From 04e6dbeedb1927060f9666b68833ccb0b433a0d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Tue, 24 Feb 2026 14:22:24 +0000 Subject: [PATCH 1/7] mongodb: Allow MongoPersistentCursor to work with aggregation pipelines. #TASK-8038 --- .../mongodb/MongoPersistentCursor.java | 118 +++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java index 4deba0933..ab30ea5b0 100644 --- a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java +++ b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java @@ -20,16 +20,21 @@ import com.mongodb.ServerAddress; import com.mongodb.ServerCursor; import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoCursor; import com.mongodb.client.model.Filters; import com.mongodb.client.model.Sorts; +import org.bson.BsonDocument; import org.bson.Document; import org.bson.conversions.Bson; import org.opencb.commons.datastore.core.QueryOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.List; + /** * MongoDBCursor wrapper for queries that require a long time to process the results. * Avoids {@link MongoCursorNotFoundException}. @@ -47,6 +52,7 @@ public class MongoPersistentCursor implements MongoCursor { private final QueryOptions options; private final Bson query; private final Bson projection; + private final List pipeline; private MongoDBCollection collection; private MongoCursor mongoCursor; @@ -71,6 +77,57 @@ public MongoPersistentCursor(MongoDBCollection collection, Bson query, Bson proj this.options = options; this.query = query; this.projection = projection; + this.pipeline = null; + this.collection = collection; + + if (batchSize > 0) { + this.batchSize = batchSize; + } + if (limit > 0) { + this.limit = limit; + } + if (skip > 0) { + this.skip = skip; + } + + reset(); + } + + /** + * Create a persistent cursor backed by an aggregation pipeline. + * The pipeline's first stage should be a {@code $match} stage so that the resume filter + * ({@code _id > lastId}) can be merged into it efficiently. + * The cursor will automatically add a {@code $sort: {_id: 1}} stage when no sort is present, + * ensuring deterministic ordering required for reliable resume. + * Will fail if the {@code $project} stage excludes the {@code _id} field. + * + * @param collection MongoDB collection to run the pipeline against. + * @param pipeline Aggregation pipeline. Must not be empty and should start with {@code $match}. + * @param options Query options (BATCH_SIZE, LIMIT, SKIP are read from here). + */ + public MongoPersistentCursor(MongoDBCollection collection, List pipeline, QueryOptions options) { + this(collection, pipeline, options, + options != null ? options.getInt(MongoDBCollection.BATCH_SIZE, 0) : 0, + options != null ? options.getInt(QueryOptions.LIMIT, 0) : 0, + options != null ? options.getInt(QueryOptions.SKIP, 0) : 0); + } + + /** + * Create a persistent cursor backed by an aggregation pipeline with explicit pagination parameters. + * + * @param collection MongoDB collection to run the pipeline against. + * @param pipeline Aggregation pipeline. Must not be empty and should start with {@code $match}. + * @param options Query options (used for sort detection). + * @param batchSize MongoDB cursor batch size (0 = server default). + * @param limit Maximum number of documents to return (0 = unlimited). + * @param skip Number of documents to skip at the start (0 = none). + */ + public MongoPersistentCursor(MongoDBCollection collection, List pipeline, QueryOptions options, + int batchSize, int limit, int skip) { + this.options = options; + this.query = null; + this.projection = null; + this.pipeline = new ArrayList<>(pipeline); this.collection = collection; if (batchSize > 0) { @@ -93,6 +150,15 @@ protected void reset() { } protected MongoPersistentCursor resume(Object lastObjectId) { + if (pipeline != null) { + resumeAggregate(lastObjectId); + } else { + resumeFind(lastObjectId); + } + return this; + } + + private void resumeFind(Object lastObjectId) { Bson query; if (lastObjectId != null) { query = Filters.and(Filters.gt("_id", lastObjectId), this.query); @@ -108,7 +174,57 @@ protected MongoPersistentCursor resume(Object lastObjectId) { .limit(limit) .skip(skip) .iterator(); - return this; + } + + /** + * Build and execute an aggregation pipeline cursor, injecting a resume filter and ensuring + * deterministic {@code _id} ordering when resuming after a cursor expiry. + */ + private void resumeAggregate(Object lastObjectId) { + List activePipeline = new ArrayList<>(pipeline); + + // Inject _id > lastId into the first $match stage so MongoDB can use the index. + if (lastObjectId != null) { + Bson resumeFilter = Filters.gt("_id", lastObjectId); + + BsonDocument stage = activePipeline.get(0).toBsonDocument(); + + if (!activePipeline.isEmpty() && stage.containsKey("$match")) { + BsonDocument existingMatch = stage.get("$match").asDocument(); + activePipeline.set(0, new Document("$match", Filters.and(existingMatch, resumeFilter))); + } else { + activePipeline.add(0, new Document("$match", resumeFilter)); + } + } + + // Ensure a $sort: {_id: 1} stage is present so that any resume starts past the right point. + // Only added when no explicit sort stage exists and no sort is specified in options. + boolean hasSortStage = activePipeline.stream().anyMatch(s -> s.toBsonDocument().containsKey("$sort")); + if (!hasSortStage && (options == null || !options.containsKey(QueryOptions.SORT))) { + // Insert before any existing $skip or $limit stages. + int insertPos = activePipeline.size(); + for (int i = 0; i < activePipeline.size(); i++) { + BsonDocument stage = activePipeline.get(i).toBsonDocument(); + if (stage.containsKey("$skip") || stage.containsKey("$limit")) { + insertPos = i; + break; + } + } + activePipeline.add(insertPos, new Document("$sort", new Document("_id", 1))); + } + + if (skip > 0) { + activePipeline.add(new Document("$skip", skip)); + } + if (limit > 0) { + activePipeline.add(new Document("$limit", limit)); + } + + AggregateIterable iterable = collection.nativeQuery().getDbCollection().aggregate(activePipeline); + if (batchSize > 0) { + iterable.batchSize(batchSize); + } + mongoCursor = iterable.iterator(); } protected FindIterable newFindIterable(Bson query, Bson projection, QueryOptions options) { From cbe47f9096bd49f24d8ab332bc4a4a3f535d601c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Tue, 24 Feb 2026 16:30:06 +0000 Subject: [PATCH 2/7] mongodb: Add tests for MongoPersistentCursor. #TASK-8038 --- .../commons-datastore-mongodb/pom.xml | 5 + .../mongodb/MongoPersistentCursor.java | 10 +- .../mongodb/MongoDBCollectionTest.java | 37 - .../mongodb/MongoPersistentCursorTest.java | 709 ++++++++++++++++++ pom.xml | 7 + 5 files changed, 728 insertions(+), 40 deletions(-) create mode 100644 commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursorTest.java diff --git a/commons-datastore/commons-datastore-mongodb/pom.xml b/commons-datastore/commons-datastore-mongodb/pom.xml index 04e5b2189..087831d18 100644 --- a/commons-datastore/commons-datastore-mongodb/pom.xml +++ b/commons-datastore/commons-datastore-mongodb/pom.xml @@ -60,6 +60,11 @@ junit test + + org.mockito + mockito-core + test + org.slf4j diff --git a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java index ab30ea5b0..13c11f952 100644 --- a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java +++ b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursor.java @@ -172,7 +172,7 @@ private void resumeFind(Object lastObjectId) { mongoCursor = iterable .batchSize(batchSize) .limit(limit) - .skip(skip) + .skip(lastObjectId == null ? skip : 0) .iterator(); } @@ -213,14 +213,14 @@ private void resumeAggregate(Object lastObjectId) { activePipeline.add(insertPos, new Document("$sort", new Document("_id", 1))); } - if (skip > 0) { + if (skip > 0 && lastObjectId == null) { activePipeline.add(new Document("$skip", skip)); } if (limit > 0) { activePipeline.add(new Document("$limit", limit)); } - AggregateIterable iterable = collection.nativeQuery().getDbCollection().aggregate(activePipeline); + AggregateIterable iterable = newAggregateIterable(activePipeline); if (batchSize > 0) { iterable.batchSize(batchSize); } @@ -231,6 +231,10 @@ protected FindIterable newFindIterable(Bson query, Bson projection, Qu return this.collection.nativeQuery().nativeFind(null, query, projection, options); } + protected AggregateIterable newAggregateIterable(List activePipeline) { + return collection.nativeQuery().getDbCollection().aggregate(activePipeline); + } + public Object getLastId() { return lastId; } diff --git a/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoDBCollectionTest.java b/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoDBCollectionTest.java index de6a06afe..5427186c2 100644 --- a/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoDBCollectionTest.java +++ b/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoDBCollectionTest.java @@ -488,43 +488,6 @@ public void testFind8() throws Exception { } } - @Test - @Ignore - public void testPermanentCursor() throws Exception { - Document query = new Document(); - QueryOptions queryOptions = new QueryOptions(); - int documents = 50000; - MongoDBCollection collection = createTestCollection("cursor5", documents); - - MongoPersistentCursor cursor = new MongoPersistentCursor(collection, query, null, queryOptions); - - int i = 0; - while (cursor.hasNext()) { - Document document = cursor.next(); - if (i % (documents / 50) == 0) { - System.out.println("document.get(\"_id\") = " + document.get("_id")); - } - i++; - if (i == 10) { - System.out.println("SLEEP!!! " + i + " document.get(\"_id\") = " + document.get("_id")); - int totalMin = 1; - for (int min = 0; min < totalMin; min++) { - System.out.println("Continue sleeping: " + min + "/" + totalMin); - Thread.sleep(60 * 1000); - } - System.out.println("Woke up!!!"); - document = cursor.next(); - i++; - System.out.println("Woke up!!! " + i + " document.get(\"_id\") = " + document.get("_id")); - - } - } - - assertEquals(1, cursor.getNumExceptions()); - assertEquals(documents, cursor.getCount()); - assertEquals(documents, i); - } - @Test public void testAggregate() { List dbObjectList = new ArrayList<>(); diff --git a/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursorTest.java b/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursorTest.java new file mode 100644 index 000000000..2c31aaff3 --- /dev/null +++ b/commons-datastore/commons-datastore-mongodb/src/test/java/org/opencb/commons/datastore/mongodb/MongoPersistentCursorTest.java @@ -0,0 +1,709 @@ +/* + * Copyright 2015-2017 OpenCB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencb.commons.datastore.mongodb; + +import com.mongodb.MongoCursorNotFoundException; +import com.mongodb.ServerAddress; +import com.mongodb.ServerCursor; +import com.mongodb.client.AggregateIterable; +import com.mongodb.client.FindIterable; +import com.mongodb.client.MongoCursor; +import com.mongodb.client.model.Filters; +import org.bson.Document; +import org.bson.conversions.Bson; +import org.bson.types.ObjectId; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.opencb.commons.datastore.core.QueryOptions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link MongoPersistentCursor} in both find and aggregation pipeline modes. + */ +public class MongoPersistentCursorTest { + + private static MongoDataStoreManager mongoDataStoreManager; + private static MongoDataStore mongoDataStore; + private static MongoDBCollection collection; + + private static final int NUM_DOCS = 2000; + /** Number of docs fetched before the simulated cursor expiry fires. */ + private static final int FAIL_AFTER = 300; + + @BeforeClass + public static void beforeClass() { + mongoDataStoreManager = new MongoDataStoreManager("localhost", 27017); + mongoDataStoreManager.get("persistent_cursor_test_db"); // register in map first + mongoDataStoreManager.drop("persistent_cursor_test_db"); // now the drop takes effect + mongoDataStore = mongoDataStoreManager.get("persistent_cursor_test_db"); + collection = mongoDataStore.getCollection("test"); + + for (int i = 0; i < NUM_DOCS; i++) { + Document doc = new Document("value", i).append("group", i % 4); + collection.nativeQuery().insert(doc, null); + } + } + + @AfterClass + public static void afterClass() { + mongoDataStore.close(); + } + + // ------------------------------------------------------------------------- + // Find mode + // ------------------------------------------------------------------------- + + @Test + public void testFindModeReturnsAllDocuments() { + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, new Document(), null, new QueryOptions()); + int count = 0; + while (cursor.hasNext()) { + cursor.next(); + count++; + } + cursor.close(); + + assertEquals(NUM_DOCS, count); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(0, cursor.getNumExceptions()); + } + + @Test + public void testFindModeWithFilter() { + // Documents where group == 0: indices 0, 4, 8, ... -> NUM_DOCS/4 documents + MongoPersistentCursor cursor = new MongoPersistentCursor( + collection, + Filters.eq("group", 0), + null, + new QueryOptions()); + + List results = new ArrayList<>(); + while (cursor.hasNext()) { + results.add(cursor.next()); + } + cursor.close(); + + assertEquals(NUM_DOCS / 4, results.size()); + for (Document doc : results) { + assertEquals(0, doc.getInteger("group").intValue()); + } + } + + @Test + public void testFindModeWithLimit() { + int limit = 7; + QueryOptions options = new QueryOptions(QueryOptions.LIMIT, limit); + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, new Document(), null, options); + + int count = 0; + while (cursor.hasNext()) { + cursor.next(); + count++; + } + cursor.close(); + + assertEquals(limit, count); + } + + @Test + public void testFindModeTracksLastId() { + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, new Document(), null, new QueryOptions()); + + assertNull(cursor.getLastId()); + + Document first = cursor.next(); + assertEquals(first.get("_id"), cursor.getLastId()); + + Document second = cursor.next(); + assertEquals(second.get("_id"), cursor.getLastId()); + + cursor.close(); + } + + @Test + public void testFindModeDocumentsAreOrderedById() { + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, new Document(), null, new QueryOptions()); + + ObjectId previous = null; + while (cursor.hasNext()) { + Document doc = cursor.next(); + ObjectId current = doc.getObjectId("_id"); + if (previous != null) { + assertTrue("Documents must be in ascending _id order", current.compareTo(previous) > 0); + } + previous = current; + } + cursor.close(); + } + + @Test + public void testFindModeResumesAfterCursorException() { + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, new Document(), null, new QueryOptions(), FAIL_AFTER, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + @Test + public void testFindModeResumesMultipleTimes() { + int maxFailures = 3; + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, new Document(), null, new QueryOptions(), FAIL_AFTER, maxFailures); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(maxFailures, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // Aggregation pipeline mode + // ------------------------------------------------------------------------- + + @Test + public void testAggregationModeReturnsAllDocuments() { + List pipeline = Arrays.asList( + new Document("$match", new Document()) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, new QueryOptions()); + + int count = 0; + while (cursor.hasNext()) { + cursor.next(); + count++; + } + cursor.close(); + + assertEquals(NUM_DOCS, count); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(0, cursor.getNumExceptions()); + } + + @Test + public void testAggregationModeWithMatchFilter() { + // group == 1: NUM_DOCS/4 documents + List pipeline = Arrays.asList( + new Document("$match", new Document("group", 1)) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, new QueryOptions()); + + List results = new ArrayList<>(); + while (cursor.hasNext()) { + results.add(cursor.next()); + } + cursor.close(); + + assertEquals(NUM_DOCS / 4, results.size()); + for (Document doc : results) { + assertEquals(1, doc.getInteger("group").intValue()); + } + } + + @Test + public void testAggregationModeWithLimit() { + int limit = 6; + QueryOptions options = new QueryOptions(QueryOptions.LIMIT, limit); + List pipeline = Arrays.asList( + new Document("$match", new Document()) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, options); + + int count = 0; + while (cursor.hasNext()) { + cursor.next(); + count++; + } + cursor.close(); + + assertEquals(limit, count); + } + + @Test + public void testAggregationModeTracksLastId() { + List pipeline = Arrays.asList( + new Document("$match", new Document()) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, new QueryOptions()); + + assertNull(cursor.getLastId()); + + Document first = cursor.next(); + assertEquals(first.get("_id"), cursor.getLastId()); + + Document second = cursor.next(); + assertEquals(second.get("_id"), cursor.getLastId()); + + cursor.close(); + } + + @Test + public void testAggregationModeDocumentsAreOrderedById() { + List pipeline = Arrays.asList( + new Document("$match", new Document()) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, new QueryOptions()); + + ObjectId previous = null; + while (cursor.hasNext()) { + Document doc = cursor.next(); + ObjectId current = doc.getObjectId("_id"); + if (previous != null) { + assertTrue("Aggregation pipeline documents must be in ascending _id order", + current.compareTo(previous) > 0); + } + previous = current; + } + cursor.close(); + } + + @Test + public void testAggregationModeWithAddFieldsStage() { + List pipeline = Arrays.asList( + new Document("$match", new Document()), + new Document("$addFields", new Document("doubled", + new Document("$multiply", Arrays.asList("$value", 2)))) + ); + + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, pipeline, new QueryOptions()); + + int count = 0; + while (cursor.hasNext()) { + Document doc = cursor.next(); + assertEquals(doc.getInteger("value") * 2, doc.getInteger("doubled").intValue()); + count++; + } + cursor.close(); + + assertEquals(NUM_DOCS, count); + } + + @Test + public void testAggregationModeResumesAfterCursorException() { + List pipeline = Arrays.asList(new Document("$match", new Document())); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, new QueryOptions(), FAIL_AFTER, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + @Test + public void testAggregationModeResumesMultipleTimes() { + int maxFailures = 3; + List pipeline = Arrays.asList(new Document("$match", new Document())); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, new QueryOptions(), FAIL_AFTER, maxFailures); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(maxFailures, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // #1 – Exception fires before the first document (lastId is null on resume) + // ------------------------------------------------------------------------- + + @Test + public void testFindModeResumesWhenExceptionBeforeFirstDocument() { + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, new Document(), null, new QueryOptions(), 0, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + @Test + public void testAggregationModeResumesWhenExceptionBeforeFirstDocument() { + List pipeline = Arrays.asList(new Document("$match", new Document())); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, new QueryOptions(), 0, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // #2 – Exception fires after the second-to-last document + // ------------------------------------------------------------------------- + + @Test + public void testFindModeResumesWhenExceptionAfterLastDocument() { + // failAfter = NUM_DOCS-1: the failing cursor returns all but the last doc, + // throws, then the resume cursor returns that final document. + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, new Document(), null, new QueryOptions(), NUM_DOCS - 1, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + @Test + public void testAggregationModeResumesWhenExceptionAfterLastDocument() { + List pipeline = Arrays.asList(new Document("$match", new Document())); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, new QueryOptions(), NUM_DOCS - 1, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // #3 – Aggregation pipeline without a leading $match stage + // Tests the branch in resumeAggregate that prepends the resume filter + // instead of merging it into an existing $match. + // ------------------------------------------------------------------------- + + @Test + public void testAggregationModeResumesWithoutLeadingMatchStage() { + // No $match, no $sort – resumeAggregate must prepend $match:{_id:{$gt:lastId}} + // and inject $sort:{_id:1} on every iteration including resumes. + List pipeline = Arrays.asList( + new Document("$addFields", new Document("doubled", + new Document("$multiply", Arrays.asList("$value", 2)))) + ); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, new QueryOptions(), FAIL_AFTER, 1); + + Set values = drainValues(cursor); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // #5 – skip must not be re-applied after a cursor resume + // ------------------------------------------------------------------------- + + @Test + public void testFindModeSkipNotReappliedOnResume() { + int skip = 5; + QueryOptions options = new QueryOptions(QueryOptions.SKIP, skip); + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, new Document(), null, options, FAIL_AFTER, 1); + + Set values = drainValues(cursor); + + Set expected = new HashSet<>(); + for (int i = skip; i < NUM_DOCS; i++) { + expected.add(i); + } + assertEquals("skip must apply only at the start, not on each resume", expected, values); + assertEquals(1, cursor.getNumExceptions()); + } + + @Test + public void testAggregationModeSkipNotReappliedOnResume() { + int skip = 5; + QueryOptions options = new QueryOptions(QueryOptions.SKIP, skip); + List pipeline = Arrays.asList(new Document("$match", new Document())); + + MongoPersistentCursor cursor = new FailingCursorPersistentCursor( + collection, pipeline, options, FAIL_AFTER, 1); + + Set values = drainValues(cursor); + + Set expected = new HashSet<>(); + for (int i = skip; i < NUM_DOCS; i++) { + expected.add(i); + } + assertEquals("skip must apply only at the start, not on each resume", expected, values); + assertEquals(1, cursor.getNumExceptions()); + } + + // ------------------------------------------------------------------------- + // Real cursor expiry (manual integration test — requires a live MongoDB) + // ------------------------------------------------------------------------- + + /** + * Verifies that the cursor recovers transparently from a real server-side + * {@link MongoCursorNotFoundException}. + * + *

