diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java index d9d79d3e..a13005c7 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelExample.java @@ -8,6 +8,7 @@ import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.model.ParallelResult; /** * Example demonstrating parallel branch execution with the Durable Execution SDK. @@ -38,8 +39,9 @@ public Output handleRequest(Input input, DurableContext context) { var config = ParallelConfig.builder().build(); var futures = new ArrayList>(items.size()); + var parallel = context.parallel("process-items", config); - try (var parallel = context.parallel("process-items", config)) { + try (parallel) { for (var item : items) { var future = parallel.branch("process-" + item, String.class, branchCtx -> { branchCtx.getLogger().info("Processing item: {}", item); @@ -49,7 +51,12 @@ public Output handleRequest(Input input, DurableContext context) { } } // join() called here via AutoCloseable - logger.info("All branches complete, collecting results"); + ParallelResult parallelResult = parallel.get(); + logger.info( + "Parallel complete: total={}, succeeded={}, failed={}", + parallelResult.getTotalBranches(), + parallelResult.getSucceededBranches(), + parallelResult.getFailedBranches()); var results = futures.stream().map(DurableFuture::get).toList(); diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java index 38f4b903..b498db8d 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExample.java @@ -9,6 +9,7 @@ import software.amazon.lambda.durable.DurableHandler; import software.amazon.lambda.durable.ParallelConfig; import software.amazon.lambda.durable.StepConfig; +import software.amazon.lambda.durable.model.ParallelResult; import software.amazon.lambda.durable.retry.RetryStrategies; /** @@ -24,9 +25,9 @@ public class ParallelFailureToleranceExample extends DurableHandler { - public record Input(List services, int toleratedFailures) {} + public record Input(List services, int toleratedFailures, int minSuccessful) {} - public record Output(List succeeded, List failed) {} + public record Output(int succeeded, int failed) {} @Override public Output handleRequest(Input input, DurableContext context) { @@ -34,12 +35,14 @@ public Output handleRequest(Input input, DurableContext context) { logger.info("Starting parallel execution with toleratedFailureCount={}", input.toleratedFailures()); var config = ParallelConfig.builder() + .minSuccessful(input.minSuccessful()) .toleratedFailureCount(input.toleratedFailures()) .build(); var futures = new ArrayList>(input.services().size()); + var parallel = context.parallel("call-services", config); - try (var parallel = context.parallel("call-services", config)) { + try (parallel) { for (var service : input.services()) { var future = parallel.branch("call-" + service, String.class, branchCtx -> { return branchCtx.step( @@ -59,20 +62,17 @@ public Output handleRequest(Input input, DurableContext context) { } } - var succeeded = new ArrayList(); - var failed = new ArrayList(); + ParallelResult parallelResult = parallel.get(); + logger.info( + "Parallel complete: succeeded={}, failed={}, status={}", + parallelResult.getSucceededBranches(), + parallelResult.getFailedBranches(), + parallelResult.getCompletionStatus().isSucceeded() ? "succeeded" : "failed"); - for (int i = 0; i < futures.size(); i++) { - try { - var result = futures.get(i).get(); - succeeded.add(result); - } catch (Exception e) { - failed.add(input.services().get(i)); - logger.info("Branch failed for service {}: {}", input.services().get(i), e.getMessage()); - } - } + var succeeded = parallelResult.getSucceededBranches(); + var failed = parallelResult.getFailedBranches(); - logger.info("Completed: {} succeeded, {} failed", succeeded.size(), failed.size()); + logger.info("Completed: {} succeeded, {} failed", succeeded, failed); return new Output(succeeded, failed); } } diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java index a25464b4..63bbee21 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java @@ -9,6 +9,7 @@ import software.amazon.lambda.durable.DurableFuture; import software.amazon.lambda.durable.DurableHandler; import software.amazon.lambda.durable.ParallelConfig; +import software.amazon.lambda.durable.model.ParallelResult; /** * Example demonstrating parallel branches where some branches include wait operations. @@ -29,7 +30,7 @@ public class ParallelWithWaitExample public record Input(String userId, String message) {} - public record Output(List deliveries) {} + public record Output(List deliveries, int success, int faiure) {} @Override public Output handleRequest(Input input, DurableContext context) { @@ -38,8 +39,9 @@ public Output handleRequest(Input input, DurableContext context) { var config = ParallelConfig.builder().build(); var futures = new ArrayList>(3); + var parallel = context.parallel("notify", config); - try (var parallel = context.parallel("notify", config)) { + try (parallel) { // Branch 1: email — no wait, deliver immediately futures.add(parallel.branch("email", String.class, ctx -> { @@ -60,10 +62,12 @@ public Output handleRequest(Input input, DurableContext context) { })); } + ParallelResult result = parallel.get(); + var deliveries = futures.stream().map(DurableFuture::get).toList(); logger.info("All {} notifications delivered", deliveries.size()); // Test replay context.wait("wait for finalization", Duration.ofSeconds(5)); - return new Output(deliveries); + return new Output(deliveries, result.getSucceededBranches(), result.getFailedBranches()); } } diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java index 2b57970e..7d4dd72d 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelFailureToleranceExampleTest.java @@ -17,17 +17,14 @@ void succeedsWhenFailuresAreWithinTolerance() { var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); // 2 good services, 1 bad — toleratedFailureCount=1 so the parallel op still succeeds - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "svc-c"), 1, -1); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); var output = result.getResult(ParallelFailureToleranceExample.Output.class); - assertEquals(2, output.succeeded().size()); - assertEquals(1, output.failed().size()); - assertTrue(output.succeeded().contains("ok:svc-a")); - assertTrue(output.succeeded().contains("ok:svc-c")); - assertTrue(output.failed().contains("bad-svc-b")); + assertEquals(2, output.succeeded()); + assertEquals(1, output.failed()); } @Test @@ -35,14 +32,13 @@ void succeedsWhenAllBranchesSucceed() { var handler = new ParallelFailureToleranceExample(); var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "svc-b", "svc-c"), 2, -1); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); var output = result.getResult(ParallelFailureToleranceExample.Output.class); - assertEquals(3, output.succeeded().size()); - assertTrue(output.failed().isEmpty()); + assertEquals(3, output.succeeded()); } @Test @@ -51,13 +47,13 @@ void failsWhenFailuresExceedTolerance() { var runner = LocalDurableTestRunner.create(ParallelFailureToleranceExample.Input.class, handler); // 2 bad services, toleratedFailureCount=1 — second failure exceeds tolerance - var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "bad-svc-c"), 1); + var input = new ParallelFailureToleranceExample.Input(List.of("svc-a", "bad-svc-b", "bad-svc-c"), 1, 2); var result = runner.runUntilComplete(input); assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); var output = result.getResult(ParallelFailureToleranceExample.Output.class); - assertEquals(2, output.failed().size()); - assertEquals(1, output.succeeded().size()); + assertEquals(2, output.failed()); + assertEquals(1, output.succeeded()); } } diff --git a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExampleTest.java b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExampleTest.java index 4352ed23..bb44dc6c 100644 --- a/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExampleTest.java +++ b/examples/src/test/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExampleTest.java @@ -29,5 +29,6 @@ void completesAfterManuallyAdvancingWaits() { var output = result.getResult(ParallelWithWaitExample.Output.class); assertEquals(List.of("email:world", "sms:world", "push:world"), output.deliveries()); + assertEquals(3, output.success()); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java index 6ca03166..6debb975 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelContext.java @@ -3,15 +3,17 @@ package software.amazon.lambda.durable; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import software.amazon.lambda.durable.model.ParallelResult; import software.amazon.lambda.durable.operation.ParallelOperation; /** User-facing context for managing parallel branch execution within a durable function. */ -public class ParallelContext implements AutoCloseable { +public class ParallelContext implements AutoCloseable, DurableFuture { - private final ParallelOperation parallelOperation; + private final ParallelOperation parallelOperation; private final DurableContext durableContext; - private boolean joined; + private final AtomicBoolean joined = new AtomicBoolean(false); /** * Creates a new ParallelContext. @@ -19,7 +21,7 @@ public class ParallelContext implements AutoCloseable { * @param parallelOperation the underlying parallel operation managing concurrency * @param durableContext the durable context for creating child operations */ - public ParallelContext(ParallelOperation parallelOperation, DurableContext durableContext) { + public ParallelContext(ParallelOperation parallelOperation, DurableContext durableContext) { this.parallelOperation = Objects.requireNonNull(parallelOperation, "parallelOperation cannot be null"); this.durableContext = Objects.requireNonNull(durableContext, "durableContext cannot be null"); } @@ -49,7 +51,7 @@ public DurableFuture branch(String name, Class resultType, Function DurableFuture branch(String name, TypeToken resultType, Function func) { - if (joined) { + if (joined.get()) { throw new IllegalStateException("Cannot add branches after join() has been called"); } return parallelOperation.addItem( @@ -66,11 +68,23 @@ public DurableFuture branch(String name, TypeToken resultType, Functio * @throws software.amazon.lambda.durable.exception.ConcurrencyExecutionException if failure threshold exceeded */ public void join() { - if (joined) { + if (!joined.compareAndSet(false, true)) { return; } - joined = true; - parallelOperation.get(); + parallelOperation.join(); + } + + /** + * Blocks until the parallel operation completes and returns the {@link ParallelResult}. + * + *

Calling {@code get()} implicitly calls {@code join()} if it has not been called yet. + * + * @return the {@link ParallelResult} summarising branch outcomes + */ + @Override + public ParallelResult get() { + joined.set(true); + return parallelOperation.get(); } /** diff --git a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java index e4bb6531..e171276d 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/context/DurableContextImpl.java @@ -558,9 +558,8 @@ public ParallelContext parallel(String name, ParallelConfig config) { Objects.requireNonNull(config, "config cannot be null"); var operationId = nextOperationId(); - var parallelOp = new ParallelOperation<>( + var parallelOp = new ParallelOperation( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL), - TypeToken.get(Void.class), getDurableConfig().getSerDes(), this, config.maxConcurrency(), diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java new file mode 100644 index 00000000..f0e0fcf1 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/ParallelResult.java @@ -0,0 +1,48 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.model; + +/** + * Summary result of a parallel operation. + * + *

Captures the aggregate outcome of a parallel execution: how many branches were registered, how many succeeded, how + * many failed, and why the operation completed. + */ +public class ParallelResult { + + private final int totalBranches; + private final int succeededBranches; + private final int failedBranches; + private final ConcurrencyCompletionStatus completionStatus; + + public ParallelResult( + int totalBranches, + int succeededBranches, + int failedBranches, + ConcurrencyCompletionStatus completionStatus) { + this.totalBranches = totalBranches; + this.succeededBranches = succeededBranches; + this.failedBranches = failedBranches; + this.completionStatus = completionStatus; + } + + /** Returns the total number of branches registered before {@code join()} was called. */ + public int getTotalBranches() { + return totalBranches; + } + + /** Returns the number of branches that completed without throwing. */ + public int getSucceededBranches() { + return succeededBranches; + } + + /** Returns the number of branches that threw an exception. */ + public int getFailedBranches() { + return failedBranches; + } + + /** Returns the status indicating why the parallel operation completed. */ + public ConcurrencyCompletionStatus getCompletionStatus() { + return completionStatus; + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 7fd8fd5f..f46b890b 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -54,6 +54,7 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Set completedOperations = Collections.synchronizedSet(new HashSet()); private OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; + private ConcurrencyCompletionStatus completionStatus; protected ConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -203,9 +204,9 @@ public void onItemComplete(ChildContextOperation child) { } runningCount.decrementAndGet(); - var status = canComplete(); - if (status != null) { - handleComplete(status); + this.completionStatus = canComplete(); + if (this.completionStatus != null) { + handleComplete(this.completionStatus); } else { executeNextItemIfAllowed(); } @@ -245,17 +246,13 @@ private void handleComplete(ConcurrencyCompletionStatus status) { * Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles * zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation. */ - protected void join() { + public void join() { validateItemCount(); isJoined.set(true); - if (childOperations.isEmpty()) { - return; - } - synchronized (this) { - var status = canComplete(); - if (status != null) { - handleComplete(status); + this.completionStatus = canComplete(); + if (this.completionStatus != null) { + handleComplete(this.completionStatus); } } @@ -274,6 +271,10 @@ protected int getTotalItems() { return childOperations.size(); } + protected ConcurrencyCompletionStatus getCompletionStatus() { + return completionStatus; + } + protected List> getChildOperations() { return Collections.unmodifiableList(childOperations); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index a68e3167..13aaa5d5 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -11,11 +11,11 @@ import software.amazon.lambda.durable.DurableContext; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; -import software.amazon.lambda.durable.exception.ConcurrencyExecutionException; import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; +import software.amazon.lambda.durable.model.ParallelResult; import software.amazon.lambda.durable.serde.SerDes; /** @@ -25,8 +25,8 @@ * *

    *
  • Creates branches as {@link ChildContextOperation} with {@link OperationSubType#PARALLEL_BRANCH} - *
  • Checkpoints SUCCESS/FAIL on the parallel context when completion criteria are met - *
  • Throws {@link ConcurrencyExecutionException} when the operation fails + *
  • Checkpoints SUCCESS on the parallel context when completion criteria are met + *
  • Returns a {@link ParallelResult} summarising branch outcomes *
* *

Context hierarchy: @@ -38,10 +38,8 @@ * ├── Branch 2 context (ChildContextOperation with PARALLEL_BRANCH) * └── Branch N context (ChildContextOperation with PARALLEL_BRANCH) * - * - * @param the result type of this operation (typically Void) */ -public class ParallelOperation extends ConcurrencyOperation { +public class ParallelOperation extends ConcurrencyOperation { private final int minSuccessful; private final int toleratedFailureCount; @@ -49,13 +47,12 @@ public class ParallelOperation extends ConcurrencyOperation { public ParallelOperation( OperationIdentifier operationIdentifier, - TypeToken resultTypeToken, SerDes resultSerDes, DurableContextImpl durableContext, int maxConcurrency, int minSuccessful, int toleratedFailureCount) { - super(operationIdentifier, resultTypeToken, resultSerDes, durableContext, maxConcurrency); + super(operationIdentifier, new TypeToken() {}, resultSerDes, durableContext, maxConcurrency); this.minSuccessful = minSuccessful; this.toleratedFailureCount = toleratedFailureCount; } @@ -110,15 +107,14 @@ protected void replay(Operation existing) { } @Override - public T get() { - // TODO: implement proper return value handling + public ParallelResult get() { join(); - return null; + return new ParallelResult(getTotalItems(), getSucceededCount(), getFailedCount(), getCompletionStatus()); } @Override protected void validateItemCount() { - if (minSuccessful > getTotalItems() - getFailedCount()) { + if (minSuccessful > getTotalItems()) { throw new IllegalArgumentException("minSuccessful (" + minSuccessful + ") exceeds the number of registered items (" + getTotalItems() + ")"); } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java new file mode 100644 index 00000000..c5cfeebc --- /dev/null +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/ParallelResultTest.java @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.model; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +class ParallelResultTest { + + @Test + void allBranchesSucceed_countsAreCorrect() { + var result = new ParallelResult(3, 3, 0, ConcurrencyCompletionStatus.ALL_COMPLETED); + + assertEquals(3, result.getTotalBranches()); + assertEquals(3, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + } +} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 5b335065..a44e5345 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -27,6 +27,7 @@ import software.amazon.lambda.durable.execution.OperationIdGenerator; import software.amazon.lambda.durable.execution.ThreadContext; import software.amazon.lambda.durable.execution.ThreadType; +import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.JacksonSerDes; @@ -36,7 +37,6 @@ class ParallelOperationTest { private static final SerDes SER_DES = new JacksonSerDes(); private static final String OPERATION_ID = "parallel-op-1"; - private static final TypeToken RESULT_TYPE = TypeToken.get(Void.class); private DurableContextImpl durableContext; private ExecutionManager executionManager; @@ -107,10 +107,9 @@ void setUp() { .sendOperationUpdate(any()); } - private ParallelOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) { - return new ParallelOperation<>( + private ParallelOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) { + return new ParallelOperation( OperationIdentifier.of(OPERATION_ID, "test-parallel", OperationType.CONTEXT, OperationSubType.PARALLEL), - RESULT_TYPE, SER_DES, durableContext, maxConcurrency, @@ -160,11 +159,10 @@ void branchCreation_childOperationHasParentReference() throws Exception { assertInstanceOf(ChildContextOperation.class, childOp); } - // ===== handleSuccess checkpointing ===== + // ===== All branches succeed ===== @Test - void handleSuccess_sendsSucceedCheckpoint() throws Exception { - // Set up two branches that both succeed in replay + void allBranchesSucceed_sendsSucceedCheckpointAndReturnsCorrectResult() throws Exception { when(executionManager.getOperationAndUpdateReplayState("child-1")) .thenReturn(Operation.builder() .id("child-1") @@ -191,16 +189,20 @@ void handleSuccess_sendsSucceedCheckpoint() throws Exception { op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); - op.get(); + var result = op.get(); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + assertEquals(2, result.getTotalBranches()); + assertEquals(2, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + assertTrue(result.getCompletionStatus().isSucceeded()); } // ===== MinSuccessful satisfaction ===== @Test - void minSuccessful_joinCompletesWhenThresholdMet() throws Exception { - // minSuccessful=1, 2 branches — first succeeds → should complete without throwing + void minSuccessful_completesWhenThresholdMetAndReturnsResult() throws Exception { when(executionManager.getOperationAndUpdateReplayState("child-1")) .thenReturn(Operation.builder() .id("child-1") @@ -216,9 +218,14 @@ void minSuccessful_joinCompletesWhenThresholdMet() throws Exception { setOperationIdGenerator(op, mockIdGenerator); op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); - // Should not throw - op.get(); - assertEquals(1, op.getSucceededCount()); + var result = op.get(); + + verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + assertEquals(1, result.getTotalBranches()); + assertEquals(1, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.getCompletionStatus()); + assertTrue(result.getCompletionStatus().isSucceeded()); } // ===== Context hierarchy ===== @@ -240,8 +247,7 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { // ===== Replay ===== @Test - void replay_doesNotSendStartCheckpoint() throws Exception { - // Simulate the parallel operation already existing in the service (STARTED status) + void replay_fromStartedState_sendsSucceedCheckpointAndReturnsResult() throws Exception { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) .thenReturn(Operation.builder() .id(OPERATION_ID) @@ -250,7 +256,6 @@ void replay_doesNotSendStartCheckpoint() throws Exception { .subType(OperationSubType.PARALLEL.getValue()) .status(OperationStatus.STARTED) .build()); - // Both branches already succeeded when(executionManager.getOperationAndUpdateReplayState("child-1")) .thenReturn(Operation.builder() .id("child-1") @@ -278,16 +283,20 @@ void replay_doesNotSendStartCheckpoint() throws Exception { op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); - op.get(); + var result = op.get(); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, times(1)) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + assertEquals(2, result.getTotalBranches()); + assertEquals(2, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); } @Test - void replay_doesNotSendSucceedCheckpointWhenParallelAlreadySucceeded() throws Exception { + void replay_fromSucceededState_skipsCheckpointAndReturnsResult() throws Exception { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) .thenReturn(Operation.builder() .id(OPERATION_ID) @@ -323,20 +332,22 @@ void replay_doesNotSendSucceedCheckpointWhenParallelAlreadySucceeded() throws Ex op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); - op.get(); + var result = op.get(); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + assertEquals(2, result.getTotalBranches()); + assertEquals(2, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); } - // ===== handleFailure still sends SUCCEED ===== + // ===== Branch failure sends SUCCEED checkpoint and returns result ===== @Test - void handleFailure_sendsSucceedCheckpointEvenWhenFailureToleranceExceeded() throws Exception { - // toleratedFailureCount=0, so the first failure triggers handleFailure - // ParallelOperation.handleFailure() delegates to handleSuccess(), so SUCCEED must be sent + void branchFailure_sendsSucceedCheckpointAndReturnsFailureCounts() throws Exception { when(executionManager.getOperationAndUpdateReplayState("child-1")) .thenReturn(Operation.builder() .id("child-1") @@ -356,10 +367,69 @@ void handleFailure_sendsSucceedCheckpointEvenWhenFailureToleranceExceeded() thro TypeToken.get(String.class), SER_DES); - op.get(); + var result = assertDoesNotThrow(() -> op.get()); verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); verify(executionManager, never()) .sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL)); + assertEquals(1, result.getTotalBranches()); + assertEquals(0, result.getSucceededBranches()); + assertEquals(1, result.getFailedBranches()); + assertFalse(result.getCompletionStatus().isSucceeded()); + } + + @Test + void get_someBranchesFail_returnsCorrectCountsAndFailureStatus() throws Exception { + when(executionManager.getOperationAndUpdateReplayState("child-1")) + .thenReturn(Operation.builder() + .id("child-1") + .name("branch-1") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.SUCCEEDED) + .contextDetails( + ContextDetails.builder().result("\"r1\"").build()) + .build()); + when(executionManager.getOperationAndUpdateReplayState("child-2")) + .thenReturn(Operation.builder() + .id("child-2") + .name("branch-2") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.FAILED) + .build()); + + // toleratedFailureCount=1 so the operation completes after both branches finish + var op = createOperation(-1, -1, 1); + setOperationIdGenerator(op, mockIdGenerator); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + op.addItem( + "branch-2", + ctx -> { + throw new RuntimeException("branch failed"); + }, + TypeToken.get(String.class), + SER_DES); + + var result = op.get(); + + verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + assertEquals(2, result.getTotalBranches()); + assertEquals(1, result.getSucceededBranches()); + assertEquals(1, result.getFailedBranches()); + assertFalse(result.getCompletionStatus().isSucceeded()); + } + + @Test + void get_zeroBranches_returnsAllZerosAndAllCompletedStatus() throws Exception { + var op = createOperation(-1, -1, 0); + + var result = op.get(); + + assertEquals(0, result.getTotalBranches()); + assertEquals(0, result.getSucceededBranches()); + assertEquals(0, result.getFailedBranches()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.getCompletionStatus()); + verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); } }