Temporarily reduces {@code cursorTimeoutMillis} to 5 s so that the + * deliberate sleep is short. A small {@code batchSize} forces MongoDB to + * open a real server cursor (without it, all documents may be returned in + * a single batch and no cursor would exist to expire). + * + *

Run manually: + *

+     *   mvn test -pl commons-datastore/commons-datastore-mongodb \
+     *            -Dtest=MongoPersistentCursorTest#testRealCursorExpiry \
+     *            -Dsurefire.failIfNoSpecifiedTests=false
+     * 
+ */ + @Test + @Ignore + public void testRealCursorExpiry() throws Exception { + int cursorTimeoutMs = 1_000; + int sleepMs = 60_000; + + // Reduce cursor TTL so the sleep stays short. + mongoDataStore.getMongoClient().getDatabase("admin") + .runCommand(new Document("setParameter", 1).append("cursorTimeoutMillis", cursorTimeoutMs)); + try { + // batchSize forces a real server-side cursor; without it every document + // may arrive in one batch and the cursor never exists on the server. + QueryOptions options = new QueryOptions(MongoDBCollection.BATCH_SIZE, 100); + MongoPersistentCursor cursor = new MongoPersistentCursor(collection, new Document(), null, options); + + Set values = new HashSet<>(); + int count = 0; + while (cursor.hasNext()) { + Document doc = cursor.next(); + assertTrue("Duplicate value " + doc.getInteger("value"), values.add(doc.getInteger("value"))); + count++; + + if (count == 10) { + // Let the server-side cursor expire. + System.out.println("Sleeping " + sleepMs + " ms to expire the server cursor after _id=" + doc.get("_id")); + Thread.sleep(sleepMs); + System.out.println("Woke up. Calling cursor.next() to exercise the next() recovery path..."); + + // Call next() directly so we cover the MongoCursorNotFoundException + // catch block inside MongoPersistentCursor.next(), not just hasNext(). + Document resumed = cursor.next(); + assertTrue("Duplicate value after resume", values.add(resumed.getInteger("value"))); + System.out.println("Recovered. First doc after expiry: _id=" + resumed.get("_id")); + } + } + cursor.close(); + + assertCompleteValueSet(values); + assertEquals(NUM_DOCS, cursor.getCount()); + assertEquals(1, cursor.getNumExceptions()); + } finally { + // Restore MongoDB default cursor timeout (10 minutes). + mongoDataStore.getMongoClient().getDatabase("admin") + .runCommand(new Document("setParameter", 1).append("cursorTimeoutMillis", 600_000)); + } + } + + // ========================================================================= + // Helpers + // ========================================================================= + + /** + * Drains the cursor into a set of {@code value} integers and closes it. + * Using a Set detects duplicates (add returns false) and, combined with + * {@link #assertCompleteValueSet}, detects skipped documents. + */ + private static Set drainValues(MongoPersistentCursor cursor) { + Set values = new HashSet<>(); + while (cursor.hasNext()) { + Document doc = cursor.next(); + int value = doc.getInteger("value"); + assertTrue("Duplicate value " + value + " returned by cursor", values.add(value)); + } + cursor.close(); + return values; + } + + /** + * Asserts that the set is exactly {@code {0, 1, …, NUM_DOCS-1}}: every document + * was returned exactly once with no gaps. + */ + private static void assertCompleteValueSet(Set values) { + Set expected = new HashSet<>(); + for (int i = 0; i < NUM_DOCS; i++) { + expected.add(i); + } + assertEquals("Cursor must return every document exactly once", expected, values); + } + + /** + * Subclass that injects a {@link FailingMongoCursor} for the first {@code maxFailures} + * resume calls, simulating {@link MongoCursorNotFoundException} mid-iteration. + * + *

The {@code initialized} flag is false while {@code super()} runs (Java guarantees that + * subclass fields are still at their defaults during the superclass constructor), so the first + * cursor created during construction uses the real MongoDB cursor. Once our constructor body + * sets {@code initialized = true} and calls {@link #reset()}, subsequent calls to + * {@code newFindIterable} / {@code newAggregateIterable} wrap the cursor with a failing one. + */ + private static class FailingCursorPersistentCursor extends MongoPersistentCursor { + + private final int failAfter; + private int failuresLeft; + /** False while super() runs so the constructor-time reset uses a real cursor. */ + private boolean initialized; + + // Find mode + FailingCursorPersistentCursor(MongoDBCollection collection, Bson query, Bson projection, + QueryOptions options, int failAfter, int maxFailures) { + super(collection, query, projection, options); + this.failAfter = failAfter; + this.failuresLeft = maxFailures; + this.initialized = true; + reset(); // restart fresh – this time newFindIterable wraps with a failing cursor + } + + // Aggregate mode + FailingCursorPersistentCursor(MongoDBCollection collection, List pipeline, + QueryOptions options, int failAfter, int maxFailures) { + super(collection, pipeline, options); + this.failAfter = failAfter; + this.failuresLeft = maxFailures; + this.initialized = true; + reset(); // restart fresh – this time newAggregateIterable wraps with a failing cursor + } + + @Override + protected FindIterable newFindIterable(Bson query, Bson projection, QueryOptions options) { + FindIterable real = super.newFindIterable(query, projection, options); + if (!initialized || failuresLeft == 0) { + return real; + } + failuresLeft--; + return wrapFind(real); + } + + @Override + protected AggregateIterable newAggregateIterable(List activePipeline) { + AggregateIterable real = super.newAggregateIterable(activePipeline); + if (!initialized || failuresLeft == 0) { + return real; + } + failuresLeft--; + return wrapAggregate(real); + } + + private FindIterable wrapFind(FindIterable real) { + // Fluent calls are forwarded to `real` so that sort/skip/limit are applied + // before real.iterator() is opened — important for tests that use skip/limit. + FindIterable mockIterable = mock(FindIterable.class); + when(mockIterable.sort(any(Bson.class))).thenAnswer(inv -> { real.sort(inv.getArgument(0)); return mockIterable; }); + when(mockIterable.batchSize(anyInt())).thenAnswer(inv -> { real.batchSize(inv.getArgument(0)); return mockIterable; }); + when(mockIterable.limit(anyInt())).thenAnswer(inv -> { real.limit(inv.getArgument(0)); return mockIterable; }); + when(mockIterable.skip(anyInt())).thenAnswer(inv -> { real.skip(inv.getArgument(0)); return mockIterable; }); + when(mockIterable.iterator()).thenAnswer(inv -> new FailingMongoCursor(real.iterator(), failAfter)); + return mockIterable; + } + + private AggregateIterable wrapAggregate(AggregateIterable real) { + AggregateIterable mockIterable = mock(AggregateIterable.class); + when(mockIterable.batchSize(anyInt())).thenAnswer(inv -> { real.batchSize(inv.getArgument(0)); return mockIterable; }); + when(mockIterable.iterator()).thenAnswer(inv -> new FailingMongoCursor(real.iterator(), failAfter)); + return mockIterable; + } + } + + /** + * Wraps a real {@link MongoCursor} and throws {@link MongoCursorNotFoundException} after + * exactly {@code failAfter} documents have been returned via {@link #next()}. + */ + private static class FailingMongoCursor implements MongoCursor { + + private final MongoCursor delegate; + private int remaining; + + FailingMongoCursor(MongoCursor delegate, int failAfter) { + this.delegate = delegate; + this.remaining = failAfter; + } + + @Override + public boolean hasNext() { + if (remaining == 0) { + throw new MongoCursorNotFoundException(0xDEADBEEFL, new ServerAddress()); + } + return delegate.hasNext(); + } + + @Override + public Document next() { + remaining--; + return delegate.next(); + } + + @Override + public Document tryNext() { + if (remaining == 0) { + throw new MongoCursorNotFoundException(0xDEADBEEFL, new ServerAddress()); + } + Document doc = delegate.tryNext(); + if (doc != null) { + remaining--; + } + return doc; + } + + @Override + public int available() { + return delegate.available(); + } + + @Override + public ServerCursor getServerCursor() { + return delegate.getServerCursor(); + } + + @Override + public ServerAddress getServerAddress() { + return delegate.getServerAddress(); + } + + @Override + public void close() { + delegate.close(); + } + } +} diff --git a/pom.xml b/pom.xml index cee87ec8f..6aa823ad0 100644 --- a/pom.xml +++ b/pom.xml @@ -30,6 +30,7 @@ 2.4.0 1.3 4.13.2 + 2.2.27 opencb https://sonarcloud.io @@ -187,6 +188,12 @@ ${junit.version} test + + org.mockito + mockito-core + ${mockito.version} + test + From 9464324934c9018fbeff38c96d5569fe8ce41c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Tue, 24 Feb 2026 17:20:33 +0000 Subject: [PATCH 3/7] commons: Add DataWriter.tee (sequential + parallel); add Task.tee, deprecate Task.join. #TASK-8038 --- .../org/opencb/commons/io/DataReader.java | 18 ++ .../org/opencb/commons/io/DataWriter.java | 108 +++++++ .../java/org/opencb/commons/run/Task.java | 39 ++- .../org/opencb/commons/io/DataWriterTest.java | 265 ++++++++++++++++++ .../java/org/opencb/commons/run/TaskTest.java | 231 +++++++++++++++ 5 files changed, 649 insertions(+), 12 deletions(-) create mode 100644 commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java create mode 100644 commons-lib/src/test/java/org/opencb/commons/run/TaskTest.java diff --git a/commons-lib/src/main/java/org/opencb/commons/io/DataReader.java b/commons-lib/src/main/java/org/opencb/commons/io/DataReader.java index e203fa7fc..c3dde17a1 100644 --- a/commons-lib/src/main/java/org/opencb/commons/io/DataReader.java +++ b/commons-lib/src/main/java/org/opencb/commons/io/DataReader.java @@ -173,4 +173,22 @@ default void forEach(Consumer action, int batchSize) { action.accept(t); } } + + static DataReader wrap(Iterable iterable) { + return wrap(iterable.iterator()); + } + + static DataReader wrap(Iterator iterator) { + return new DataReader() { + @Override + public List read(int batchSize) { + List batch = new ArrayList<>(batchSize); + while (iterator.hasNext() && batchSize > batch.size()) { + batch.add(iterator.next()); + } + return batch; + } + }; + } + } diff --git a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java index fc86358ad..721159279 100644 --- a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java +++ b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java @@ -20,6 +20,9 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -98,6 +101,111 @@ public void post() throws Exception { }; } + /** + * Fan out to two writers receiving the same batches sequentially. + * + * @param dw1 First writer + * @param dw2 Second writer + * @param Batch element type + * @return Composite writer that writes to both dw1 and dw2 in sequence. + */ + static DataWriter tee(DataWriter dw1, DataWriter dw2) { + return dw1.then(dw2.asTask()); + } + + /** + * Fan out to two writers receiving the same batches. + * + *

When {@code parallel=true} each writer runs in its own background thread. + * Batches are enqueued from the caller thread; the background threads consume and write. + * Any exception thrown by a background thread is rethrown from {@link #post()}. + * + * @param dw1 First writer + * @param dw2 Second writer + * @param parallel Whether to run each writer in its own background thread + * @param Batch element type + * @return Composite writer that writes to both dw1 and dw2. + */ + static DataWriter tee(DataWriter dw1, DataWriter dw2, boolean parallel) { + if (!parallel) { + return tee(dw1, dw2); + } + return new DataWriter() { + private final BlockingQueue>> queue1 = new LinkedBlockingQueue<>(); + private final BlockingQueue>> queue2 = new LinkedBlockingQueue<>(); + private Thread thread1; + private Thread thread2; + private volatile Throwable error1; + private volatile Throwable error2; + + @Override + public boolean pre() { + thread1 = new Thread(() -> { + try { + dw1.open(); + dw1.pre(); + Optional> item = queue1.take(); + while (item.isPresent()) { + dw1.write(item.get()); + item = queue1.take(); + } + dw1.post(); + dw1.close(); + } catch (Throwable t) { + error1 = t; + } + }, Thread.currentThread().getName() + "writer-1"); + thread2 = new Thread(() -> { + try { + dw2.open(); + dw2.pre(); + Optional> item = queue2.take(); + while (item.isPresent()) { + dw2.write(item.get()); + item = queue2.take(); + } + dw2.post(); + dw2.close(); + } catch (Throwable t) { + error2 = t; + } + }, Thread.currentThread().getName() + "writer-2"); + thread1.start(); + thread2.start(); + return true; + } + + @Override + public boolean write(List batch) { + if (error1 != null || error2 != null) { + throw new RuntimeException("Tee background writer has failed"); + } + queue1.add(Optional.of(batch)); + queue2.add(Optional.of(batch)); + return true; + } + + @Override + public boolean post() { + queue1.add(Optional.empty()); + queue2.add(Optional.empty()); + try { + thread1.join(); + thread2.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + if (error1 != null) { + throw new RuntimeException("Error in tee background writer 1", error1); + } + if (error2 != null) { + throw new RuntimeException("Error in tee background writer 2", error2); + } + return true; + } + }; + } + default DataWriter then(DataWriter nextWriter) { return then(nextWriter.asTask()); } diff --git a/commons-lib/src/main/java/org/opencb/commons/run/Task.java b/commons-lib/src/main/java/org/opencb/commons/run/Task.java index 971057946..ab24c69b9 100644 --- a/commons-lib/src/main/java/org/opencb/commons/run/Task.java +++ b/commons-lib/src/main/java/org/opencb/commons/run/Task.java @@ -130,47 +130,62 @@ default Task then(DataWriter writer) { /** * Use to execute multiple Tasks with the same input. - * Only the output of the main task will be propagated. + * Only the output of the main task will be propagated. The side task runs as a side-effect. * - * task = Task.join(task1, task2); + * task = Task.tee(task1, task2); * - * @param mainTask Main task to propagate - * @param otherTask Task to execute with the same input. The output will be lost. + * @param mainTask Main task whose output is propagated + * @param sideTask Task to execute with the same input. The output will be discarded. * @param Input type. * @param Return type. * @return Task that runs both tasks with the same input. */ - static Task join(Task mainTask, Task otherTask) { + static Task tee(Task mainTask, Task sideTask) { return new Task() { @Override public void pre() throws Exception { mainTask.pre(); - otherTask.pre(); + sideTask.pre(); } @Override public List apply(List batch) throws Exception { List apply1 = mainTask.apply(batch); - otherTask.apply(batch); // ignore output + sideTask.apply(batch); // ignore output return apply1; } @Override public List drain() throws Exception { - // Drain both tasks List drain1 = mainTask.drain(); - otherTask.drain(); // ignore output - - // Return drain1 + sideTask.drain(); // ignore output return drain1; } @Override public void post() throws Exception { mainTask.post(); - otherTask.post(); + sideTask.post(); } }; } + /** + * Use to execute multiple Tasks with the same input. + * Only the output of the main task will be propagated. + * + * task = Task.join(task1, task2); + * + * @param mainTask Main task to propagate + * @param otherTask Task to execute with the same input. The output will be lost. + * @param Input type. + * @param Return type. + * @return Task that runs both tasks with the same input. + * @deprecated Use {@link #tee(Task, Task)} instead. + */ + @Deprecated + static Task join(Task mainTask, Task otherTask) { + return tee(mainTask, otherTask); + } + } diff --git a/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java b/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java new file mode 100644 index 000000000..38604ecb6 --- /dev/null +++ b/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java @@ -0,0 +1,265 @@ +package org.opencb.commons.io; + +import org.junit.Assert; +import org.junit.Test; +import org.opencb.commons.run.Task; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +public class DataWriterTest { + + // ===== tee(dw1, dw2) — sequential ===== + + @Test + public void testTeeSequentialBothReceiveSameBatches() { + List> received1 = new ArrayList<>(); + List> received2 = new ArrayList<>(); + DataWriter dw1 = b -> { received1.add(new ArrayList<>(b)); return true; }; + DataWriter dw2 = b -> { received2.add(new ArrayList<>(b)); return true; }; + + DataWriter tee = DataWriter.tee(dw1, dw2); + tee.write(Arrays.asList("a", "b")); + tee.write(Collections.singletonList("c")); + + Assert.assertEquals(received1, received2); + Assert.assertEquals(2, received1.size()); + Assert.assertEquals(Arrays.asList("a", "b"), received1.get(0)); + Assert.assertEquals(Collections.singletonList("c"), received1.get(1)); + } + + @Test + public void testTeeSequentialLifecycleCalledOnBothWriters() { + AtomicBoolean pre1 = new AtomicBoolean(); + AtomicBoolean post1 = new AtomicBoolean(); + AtomicBoolean pre2 = new AtomicBoolean(); + AtomicBoolean post2 = new AtomicBoolean(); + + DataWriter dw1 = new DataWriter() { + @Override public boolean pre() { pre1.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post1.set(true); return true; } + }; + DataWriter dw2 = new DataWriter() { + @Override public boolean pre() { pre2.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post2.set(true); return true; } + }; + + DataWriter tee = DataWriter.tee(dw1, dw2); + tee.pre(); + tee.write(Collections.singletonList("x")); + tee.post(); + + Assert.assertTrue(pre1.get()); + Assert.assertTrue(pre2.get()); + Assert.assertTrue(post1.get()); + Assert.assertTrue(post2.get()); + } + + // ===== tee(dw1, dw2, false) — explicit non-parallel ===== + + @Test + public void testTeeNonParallelEquivalentToSequential() { + List> received1 = new ArrayList<>(); + List> received2 = new ArrayList<>(); + DataWriter dw1 = b -> { received1.add(new ArrayList<>(b)); return true; }; + DataWriter dw2 = b -> { received2.add(new ArrayList<>(b)); return true; }; + + DataWriter tee = DataWriter.tee(dw1, dw2, false); + tee.write(Arrays.asList("x", "y")); + + Assert.assertEquals(received1, received2); + Assert.assertEquals(1, received1.size()); + } + + // ===== tee(dw1, dw2, true) — parallel ===== + + @Test + public void testTeeParallelBothReceiveAllBatches() throws Exception { + List received1 = Collections.synchronizedList(new ArrayList<>()); + List received2 = Collections.synchronizedList(new ArrayList<>()); + DataWriter dw1 = b -> { received1.addAll(b); return true; }; + DataWriter dw2 = b -> { received2.addAll(b); return true; }; + + DataWriter tee = DataWriter.tee(dw1, dw2, true); + tee.pre(); + tee.write(Arrays.asList("a", "b", "c")); + tee.write(Arrays.asList("d", "e")); + tee.post(); // blocks until both background threads complete + + List expected = Arrays.asList("a", "b", "c", "d", "e"); + Assert.assertEquals(expected, received1); + Assert.assertEquals(expected, received2); + } + + @Test + public void testTeeParallelOpenPrePostCloseCalledOnBothWriters() throws Exception { + AtomicBoolean open1 = new AtomicBoolean(); + AtomicBoolean pre1 = new AtomicBoolean(); + AtomicBoolean post1 = new AtomicBoolean(); + AtomicBoolean close1 = new AtomicBoolean(); + AtomicBoolean open2 = new AtomicBoolean(); + AtomicBoolean pre2 = new AtomicBoolean(); + AtomicBoolean post2 = new AtomicBoolean(); + AtomicBoolean close2 = new AtomicBoolean(); + + DataWriter dw1 = new DataWriter() { + @Override public boolean open() { open1.set(true); return true; } + @Override public boolean pre() { pre1.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post1.set(true); return true; } + @Override public boolean close() { close1.set(true); return true; } + }; + DataWriter dw2 = new DataWriter() { + @Override public boolean open() { open2.set(true); return true; } + @Override public boolean pre() { pre2.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post2.set(true); return true; } + @Override public boolean close() { close2.set(true); return true; } + }; + + DataWriter tee = DataWriter.tee(dw1, dw2, true); + tee.pre(); + tee.write(Collections.singletonList("x")); + tee.post(); // joins background threads → all lifecycle steps completed + + Assert.assertTrue(open1.get()); + Assert.assertTrue(pre1.get()); + Assert.assertTrue(post1.get()); + Assert.assertTrue(close1.get()); + Assert.assertTrue(open2.get()); + Assert.assertTrue(pre2.get()); + Assert.assertTrue(post2.get()); + Assert.assertTrue(close2.get()); + } + + @Test(expected = RuntimeException.class) + public void testTeeParallelExceptionInBackgroundWriterRethrownOnPost() throws Exception { + // dw1 always throws; the error must surface when post() joins the background thread + DataWriter failing = b -> { throw new RuntimeException("write failed"); }; + DataWriter ok = b -> true; + + DataWriter tee = DataWriter.tee(failing, ok, true); + tee.pre(); + tee.write(Collections.singletonList("x")); + tee.post(); // background thread for 'failing' set error1; post() must rethrow it + } + + @Test(expected = RuntimeException.class) + public void testTeeParallelExceptionDetectedOnSubsequentWrite() throws Exception { + // After the background thread has failed, the next write() should detect it + DataWriter failing = b -> { throw new RuntimeException("write failed"); }; + DataWriter ok = b -> true; + + DataWriter tee = DataWriter.tee(failing, ok, true); + tee.pre(); + tee.write(Collections.singletonList("trigger-failure")); + Thread.sleep(200); // wait for background thread to process and set the error + tee.write(Collections.singletonList("should-throw")); // error already set → throws + } + + // ===== asTask() ===== + + @Test + public void testAsTaskDelegatesFullLifecycle() throws Exception { + AtomicBoolean opened = new AtomicBoolean(); + AtomicBoolean pre = new AtomicBoolean(); + AtomicBoolean post = new AtomicBoolean(); + AtomicBoolean closed = new AtomicBoolean(); + List> written = new ArrayList<>(); + + DataWriter dw = new DataWriter() { + @Override public boolean open() { opened.set(true); return true; } + @Override public boolean pre() { pre.set(true); return true; } + @Override public boolean write(List b) { written.add(new ArrayList<>(b)); return true; } + @Override public boolean post() { post.set(true); return true; } + @Override public boolean close() { closed.set(true); return true; } + }; + + Task task = dw.asTask(); + task.pre(); + List result = task.apply(Arrays.asList("x", "y")); + task.post(); + + Assert.assertTrue(opened.get()); + Assert.assertTrue(pre.get()); + Assert.assertTrue(post.get()); + Assert.assertTrue(closed.get()); + Assert.assertEquals(1, written.size()); + Assert.assertEquals(Arrays.asList("x", "y"), result); // batch passed through + } + + @Test + public void testAsTaskPreAndPostAreIdempotent() throws Exception { + // open/pre are called only once even if task.pre() is invoked multiple times, + // same for post/close. + AtomicInteger preCount = new AtomicInteger(); + AtomicInteger postCount = new AtomicInteger(); + + DataWriter dw = new DataWriter() { + @Override public boolean pre() { preCount.incrementAndGet(); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { postCount.incrementAndGet(); return true; } + }; + + Task task = dw.asTask(); + task.pre(); + task.pre(); + task.post(); + task.post(); + + Assert.assertEquals(1, preCount.get()); + Assert.assertEquals(1, postCount.get()); + } + + // ===== then(DataWriter) ===== + + @Test + public void testThenDataWriterBothWritersReceiveSameBatches() { + List received1 = new ArrayList<>(); + List received2 = new ArrayList<>(); + DataWriter dw1 = b -> { received1.addAll(b); return true; }; + DataWriter dw2 = b -> { received2.addAll(b); return true; }; + + DataWriter chained = dw1.then(dw2); + chained.write(Arrays.asList("a", "b")); + chained.write(Collections.singletonList("c")); + + Assert.assertEquals(received1, received2); + Assert.assertEquals(Arrays.asList("a", "b", "c"), received1); + } + + @Test + public void testThenDataWriterLifecycleCalledOnBoth() { + AtomicBoolean pre1 = new AtomicBoolean(); + AtomicBoolean post1 = new AtomicBoolean(); + AtomicBoolean pre2 = new AtomicBoolean(); + AtomicBoolean post2 = new AtomicBoolean(); + + DataWriter dw1 = new DataWriter() { + @Override public boolean pre() { pre1.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post1.set(true); return true; } + }; + DataWriter dw2 = new DataWriter() { + @Override public boolean pre() { pre2.set(true); return true; } + @Override public boolean write(List b) { return true; } + @Override public boolean post() { post2.set(true); return true; } + }; + + DataWriter chained = dw1.then(dw2); + chained.pre(); + chained.write(Collections.singletonList("x")); + chained.post(); + + Assert.assertTrue(pre1.get()); + Assert.assertTrue(pre2.get()); + Assert.assertTrue(post1.get()); + Assert.assertTrue(post2.get()); + } +} diff --git a/commons-lib/src/test/java/org/opencb/commons/run/TaskTest.java b/commons-lib/src/test/java/org/opencb/commons/run/TaskTest.java new file mode 100644 index 000000000..559b8d120 --- /dev/null +++ b/commons-lib/src/test/java/org/opencb/commons/run/TaskTest.java @@ -0,0 +1,231 @@ +package org.opencb.commons.run; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +public class TaskTest { + + // ===== forEach(Function) ===== + + @Test + public void testForEachFunctionTransforms() throws Exception { + Task task = Task.forEach(String::length); + List result = task.apply(Arrays.asList("a", "bb", "ccc")); + Assert.assertEquals(Arrays.asList(1, 2, 3), result); + } + + @Test + public void testForEachFunctionFiltersNullReturns() throws Exception { + // Elements where the function returns null are excluded from the output + Task task = Task.forEach(s -> s.startsWith("a") ? s.length() : null); + List result = task.apply(Arrays.asList("abc", "xyz", "ab")); + Assert.assertEquals(Arrays.asList(3, 2), result); + } + + @Test + public void testForEachFunctionAllNullReturnsEmptyList() throws Exception { + Task task = Task.forEach(s -> null); + List result = task.apply(Arrays.asList("a", "b")); + Assert.assertEquals(Collections.emptyList(), result); + } + + @Test + public void testForEachFunctionEmptyBatch() throws Exception { + Task task = Task.forEach(String::length); + Assert.assertEquals(Collections.emptyList(), task.apply(Collections.emptyList())); + } + + @Test + public void testForEachFunctionNullBatch() throws Exception { + Task task = Task.forEach(String::length); + Assert.assertEquals(Collections.emptyList(), task.apply(null)); + } + + // ===== forEach(Consumer) ===== + + @Test + public void testForEachConsumerRunsOnEachElementAndPassesThrough() throws Exception { + List visited = new ArrayList<>(); + Task task = Task.forEach((Consumer) s -> { visited.add(s); }); + List input = Arrays.asList("a", "b", "c"); + + List result = task.apply(input); + + Assert.assertEquals(input, visited); // consumer received every element + Assert.assertEquals(input, result); // original batch passed through unchanged + } + + @Test + public void testForEachConsumerEmptyBatch() throws Exception { + Task task = Task.forEach((Consumer) s -> {}); + Assert.assertEquals(Collections.emptyList(), task.apply(Collections.emptyList())); + } + + @Test + public void testForEachConsumerNullBatch() throws Exception { + Task task = Task.forEach((Consumer) s -> {}); + Assert.assertEquals(Collections.emptyList(), task.apply(null)); + } + + // ===== then(Task) ===== + + @Test + public void testThenChainsOutputToNextTaskInput() throws Exception { + Task lengths = Task.forEach(String::length); + Task labels = Task.forEach(i -> "len=" + i); + + List result = lengths.then(labels).apply(Arrays.asList("hello", "hi")); + Assert.assertEquals(Arrays.asList("len=5", "len=2"), result); + } + + @Test + public void testThenPreAndPostCalledOnBothTasks() throws Exception { + AtomicBoolean pre1 = new AtomicBoolean(); + AtomicBoolean post1 = new AtomicBoolean(); + AtomicBoolean pre2 = new AtomicBoolean(); + AtomicBoolean post2 = new AtomicBoolean(); + + Task task1 = new Task() { + @Override public void pre() { pre1.set(true); } + @Override public List apply(List b) { return b; } + @Override public void post() { post1.set(true); } + }; + Task task2 = new Task() { + @Override public void pre() { pre2.set(true); } + @Override public List apply(List b) { return b; } + @Override public void post() { post2.set(true); } + }; + + Task combined = task1.then(task2); + combined.pre(); + combined.post(); + + Assert.assertTrue(pre1.get()); + Assert.assertTrue(pre2.get()); + Assert.assertTrue(post1.get()); + Assert.assertTrue(post2.get()); + } + + @Test + public void testThenDrainFromFirstTaskFeedsIntoSecond() throws Exception { + // task1 drains an extra element; task2 uppercases everything + Task task1 = new Task() { + @Override public List apply(List b) { return b; } + @Override public List drain() { return Collections.singletonList("drained"); } + }; + Task task2 = Task.forEach(s -> { + return s.toUpperCase(); + }); + + // drain1 = ["drained"] → task2.apply(drain1) = ["DRAINED"], task2.drain() = [] + List drain = task1.then(task2).drain(); + Assert.assertEquals(Collections.singletonList("DRAINED"), drain); + } + + @Test + public void testThenDrainCombinesBothDrains() throws Exception { + // Both tasks drain something; the combined result should include both contributions + Task task1 = new Task() { + @Override public List apply(List b) { return b; } + @Override public List drain() { return Collections.singletonList("from-task1"); } + }; + Task task2 = new Task() { + @Override public List apply(List b) { return b; } + @Override public List drain() { return Collections.singletonList("from-task2"); } + }; + + // task2.apply(["from-task1"]) = ["from-task1"], task2.drain() = ["from-task2"] + List drain = task1.then(task2).drain(); + Assert.assertEquals(Arrays.asList("from-task1", "from-task2"), drain); + } + + // ===== tee(mainTask, sideTask) ===== + + @Test + public void testTeeMainOutputPropagated() throws Exception { + Task main = Task.forEach(String::length); + Task side = Task.forEach(s -> { + return s.toUpperCase(); + }); // output discarded + + List result = Task.tee(main, side).apply(Arrays.asList("hello", "world")); + Assert.assertEquals(Arrays.asList(5, 5), result); + } + + @Test + public void testTeeSideTaskReceivesSameInputAsBatch() throws Exception { + List sideInput = new ArrayList<>(); + Task main = Task.forEach(String::length); + Task side = Task.forEach(s -> { sideInput.add(s); return s; }); + + Task.tee(main, side).apply(Arrays.asList("hello", "world")); + Assert.assertEquals(Arrays.asList("hello", "world"), sideInput); + } + + @Test + public void testTeePrePostAndDrainCalledOnBothTasks() throws Exception { + AtomicBoolean preMain = new AtomicBoolean(); + AtomicBoolean postMain = new AtomicBoolean(); + AtomicBoolean preSide = new AtomicBoolean(); + AtomicBoolean postSide = new AtomicBoolean(); + AtomicBoolean drainSide = new AtomicBoolean(); + + Task main = new Task() { + @Override public void pre() { preMain.set(true); } + @Override public List apply(List b) { return b; } + @Override public void post() { postMain.set(true); } + }; + Task side = new Task() { + @Override public void pre() { preSide.set(true); } + @Override public List apply(List b) { return b; } + @Override public List drain() { drainSide.set(true); return Collections.emptyList(); } + @Override public void post() { postSide.set(true); } + }; + + Task tee = Task.tee(main, side); + tee.pre(); + tee.drain(); + tee.post(); + + Assert.assertTrue(preMain.get()); + Assert.assertTrue(preSide.get()); + Assert.assertTrue(postMain.get()); + Assert.assertTrue(postSide.get()); + Assert.assertTrue(drainSide.get()); + } + + @Test + public void testTeeSideExceptionPropagates() throws Exception { + Task main = Task.forEach(String::length); + Task side = batch -> { throw new RuntimeException("side failed"); }; + + try { + Task.tee(main, side).apply(Collections.singletonList("x")); + Assert.fail("Expected exception from side task"); + } catch (RuntimeException e) { + Assert.assertEquals("side failed", e.getMessage()); + } + } + + // ===== join (deprecated alias for tee) ===== + + @Test + public void testJoinDelegatesToTee() throws Exception { + List sideInput = new ArrayList<>(); + Task main = Task.forEach(String::length); + Task side = Task.forEach(s -> { sideInput.add(s); return s; }); + + @SuppressWarnings("deprecation") + List result = Task.join(main, side).apply(Arrays.asList("abc", "de")); + + Assert.assertEquals(Arrays.asList(3, 2), result); + Assert.assertEquals(Arrays.asList("abc", "de"), sideInput); + } +} From d82f2a8c1379ea1637d9c52b8ff4556345812443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Thu, 12 Mar 2026 00:25:39 +0000 Subject: [PATCH 4/7] commons: Bound parallel tee queues to prevent unbounded memory growth. #TASK-8038 --- .../org/opencb/commons/io/DataWriter.java | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java index 721159279..55f75ec75 100644 --- a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java +++ b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java @@ -113,6 +113,10 @@ static DataWriter tee(DataWriter dw1, DataWriter dw2) { return dw1.then(dw2.asTask()); } + static DataWriter tee(DataWriter dw1, DataWriter dw2, boolean parallel) { + return tee(dw1, dw2, parallel, 1); + } + /** * Fan out to two writers receiving the same batches. * @@ -120,19 +124,20 @@ static DataWriter tee(DataWriter dw1, DataWriter dw2) { * Batches are enqueued from the caller thread; the background threads consume and write. * Any exception thrown by a background thread is rethrown from {@link #post()}. * - * @param dw1 First writer - * @param dw2 Second writer - * @param parallel Whether to run each writer in its own background thread - * @param Batch element type - * @return Composite writer that writes to both dw1 and dw2. + * @param dw1 First writer + * @param dw2 Second writer + * @param parallel Whether to run each writer in its own background thread + * @param queueCapacity Maximum number of batches buffered per writer when parallel + * @param Batch element type + * @return Composite writer that writes to both dw1 and dw2. */ - static DataWriter tee(DataWriter dw1, DataWriter dw2, boolean parallel) { + static DataWriter tee(DataWriter dw1, DataWriter dw2, boolean parallel, int queueCapacity) { if (!parallel) { return tee(dw1, dw2); } return new DataWriter() { - private final BlockingQueue>> queue1 = new LinkedBlockingQueue<>(); - private final BlockingQueue>> queue2 = new LinkedBlockingQueue<>(); + private final BlockingQueue>> queue1 = new LinkedBlockingQueue<>(queueCapacity); + private final BlockingQueue>> queue2 = new LinkedBlockingQueue<>(queueCapacity); private Thread thread1; private Thread thread2; private volatile Throwable error1; @@ -154,7 +159,7 @@ public boolean pre() { } catch (Throwable t) { error1 = t; } - }, Thread.currentThread().getName() + "writer-1"); + }, Thread.currentThread().getName() + "-writer-1"); thread2 = new Thread(() -> { try { dw2.open(); @@ -169,7 +174,7 @@ public boolean pre() { } catch (Throwable t) { error2 = t; } - }, Thread.currentThread().getName() + "writer-2"); + }, Thread.currentThread().getName() + "-writer-2"); thread1.start(); thread2.start(); return true; @@ -180,15 +185,25 @@ public boolean write(List batch) { if (error1 != null || error2 != null) { throw new RuntimeException("Tee background writer has failed"); } - queue1.add(Optional.of(batch)); - queue2.add(Optional.of(batch)); + try { + queue1.put(Optional.of(batch)); + queue2.put(Optional.of(batch)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while enqueueing batch", e); + } return true; } @Override public boolean post() { - queue1.add(Optional.empty()); - queue2.add(Optional.empty()); + try { + queue1.put(Optional.empty()); + queue2.put(Optional.empty()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while sending poison pill", e); + } try { thread1.join(); thread2.join(); From 7fb7ba94ccbc2c298867622bd3acfa942a9e7065 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Thu, 12 Mar 2026 09:18:17 +0000 Subject: [PATCH 5/7] storage: Fix NPE at ProgressLogger.asTask() #TASK-8038 --- .../main/java/org/opencb/commons/ProgressLogger.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/commons-lib/src/main/java/org/opencb/commons/ProgressLogger.java b/commons-lib/src/main/java/org/opencb/commons/ProgressLogger.java index b87f3691d..6a694b381 100644 --- a/commons-lib/src/main/java/org/opencb/commons/ProgressLogger.java +++ b/commons-lib/src/main/java/org/opencb/commons/ProgressLogger.java @@ -310,10 +310,14 @@ public List apply(List batch) throws Exception { if (batch == null || batch.isEmpty()) { return batch; } - increment(batch.size(), () -> { - T lastElement = batch.get(batch.size() - 1); - return messageBuilder.apply(lastElement); - }); + if (messageBuilder == null) { + increment(batch.size()); + } else { + increment(batch.size(), () -> { + T lastElement = batch.get(batch.size() - 1); + return messageBuilder.apply(lastElement); + }); + } return batch; } }; From a8cc8f889a5c5db86a81c43b3e68c15502145490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Tue, 31 Mar 2026 15:37:28 +0100 Subject: [PATCH 6/7] commons: Fix DataWriter.tee hang when background thread dies with full queue. #TASK-8038 --- .../org/opencb/commons/io/DataWriter.java | 65 ++++++++++++++----- .../org/opencb/commons/io/DataWriterTest.java | 55 ++++++++++++++++ 2 files changed, 104 insertions(+), 16 deletions(-) diff --git a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java index 55f75ec75..5c00ab6c1 100644 --- a/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java +++ b/commons-lib/src/main/java/org/opencb/commons/io/DataWriter.java @@ -23,6 +23,7 @@ import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -182,12 +183,10 @@ public boolean pre() { @Override public boolean write(List batch) { - if (error1 != null || error2 != null) { - throw new RuntimeException("Tee background writer has failed"); - } + checkErrors(); try { - queue1.put(Optional.of(batch)); - queue2.put(Optional.of(batch)); + offerIfAlive(queue1, Optional.of(batch), thread1); + offerIfAlive(queue2, Optional.of(batch), thread2); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted while enqueueing batch", e); @@ -197,26 +196,60 @@ public boolean write(List batch) { @Override public boolean post() { + boolean pill1Sent = false; + boolean pill2Sent = false; try { - queue1.put(Optional.empty()); - queue2.put(Optional.empty()); + offerIfAlive(queue1, Optional.empty(), thread1); + pill1Sent = true; + offerIfAlive(queue2, Optional.empty(), thread2); + pill2Sent = true; } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted while sending poison pill", e); + } finally { + if (!pill1Sent) { + thread1.interrupt(); + } + if (!pill2Sent) { + thread2.interrupt(); + } + try { + thread1.join(); + thread2.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } } - try { - thread1.join(); - thread2.join(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); + checkErrors(); + return true; + } + + /** + * Offer an item to the queue, retrying with a timeout. If the consumer thread + * is dead (can't drain the queue), clear the queue and return. + */ + private void offerIfAlive(BlockingQueue>> queue, Optional> item, Thread thread) + throws InterruptedException { + while (!queue.offer(item, 100, TimeUnit.MILLISECONDS)) { + if (!thread.isAlive()) { + queue.clear(); + checkErrors(); + return; + } } - if (error1 != null) { + } + + private void checkErrors() { + if (error1 != null && error2 != null) { + RuntimeException e = new RuntimeException("Error in tee background writers"); + e.addSuppressed(error1); + e.addSuppressed(error2); + throw e; + } else if (error1 != null) { throw new RuntimeException("Error in tee background writer 1", error1); - } - if (error2 != null) { + } else if (error2 != null) { throw new RuntimeException("Error in tee background writer 2", error2); } - return true; } }; } diff --git a/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java b/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java index 38604ecb6..42ab6096b 100644 --- a/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java +++ b/commons-lib/src/test/java/org/opencb/commons/io/DataWriterTest.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -163,6 +164,60 @@ public void testTeeParallelExceptionDetectedOnSubsequentWrite() throws Exception tee.write(Collections.singletonList("should-throw")); // error already set → throws } + @Test(timeout = 5000) + public void testTeeParallelPostShouldNotHangWhenBackgroundWriterDies() throws Exception { + // Scenario: thread1 dies (exception in write), leaving an unconsumed batch in + // its bounded queue. post() tries to enqueue a poison pill into the full queue + // and hangs forever because there's no consumer to drain it. + CountDownLatch writerStarted = new CountDownLatch(1); + CountDownLatch writerCanProceed = new CountDownLatch(1); + CountDownLatch okWriterConsumedBatch1 = new CountDownLatch(1); + + DataWriter failing = new DataWriter() { + @Override + public boolean write(List batch) { + writerStarted.countDown(); + try { + writerCanProceed.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("write failed"); + } + }; + + DataWriter ok = b -> { + okWriterConsumedBatch1.countDown(); + return true; + }; + + DataWriter tee = DataWriter.tee(failing, ok, true, 1); + tee.pre(); + + // batch1: thread1 takes it, enters write(), signals writerStarted, blocks on latch + tee.write(Collections.singletonList("batch1")); + writerStarted.await(); + + // Ensure thread2 has consumed batch1 so queue2.put won't block on next write + okWriterConsumedBatch1.await(); + + // batch2: fills queue1 to capacity (thread1 is blocked, can't consume) + tee.write(Collections.singletonList("batch2")); + + // Let thread1 proceed → it throws and dies. queue1 still has batch2 (full). + writerCanProceed.countDown(); + Thread.sleep(200); + + // post() should complete and report the error, not hang. + // BUG: queue1.put(poison pill) blocks forever because queue1 is full and thread1 is dead. + try { + tee.post(); + Assert.fail("post() should have thrown due to failed background writer"); + } catch (RuntimeException e) { + // Expected: error from failed writer should propagate + } + } + // ===== asTask() ===== @Test From 33c75eb216f1bbc1761746d5bc691c2e9d3a1085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacobo=20Coll=20Morag=C3=B3n?= Date: Thu, 23 Apr 2026 09:55:19 +0100 Subject: [PATCH 7/7] datastore: Move EmbeddedMongoDBManager to datastore. #TASK-8038 --- .../commons-datastore-mongodb/pom.xml | 24 +- .../mongodb/test/EmbeddedMongoDBManager.java | 301 ++++++++++++++++++ .../mongodb/test/EmbeddedMongoDBRule.java | 64 ++++ pom.xml | 20 ++ 4 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBManager.java create mode 100644 commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBRule.java diff --git a/commons-datastore/commons-datastore-mongodb/pom.xml b/commons-datastore/commons-datastore-mongodb/pom.xml index 087831d18..0e4b5f0b3 100644 --- a/commons-datastore/commons-datastore-mongodb/pom.xml +++ b/commons-datastore/commons-datastore-mongodb/pom.xml @@ -55,10 +55,32 @@ org.hamcrest hamcrest-core + junit junit - test + compile + true + + + de.flapdoodle.embed + de.flapdoodle.embed.mongo + true + + + de.flapdoodle.embed + de.flapdoodle.embed.process + true + + + de.flapdoodle.reverse + de.flapdoodle.reverse + true org.mockito diff --git a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBManager.java b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBManager.java new file mode 100644 index 000000000..c83fc7956 --- /dev/null +++ b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBManager.java @@ -0,0 +1,301 @@ +/* + * Copyright 2015-2020 OpenCB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencb.commons.datastore.mongodb.test; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoDatabase; +import de.flapdoodle.embed.mongo.commands.MongodArguments; +import de.flapdoodle.embed.mongo.config.Net; +import de.flapdoodle.embed.mongo.config.Storage; +import de.flapdoodle.embed.mongo.distribution.Version; +import de.flapdoodle.embed.mongo.transitions.Mongod; +import de.flapdoodle.embed.mongo.transitions.RunningMongodProcess; +import de.flapdoodle.reverse.TransitionWalker; +import de.flapdoodle.reverse.transitions.Start; +import org.bson.Document; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.PrintStream; +import java.net.ServerSocket; +import java.util.Collections; + +/** + * Manages a process-singleton embedded MongoDB instance for testing purposes, backed by flapdoodle. + * Provides faster test execution and better isolation than relying on an external MongoDB service. + * + *

Configuration via system properties:

+ *
    + *
  • opencb.test.embeddedMongo - Enable/disable embedded MongoDB (default: true)
  • + *
  • opencb.test.mongodb.version - MongoDB version to use (default: 7.0)
  • + *
  • opencb.test.mongo.verbose - Enable verbose mongod stdout/stderr passthrough (default: false)
  • + *
+ * + *

The embedded mongod is started lazily on the first {@link #start()} call, reused across the + * entire JVM, and stopped via a shutdown hook. A single-member replica set {@code rs0} is + * initialised so transactions are available.

+ */ +public final class EmbeddedMongoDBManager { + private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddedMongoDBManager.class); + private static final String DEFAULT_MONGODB_VERSION = "7.0"; + + private static EmbeddedMongoDBManager instance; + private TransitionWalker.ReachedState runningMongod; + private int port; + private final boolean enabled; + private final String mongoVersion; + private final boolean verbose; + + private EmbeddedMongoDBManager() { + this.enabled = Boolean.parseBoolean(System.getProperty("opencb.test.embeddedMongo", "true")); + this.mongoVersion = System.getProperty("opencb.test.mongodb.version", DEFAULT_MONGODB_VERSION); + this.verbose = Boolean.parseBoolean(System.getProperty("opencb.test.mongo.verbose", "false")); + } + + public static synchronized EmbeddedMongoDBManager getInstance() { + if (instance == null) { + instance = new EmbeddedMongoDBManager(); + } + return instance; + } + + private Version.Main getMongoVersion() { + switch (mongoVersion) { + case "3.6": + return Version.Main.V3_6; + case "4.0": + return Version.Main.V4_0; + case "4.2": + return Version.Main.V4_2; + case "4.4": + return Version.Main.V4_4; + case "5.0": + return Version.Main.V5_0; + case "6.0": + return Version.Main.V6_0; + case "7.0": + return Version.Main.V7_0; + case "8.0": + return Version.Main.V8_0; + case "8.1": + return Version.Main.V8_1; + default: + throw new IllegalArgumentException("Unsupported MongoDB version: " + mongoVersion); + } + } + + public synchronized void start() throws IOException { + if (!enabled) { + LOGGER.info("Embedded MongoDB is disabled. Using external MongoDB instance."); + return; + } + + if (runningMongod != null) { + LOGGER.debug("Embedded MongoDB is already running on port {}", port); + return; + } + + try { + LOGGER.info("Starting embedded MongoDB {} with replica set support", mongoVersion); + + port = findAvailablePort(); + LOGGER.info("Found available port: {}", port); + + if (!verbose) { + System.setOut(new FilteringPrintStream(System.out)); + System.setErr(new FilteringPrintStream(System.err)); + LOGGER.info("MongoDB output filtering enabled (use -Dopencb.test.mongo.verbose=true to see all logs)"); + } + + runningMongod = Mongod.instance() + .withNet(Start.to(Net.class).initializedWith(Net.defaults().withPort(port))) + .withMongodArguments(Start.to(MongodArguments.class) + .initializedWith(MongodArguments.defaults() + .withReplication(Storage.of("rs0", 10)))) + .start(getMongoVersion()); + + LOGGER.info("Embedded MongoDB {} started on port {}, initializing replica set...", mongoVersion, port); + + Thread.sleep(500); + + initializeReplicaSet(); + + LOGGER.info("Embedded MongoDB {} with replica set 'rs0' ready on port {}", mongoVersion, port); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + if (verbose) { + LOGGER.info("Shutting down embedded MongoDB via shutdown hook"); + } + stop(); + })); + + } catch (Exception e) { + LOGGER.error("Failed to start embedded MongoDB version {}", mongoVersion, e); + throw new IOException("Failed to start embedded MongoDB", e); + } + } + + private void initializeReplicaSet() { + MongoClient mongoClient = null; + try { + String connectionString = String.format("mongodb://localhost:%d", port); + mongoClient = MongoClients.create(connectionString); + + MongoDatabase adminDb = mongoClient.getDatabase("admin"); + + Document config = new Document("_id", "rs0") + .append("members", Collections.singletonList( + new Document("_id", 0) + .append("host", "localhost:" + port) + )); + + adminDb.runCommand(new Document("replSetInitiate", config)); + + LOGGER.debug("Replica set initiation command sent"); + + long timeout = System.currentTimeMillis() + 30000; + while (System.currentTimeMillis() < timeout) { + try { + Document result = adminDb.runCommand(new Document("replSetGetStatus", 1)); + Integer myState = result.getInteger("myState"); + if (myState != null && myState == 1) { + LOGGER.info("Replica set initialized successfully and is PRIMARY"); + return; + } + LOGGER.debug("Replica set state: {}, waiting for PRIMARY...", myState); + } catch (Exception e) { + LOGGER.trace("Waiting for replica set to be ready: {}", e.getMessage()); + } + Thread.sleep(100); + } + + throw new RuntimeException("Replica set initialization timed out after 30 seconds"); + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Replica set initialization was interrupted", e); + } catch (Exception e) { + LOGGER.error("Failed to initialize replica set", e); + throw new RuntimeException("Could not initialize replica set", e); + } finally { + if (mongoClient != null) { + try { + mongoClient.close(); + } catch (Exception e) { + LOGGER.warn("Error closing MongoDB client", e); + } + } + } + } + + public synchronized void stop() { + if (!enabled) { + return; + } + + if (runningMongod != null) { + LOGGER.info("Stopping embedded MongoDB on port {}", port); + try { + runningMongod.close(); + } catch (Exception e) { + LOGGER.warn("Error stopping embedded MongoDB", e); + } finally { + runningMongod = null; + } + } + } + + public int getPort() { + return port; + } + + public String getConnectionString() { + if (!enabled) { + return "localhost:27017"; + } + return "localhost:" + port; + } + + public boolean isEnabled() { + return enabled; + } + + public boolean isRunning() { + return enabled && runningMongod != null; + } + + private int findAvailablePort() throws IOException { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new IOException("Failed to find an available port", e); + } + } + + private static class FilteringPrintStream extends PrintStream { + private final PrintStream original; + private final StringBuilder lineBuffer = new StringBuilder(); + + FilteringPrintStream(PrintStream original) { + super(original); + this.original = original; + } + + @Override + public void write(int b) { + if (b == '\n') { + String line = lineBuffer.toString(); + boolean isMongodLog = line.contains("{\"t\":{\"$date\":") + || line.startsWith("[mongod output]") + || line.startsWith("[mongod error]"); + if (!isMongodLog) { + original.print(lineBuffer.toString()); + original.write(b); + original.flush(); + } + lineBuffer.setLength(0); + } else if (b != '\r') { + lineBuffer.append((char) b); + } + } + + @Override + public void write(byte[] buf, int off, int len) { + for (int i = 0; i < len; i++) { + write(buf[off + i]); + } + } + + @Override + public void flush() { + if (lineBuffer.length() > 0) { + String line = lineBuffer.toString(); + boolean isMongodLog = line.contains("{\"t\":{\"$date\":") + || line.startsWith("[mongod output]") + || line.startsWith("[mongod error]"); + if (!isMongodLog) { + original.print(lineBuffer.toString()); + } + lineBuffer.setLength(0); + } + original.flush(); + } + } +} diff --git a/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBRule.java b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBRule.java new file mode 100644 index 000000000..eff6956ff --- /dev/null +++ b/commons-datastore/commons-datastore-mongodb/src/main/java/org/opencb/commons/datastore/mongodb/test/EmbeddedMongoDBRule.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2020 OpenCB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencb.commons.datastore.mongodb.test; + +import org.junit.rules.ExternalResource; + +/** + * JUnit {@link org.junit.rules.ExternalResource} wrapper around the process-singleton + * {@link EmbeddedMongoDBManager}. The rule ensures the embedded mongod is running when the + * test class starts; it intentionally does NOT stop it in {@link #after()} — the singleton's + * JVM shutdown hook owns the teardown. + * + *

Intended usage:

+ *
+ * public class MyMongoTest {
+ *     @ClassRule
+ *     public static final EmbeddedMongoDBRule EMBEDDED_MONGO = new EmbeddedMongoDBRule();
+ *
+ *     @Test public void ... { }
+ * }
+ * 
+ * + *

Subclasses may extend this rule to add per-class housekeeping (e.g. closing engine-level + * connections) in {@link #after()}.

+ */ +public class EmbeddedMongoDBRule extends ExternalResource { + + @Override + protected void before() throws Throwable { + EmbeddedMongoDBManager.getInstance().start(); + } + + @Override + protected void after() { + // Intentionally empty: the embedded mongod is a JVM-wide singleton owned by + // EmbeddedMongoDBManager. Its shutdown hook stops it on JVM exit. + } + + public EmbeddedMongoDBManager getManager() { + return EmbeddedMongoDBManager.getInstance(); + } + + public String getConnectionString() { + return EmbeddedMongoDBManager.getInstance().getConnectionString(); + } + + public int getPort() { + return EmbeddedMongoDBManager.getInstance().getPort(); + } +} diff --git a/pom.xml b/pom.xml index 6aa823ad0..426b2e8ef 100644 --- a/pom.xml +++ b/pom.xml @@ -31,6 +31,9 @@ 1.3 4.13.2 2.2.27 + 4.21.0 + 1.9.1 + 4.17.0 opencb https://sonarcloud.io @@ -194,6 +197,23 @@ ${mockito.version} test
+ + + + de.flapdoodle.embed + de.flapdoodle.embed.mongo + ${flapdoodle.version} + + + de.flapdoodle.embed + de.flapdoodle.embed.process + ${flapdoodle.process.version} + + + de.flapdoodle.reverse + de.flapdoodle.reverse + ${flapdoodle.reverse.version} +