From bc2b52bad980ae3955ccfc80c3a0265be446ca0c Mon Sep 17 00:00:00 2001 From: Ayushi Ahjolia Date: Thu, 19 Mar 2026 21:05:00 -0700 Subject: [PATCH] refactor(map): Fix null results, abstract completion logic --- docs/core/map.md | 22 ++- docs/design.md | 14 +- .../durable/MapInputValidationTest.java | 4 +- .../lambda/durable/MapIntegrationTest.java | 114 ++++++++++++++-- .../durable/model/CompletionReason.java | 10 -- .../model/ConcurrencyCompletionStatus.java | 1 - .../lambda/durable/model/MapResult.java | 8 +- .../operation/ChildContextOperation.java | 3 +- .../operation/ConcurrencyOperation.java | 128 ++++++------------ .../durable/operation/MapOperation.java | 56 ++------ .../durable/operation/ParallelOperation.java | 48 +++++-- .../lambda/durable/model/MapResultTest.java | 26 ++-- .../operation/BaseDurableOperationTest.java | 3 - .../operation/ConcurrencyOperationTest.java | 43 ++++-- 14 files changed, 280 insertions(+), 200 deletions(-) delete mode 100644 sdk/src/main/java/software/amazon/lambda/durable/model/CompletionReason.java diff --git a/docs/core/map.md b/docs/core/map.md index 0037e740b..544838d17 100644 --- a/docs/core/map.md +++ b/docs/core/map.md @@ -41,14 +41,14 @@ MapResult result = future.get(); | Method | Description | |--------|-------------| | `getResult(i)` | Result at index `i`, or `null` if that item failed | -| `getError(i)` | `ErrorObject` at index `i`, or `null` if that item succeeded | +| `getError(i)` | `MapError` at index `i`, or `null` if that item succeeded | | `getItem(i)` | The `MapResultItem` at index `i` with status, result, and error | | `allSucceeded()` | `true` if every item succeeded | | `size()` | Number of items in the result | | `items()` | All result items as an unmodifiable list | | `results()` | All results as an unmodifiable list (nulls for failed items) | | `succeeded()` | Only the non-null (successful) results | -| `failed()` | Only the non-null `ErrorObject`s | +| `failed()` | Only the non-null `MapError`s | | `completionReason()` | Why the operation completed (`ALL_COMPLETED`, `MIN_SUCCESSFUL_REACHED`, `FAILURE_TOLERANCE_EXCEEDED`) | ### MapResultItem @@ -59,7 +59,17 @@ Each `MapResultItem` contains: |-------|-------------| | `status()` | `SUCCEEDED`, `FAILED`, or `NOT_STARTED` | | `result()` | The result value, or `null` if failed/not started | -| `error()` | The error details as `ErrorObject`, or `null` if succeeded/not started | +| `error()` | The error details as `MapError`, or `null` if succeeded/not started | + +### MapError + +Failed items store error details as `MapError`, a serializable record that survives checkpoint-and-replay cycles: + +| Field | Description | +|-------|-------------| +| `errorType()` | Fully qualified exception class name (e.g., `java.lang.RuntimeException`) | +| `errorMessage()` | The exception message | +| `stackTrace()` | Stack trace frames as a list of strings, or `null` | ### Error Isolation @@ -87,9 +97,11 @@ var config = MapConfig.builder() .build(); var result = ctx.map("process-orders", items, OrderResult.class, - (orderId, index, childCtx) -> process(childCtx, orderId), config); + (orderId, index, childCtx) -> process(orderId, childCtx), config); ``` +`MapConfig` also supports a custom `serDes` for serialization via `.serDes(customSerDes)`. By default, the context's serializer is used. `maxConcurrency` must be at least 1 if set. + #### Concurrency Limiting `maxConcurrency` controls how many items execute concurrently. When set, items beyond the limit are queued and started as earlier items complete. Default is `null` (unlimited). @@ -158,7 +170,7 @@ The function passed to `map()` is a `MapFunction`: ```java @FunctionalInterface public interface MapFunction { - O apply(I item, int index, DurableContext context) throws Exception; + O apply(I item, int index, DurableContext context); } ``` diff --git a/docs/design.md b/docs/design.md index 97cb6015d..e5100d024 100644 --- a/docs/design.md +++ b/docs/design.md @@ -199,8 +199,12 @@ context.step("name", Type.class, supplier, │ - StepOperation │ │ - Queues requests │ │ - WaitOperation │ │ - Batches API calls (750KB) │ │ - WaitForConditionOperation │ │ │ -│ - execute() / get() │ │ - Notifies via callback │ -└──────────────────────────────┘ └──────────────────────────────┘ +│ - ConcurrencyOperation │ │ - Notifies via callback │ +│ - MapOperation │ └──────────────────────────────┘ +│ - ParallelOperation │ +│ - ChildContextOperation │ +│ - execute() / get() │ +└──────────────────────────────┘ │ ▼ ┌──────────────────────────────┐ @@ -235,7 +239,11 @@ software.amazon.lambda.durable │ ├── InvokeOperation # Invoke logic │ ├── CallbackOperation # Callback logic │ ├── WaitOperation # Wait logic -│ └── WaitForConditionOperation # Polling condition logic +│ ├── WaitForConditionOperation # Polling condition logic +│ ├── ConcurrencyOperation # Shared base for map/parallel +│ ├── MapOperation # Map operation logic +│ ├── ParallelOperation # Parallel operation logic +│ └── ChildContextOperation # Per-item child context execution │ ├── logging/ │ ├── DurableLogger # Context-aware logger wrapper (MDC-based) diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapInputValidationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapInputValidationTest.java index 15bc05796..2c7ea87c6 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapInputValidationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapInputValidationTest.java @@ -7,7 +7,7 @@ import java.util.HashSet; import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.model.CompletionReason; +import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; @@ -54,7 +54,7 @@ void mapWithEmptyCollection_returnsEmptyMapResult() { assertEquals(0, result.size()); assertTrue(result.allSucceeded()); - assertEquals(CompletionReason.ALL_COMPLETED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionReason()); return "done"; }); diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java index e270bcdaa..3bbd8c569 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/MapIntegrationTest.java @@ -9,7 +9,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; -import software.amazon.lambda.durable.model.CompletionReason; +import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.model.MapResultItem; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; @@ -81,7 +81,7 @@ void testMapPartialFailure_failedItemDoesNotPreventOthers() { assertNull(result.getError(0)); assertNull(result.getError(2)); - assertEquals(CompletionReason.ALL_COMPLETED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionReason()); return "done"; }); @@ -252,7 +252,7 @@ void testMapWithToleratedFailureCount_earlyTermination() { }, config); - assertEquals(CompletionReason.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); assertFalse(result.allSucceeded()); assertEquals(5, result.size()); assertEquals("OK", result.getResult(0)); @@ -279,7 +279,7 @@ void testMapWithMinSuccessful_earlyTermination() { var result = context.map( "min-successful", items, String.class, (item, index, ctx) -> item.toUpperCase(), config); - assertEquals(CompletionReason.MIN_SUCCESSFUL_REACHED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); assertEquals(5, result.size()); assertEquals("A", result.getResult(0)); assertEquals("B", result.getResult(1)); @@ -419,7 +419,7 @@ void testMapUnlimitedConcurrencyWithToleratedFailureCount() { }, config); - assertEquals(CompletionReason.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); assertFalse(result.allSucceeded()); return "done"; }); @@ -442,7 +442,7 @@ void testMapReplayWithFailedBranches() { return item.toUpperCase(); }); - // Errors survive replay since they are stored as ErrorObject (not raw Throwable) + // Errors survive replay since they are stored as MapError (not raw Throwable) assertEquals("OK", result.getResult(0)); assertEquals("OK2", result.getResult(2)); return "done"; @@ -531,7 +531,7 @@ void testMapWithAllSuccessfulCompletionConfig_stopsOnFirstFailure() { }, config); - assertEquals(CompletionReason.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); assertEquals("OK1", result.getResult(0)); assertNotNull(result.getError(1)); // Items after the failure should be NOT_STARTED @@ -622,7 +622,7 @@ void testMapWithToleratedFailurePercentage() { }, config); - assertEquals(CompletionReason.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); return "done"; }); @@ -630,6 +630,43 @@ void testMapWithToleratedFailurePercentage() { assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); } + @Test + void testMapWithToleratedFailurePercentage_replay() { + var executionCount = new AtomicInteger(0); + + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var items = List.of("ok1", "FAIL1", "ok2", "FAIL2", "ok3", "FAIL3", "ok4"); + var config = MapConfig.builder() + .completionConfig(CompletionConfig.toleratedFailurePercentage(0.3)) + .build(); + var result = context.map( + "pct-fail-replay", + items, + String.class, + (item, index, ctx) -> { + executionCount.incrementAndGet(); + if (item.startsWith("FAIL")) { + throw new RuntimeException("failed: " + item); + } + return item.toUpperCase(); + }, + config); + + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionReason()); + return "done"; + }); + + var result1 = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result1.getStatus()); + var firstRunCount = executionCount.get(); + + // Replay — with unlimited concurrency, children replay simultaneously. + // Verify completionReason is consistent and no re-execution occurs. + var result2 = runner.run("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); + assertEquals(firstRunCount, executionCount.get(), "Map functions should not re-execute on replay"); + } + @Test void testMapAsyncWithWaitInsideBranches() { var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { @@ -747,7 +784,7 @@ void testMapWithMinSuccessful_replay() { }, config); - assertEquals(CompletionReason.MIN_SUCCESSFUL_REACHED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); assertEquals("A", result.getResult(0)); assertEquals("B", result.getResult(1)); return "done"; @@ -826,4 +863,63 @@ void testMapWithLargeResult_replayChildren() { assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); assertEquals(firstRunCount, executionCount.get(), "Map functions should not re-execute on replay"); } + + @Test + void testMapWithNullResults() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var items = List.of("a", "b", "c"); + var result = context.map("null-map", items, String.class, (item, index, ctx) -> null); + + assertTrue(result.allSucceeded()); + assertEquals(3, result.size()); + for (int i = 0; i < result.size(); i++) { + assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(i).status()); + assertNull(result.getResult(i)); + assertNull(result.getError(i)); + } + return "done"; + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + } + + @Test + void testMultipleMapAsyncInParallel() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var numbers = List.of(1, 2, 3); + var letters = List.of("a", "b"); + var words = List.of("hello", "world", "foo", "bar"); + + var numbersFuture = context.mapAsync("map-numbers", numbers, String.class, (item, index, ctx) -> { + return ctx.step("double-" + index, String.class, stepCtx -> String.valueOf(item * 2)); + }); + + var lettersFuture = context.mapAsync("map-letters", letters, String.class, (item, index, ctx) -> { + return ctx.step("upper-" + index, String.class, stepCtx -> item.toUpperCase()); + }); + + var wordsFuture = context.mapAsync("map-words", words, String.class, (item, index, ctx) -> { + return ctx.step("reverse-" + index, String.class, stepCtx -> new StringBuilder(item) + .reverse() + .toString()); + }); + + var numbersResult = numbersFuture.get(); + var lettersResult = lettersFuture.get(); + var wordsResult = wordsFuture.get(); + + assertTrue(numbersResult.allSucceeded()); + assertTrue(lettersResult.allSucceeded()); + assertTrue(wordsResult.allSucceeded()); + + return String.join(",", numbersResult.results()) + + "|" + String.join(",", lettersResult.results()) + + "|" + String.join(",", wordsResult.results()); + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals("2,4,6|A,B|olleh,dlrow,oof,rab", result.getResult(String.class)); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/CompletionReason.java b/sdk/src/main/java/software/amazon/lambda/durable/model/CompletionReason.java deleted file mode 100644 index ad7b00f71..000000000 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/CompletionReason.java +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -package software.amazon.lambda.durable.model; - -/** Indicates why a concurrent operation (map or parallel) completed. */ -public enum CompletionReason { - ALL_COMPLETED, - MIN_SUCCESSFUL_REACHED, - FAILURE_TOLERANCE_EXCEEDED -} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/ConcurrencyCompletionStatus.java b/sdk/src/main/java/software/amazon/lambda/durable/model/ConcurrencyCompletionStatus.java index 0a7b51984..220e25867 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/ConcurrencyCompletionStatus.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/ConcurrencyCompletionStatus.java @@ -5,7 +5,6 @@ public enum ConcurrencyCompletionStatus { ALL_COMPLETED, MIN_SUCCESSFUL_REACHED, - MIN_SUCCESSFUL_NOT_REACHED, FAILURE_TOLERANCE_EXCEEDED; @Override diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java index 2c7009fcb..a307c4fbf 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/MapResult.java @@ -10,7 +10,7 @@ * *

Holds ordered results from a map operation. Each index corresponds to the input item at the same position. Each * item is represented as a {@link MapResultItem} containing its status, result, and error. Includes the - * {@link CompletionReason} indicating why the operation completed. + * {@link ConcurrencyCompletionStatus} indicating why the operation completed. * *

Errors are stored as {@link MapError} rather than raw Throwable, so they survive serialization across * checkpoint-and-replay cycles without requiring AWS SDK-specific Jackson modules. @@ -19,17 +19,17 @@ * @param completionReason why the operation completed * @param the result type of each item */ -public record MapResult(List> items, CompletionReason completionReason) { +public record MapResult(List> items, ConcurrencyCompletionStatus completionReason) { /** Compact constructor that applies defensive copy and defaults. */ public MapResult { items = items != null ? List.copyOf(items) : Collections.emptyList(); - completionReason = completionReason != null ? completionReason : CompletionReason.ALL_COMPLETED; + completionReason = completionReason != null ? completionReason : ConcurrencyCompletionStatus.ALL_COMPLETED; } /** Returns an empty MapResult with no items. */ public static MapResult empty() { - return new MapResult<>(Collections.emptyList(), CompletionReason.ALL_COMPLETED); + return new MapResult<>(Collections.emptyList(), ConcurrencyCompletionStatus.ALL_COMPLETED); } /** Returns the result item at the given index. */ diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index 92fc7a02b..4fb9e1beb 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -175,9 +175,8 @@ private void checkpointSuccess(T result) { } var serialized = serializeResult(result); - var serializedBytes = serialized.getBytes(StandardCharsets.UTF_8); - if (serializedBytes.length < LARGE_RESULT_THRESHOLD) { + if (serialized == null || serialized.getBytes(StandardCharsets.UTF_8).length < LARGE_RESULT_THRESHOLD) { sendOperationUpdate( OperationUpdate.builder().action(OperationAction.SUCCEED).payload(serialized)); } else { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index dc5cbb2fb..7fd8fd5f7 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -33,8 +33,8 @@ *

    *
  • Does NOT register its own thread — child context threads handle all suspension *
  • Uses a pending queue + running counter for concurrency control - *
  • Completion is determined by {@code minSuccessful}, {@code failureRateThreshold} and - * {@code toleratedFailureCount} + *
  • Completion is determined by subclass-specific logic via abstract {@code canComplete()} and + * {@code validateItemCount()} *
  • When a child suspends, the running count is NOT decremented *
* @@ -45,9 +45,6 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class); private final int maxConcurrency; - private final int minSuccessful; - private final int toleratedFailureCount; - private final double failureRateThreshold; private final AtomicInteger succeededCount = new AtomicInteger(0); private final AtomicInteger failedCount = new AtomicInteger(0); private final AtomicInteger runningCount = new AtomicInteger(0); @@ -55,7 +52,6 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Queue> pendingQueue = new ConcurrentLinkedDeque<>(); private final List> childOperations = Collections.synchronizedList(new ArrayList<>()); private final Set completedOperations = Collections.synchronizedSet(new HashSet()); - protected ConcurrencyCompletionStatus completionStatus; private OperationIdGenerator operationIdGenerator; private final DurableContextImpl rootContext; @@ -64,38 +60,13 @@ protected ConcurrencyOperation( TypeToken resultTypeToken, SerDes resultSerDes, DurableContextImpl durableContext, - int maxConcurrency, - int minSuccessful, - int toleratedFailureCount, - double failureRateThreshold) { + int maxConcurrency) { super(operationIdentifier, resultTypeToken, resultSerDes, durableContext); this.maxConcurrency = maxConcurrency; - this.minSuccessful = minSuccessful; - this.toleratedFailureCount = toleratedFailureCount; - this.failureRateThreshold = failureRateThreshold; this.operationIdGenerator = new OperationIdGenerator(getOperationId()); this.rootContext = durableContext.createChildContextWithoutSettingThreadContext(getOperationId(), getName()); } - protected ConcurrencyOperation( - OperationIdentifier operationIdentifier, - TypeToken resultTypeToken, - SerDes resultSerDes, - DurableContextImpl durableContext, - int maxConcurrency, - int minSuccessful, - int toleratedFailureCount) { - this( - operationIdentifier, - resultTypeToken, - resultSerDes, - durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount, - 100); - } - // ========== Template methods for subclasses ========== /** @@ -118,11 +89,8 @@ protected abstract ChildContextOperation createItem( SerDes serDes, DurableContextImpl parentContext); - /** - * Called when the concurrency operation succeeds (minSuccessful threshold met). Subclasses define checkpointing - * behavior. - */ - protected abstract void handleSuccess(); + /** Called when the concurrency operation succeeds. Subclasses define checkpointing behavior. */ + protected abstract void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus); /** Called when the concurrency operation fails. Subclasses define checkpointing and exception behavior. */ protected abstract void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus); @@ -209,26 +177,38 @@ private void executeNextItemIfAllowed() { * @param child the child operation that completed */ public void onItemComplete(ChildContextOperation child) { - if (completedOperations.contains(child.getOperationId())) { + if (!completedOperations.add(child.getOperationId())) { return; - } else { - completedOperations.add(child.getOperationId()); } - runningCount.decrementAndGet(); + + // Evaluate child result outside the lock — child.get() may block waiting for a checkpoint response. logger.debug("OnItemComplete called by {}, Id: {}", child.getName(), child.getOperationId()); + boolean succeeded; try { child.get(); logger.debug("Result succeeded - {}", child.getName()); - succeededCount.incrementAndGet(); + succeeded = true; } catch (Throwable e) { - failedCount.incrementAndGet(); logger.debug("Child operation {} failed: {}", child.getOperationId(), e.getMessage()); + succeeded = false; } - if (canComplete()) { - handleComplete(); - } else { - executeNextItemIfAllowed(); + // Counter updates, completion check, and next-item dispatch must be atomic to prevent + // the main thread's join() from seeing runningCount==0 with incomplete counters. + synchronized (this) { + if (succeeded) { + succeededCount.incrementAndGet(); + } else { + failedCount.incrementAndGet(); + } + runningCount.decrementAndGet(); + + var status = canComplete(); + if (status != null) { + handleComplete(status); + } else { + executeNextItemIfAllowed(); + } } } @@ -239,57 +219,24 @@ public void onItemComplete(ChildContextOperation child) { * * @throws IllegalArgumentException if the item count cannot satisfy the criteria */ - protected void validateItemCount() { - int totalItems = childOperations.size(); - - if (minSuccessful > totalItems - failedCount.get()) { - throw new IllegalArgumentException("minSuccessful (" + minSuccessful - + ") exceeds the number of registered items (" + totalItems + ")"); - } - } + protected abstract void validateItemCount(); /** * Checks whether the concurrency operation can be considered complete. * - * @return true if enough items succeeded, too many failed, or not enough remaining items to reach minSuccessful + * @return the completion status if the operation is complete, or null if it should continue */ - protected boolean canComplete() { - int totalItems = childOperations.size(); - int succeeded = succeededCount.get(); - int failed = failedCount.get(); - - // If we've met the minimum successful count, we're done - if (minSuccessful != -1 && succeeded >= minSuccessful) { - completionStatus = ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - return true; - } - - // If we've exceeded the failure tolerance, we're done - if ((minSuccessful == -1 && failed > 0) - || failed > toleratedFailureCount - || (double) failed / totalItems > failureRateThreshold) { - completionStatus = ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - return true; - } - - // This will only happens when minSuccessful == -1 and user calls join() - if (isJoined.get() && minSuccessful == -1 && pendingQueue.isEmpty() && runningCount.get() == 0) { - completionStatus = ConcurrencyCompletionStatus.ALL_COMPLETED; - return true; - } + protected abstract ConcurrencyCompletionStatus canComplete(); - return false; - } - - private void handleComplete() { + private void handleComplete(ConcurrencyCompletionStatus status) { synchronized (this) { if (isOperationCompleted()) { return; } - if (completionStatus.isSucceeded()) { - handleSuccess(); + if (status.isSucceeded()) { + handleSuccess(status); } else { - handleFailure(completionStatus); + handleFailure(status); } } } @@ -305,8 +252,11 @@ protected void join() { return; } - if (canComplete()) { - handleComplete(); + synchronized (this) { + var status = canComplete(); + if (status != null) { + handleComplete(status); + } } waitForOperationCompletion(); diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index 170f15351..d9931a475 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -17,7 +17,6 @@ import software.amazon.lambda.durable.MapFunction; import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; -import software.amazon.lambda.durable.model.CompletionReason; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.MapError; import software.amazon.lambda.durable.model.MapResult; @@ -48,6 +47,8 @@ public class MapOperation extends ConcurrencyOperation> { private final SerDes serDes; private final CompletionConfig completionConfig; private boolean replayFromPayload; + private volatile MapResult cachedResult; + private ConcurrencyCompletionStatus completionStatus; public MapOperation( OperationIdentifier operationIdentifier, @@ -61,16 +62,7 @@ public MapOperation( new TypeToken<>() {}, config.serDes(), durableContext, - config.maxConcurrency() != null ? config.maxConcurrency() : -1, - config.completionConfig().minSuccessful() != null - ? config.completionConfig().minSuccessful() - : -1, - config.completionConfig().toleratedFailureCount() != null - ? config.completionConfig().toleratedFailureCount() - : Integer.MAX_VALUE, - config.completionConfig().toleratedFailurePercentage() != null - ? config.completionConfig().toleratedFailurePercentage() - : 100); + config.maxConcurrency() != null ? config.maxConcurrency() : -1); this.items = List.copyOf(items); this.function = function; this.itemResultType = itemResultType; @@ -142,12 +134,14 @@ private void addAllItems() { } @Override - protected void handleSuccess() { + protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { + this.completionStatus = concurrencyCompletionStatus; checkpointMapResult(); } @Override protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { + this.completionStatus = concurrencyCompletionStatus; checkpointMapResult(); } @@ -165,42 +159,39 @@ protected void validateItemCount() { * Map's default {@code allCompleted()} allows failures without early termination. */ @Override - protected boolean canComplete() { + protected ConcurrencyCompletionStatus canComplete() { int succeeded = getSucceededCount(); int failed = getFailedCount(); int totalCompleted = succeeded + failed; // Check minSuccessful if (completionConfig.minSuccessful() != null && succeeded >= completionConfig.minSuccessful()) { - completionStatus = ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; - return true; + return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; } // Check toleratedFailureCount if (completionConfig.toleratedFailureCount() != null && failed > completionConfig.toleratedFailureCount()) { - completionStatus = ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - return true; + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; } // Check toleratedFailurePercentage if (completionConfig.toleratedFailurePercentage() != null && totalCompleted > 0 && ((double) failed / totalCompleted) > completionConfig.toleratedFailurePercentage()) { - completionStatus = ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; - return true; + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; } // All items finished (no pending, no running) — complete with ALL_COMPLETED if (isAllItemsFinished()) { - completionStatus = ConcurrencyCompletionStatus.ALL_COMPLETED; - return true; + return ConcurrencyCompletionStatus.ALL_COMPLETED; } - return false; + return null; } private void checkpointMapResult() { var result = aggregateResults(); + this.cachedResult = result; var serialized = serializeResult(result); var serializedBytes = serialized.getBytes(java.nio.charset.StandardCharsets.UTF_8); @@ -230,7 +221,7 @@ public MapResult get() { } // First execution or large result replay: wait for children, then aggregate join(); - return aggregateResults(); + return cachedResult != null ? cachedResult : aggregateResults(); } /** @@ -262,24 +253,7 @@ private MapResult aggregateResults() { resultItems.set(i, MapResultItem.notStarted()); } - return new MapResult<>(resultItems, toCompletionReason()); - } - - private CompletionReason toCompletionReason() { - if (completionConfig.minSuccessful() != null && getSucceededCount() >= completionConfig.minSuccessful()) { - return CompletionReason.MIN_SUCCESSFUL_REACHED; - } - if (completionConfig.toleratedFailureCount() != null - && getFailedCount() > completionConfig.toleratedFailureCount()) { - return CompletionReason.FAILURE_TOLERANCE_EXCEEDED; - } - if (completionConfig.toleratedFailurePercentage() != null && getFailedCount() > 0) { - int total = getSucceededCount() + getFailedCount(); - if (total > 0 && ((double) getFailedCount() / total) > completionConfig.toleratedFailurePercentage()) { - return CompletionReason.FAILURE_TOLERANCE_EXCEEDED; - } - } - return CompletionReason.ALL_COMPLETED; + return new MapResult<>(resultItems, completionStatus); } private static MapError buildMapError(Exception e) { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index 64e937797..a68e3167e 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -43,6 +43,8 @@ */ public class ParallelOperation extends ConcurrencyOperation { + private final int minSuccessful; + private final int toleratedFailureCount; private boolean skipCheckpoint = false; public ParallelOperation( @@ -53,14 +55,9 @@ public ParallelOperation( int maxConcurrency, int minSuccessful, int toleratedFailureCount) { - super( - operationIdentifier, - resultTypeToken, - resultSerDes, - durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); + super(operationIdentifier, resultTypeToken, resultSerDes, durableContext, maxConcurrency); + this.minSuccessful = minSuccessful; + this.toleratedFailureCount = toleratedFailureCount; } @Override @@ -81,7 +78,7 @@ protected ChildContextOperation createItem( } @Override - protected void handleSuccess() { + protected void handleSuccess(ConcurrencyCompletionStatus concurrencyCompletionStatus) { if (skipCheckpoint) { // Do not send checkpoint during replay markAlreadyCompleted(); @@ -95,7 +92,7 @@ protected void handleSuccess() { @Override protected void handleFailure(ConcurrencyCompletionStatus concurrencyCompletionStatus) { - handleSuccess(); + handleSuccess(concurrencyCompletionStatus); } @Override @@ -118,4 +115,35 @@ public T get() { join(); return null; } + + @Override + protected void validateItemCount() { + if (minSuccessful > getTotalItems() - getFailedCount()) { + throw new IllegalArgumentException("minSuccessful (" + minSuccessful + + ") exceeds the number of registered items (" + getTotalItems() + ")"); + } + } + + @Override + protected ConcurrencyCompletionStatus canComplete() { + int succeeded = getSucceededCount(); + int failed = getFailedCount(); + + // If we've met the minimum successful count, we're done + if (minSuccessful != -1 && succeeded >= minSuccessful) { + return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; + } + + // If we've exceeded the failure tolerance, we're done + if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; + } + + // All items finished — complete + if (isAllItemsFinished()) { + return ConcurrencyCompletionStatus.ALL_COMPLETED; + } + + return null; + } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java index 86e7e3476..09d97e6d2 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/model/MapResultTest.java @@ -19,7 +19,7 @@ void empty_returnsZeroSizeResult() { assertEquals(0, result.size()); assertTrue(result.allSucceeded()); - assertEquals(CompletionReason.ALL_COMPLETED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.ALL_COMPLETED, result.completionReason()); assertTrue(result.results().isEmpty()); assertTrue(result.succeeded().isEmpty()); assertTrue(result.failed().isEmpty()); @@ -28,7 +28,8 @@ void empty_returnsZeroSizeResult() { @Test void allSucceeded_trueWhenNoErrors() { var result = new MapResult<>( - List.of(MapResultItem.success("a"), MapResultItem.success("b")), CompletionReason.ALL_COMPLETED); + List.of(MapResultItem.success("a"), MapResultItem.success("b")), + ConcurrencyCompletionStatus.ALL_COMPLETED); assertTrue(result.allSucceeded()); assertEquals(2, result.size()); @@ -43,7 +44,7 @@ void allSucceeded_falseWhenAnyError() { var error = testError("fail"); var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.failure(error)), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); assertFalse(result.allSucceeded()); } @@ -53,7 +54,7 @@ void getResult_returnsNullForFailedItem() { var error = testError("fail"); var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.failure(error)), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals("a", result.getResult(0)); assertNull(result.getResult(1)); @@ -64,7 +65,7 @@ void getError_returnsNullForSucceededItem() { var error = testError("fail"); var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.failure(error)), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); assertNull(result.getError(0)); assertSame(error, result.getError(1)); @@ -77,7 +78,7 @@ void succeeded_filtersNullResults() { MapResultItem.success("a"), MapResultItem.failure(testError("fail")), MapResultItem.success("c")), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(List.of("a", "c"), result.succeeded()); } @@ -87,7 +88,7 @@ void failed_filtersNullErrors() { var error = testError("fail"); var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.failure(error), MapResultItem.success("c")), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); var failures = result.failed(); assertEquals(1, failures.size()); @@ -96,14 +97,15 @@ void failed_filtersNullErrors() { @Test void completionReason_preserved() { - var result = new MapResult<>(List.of(MapResultItem.success("a")), CompletionReason.MIN_SUCCESSFUL_REACHED); + var result = new MapResult<>( + List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); - assertEquals(CompletionReason.MIN_SUCCESSFUL_REACHED, result.completionReason()); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionReason()); } @Test void items_returnsUnmodifiableList() { - var result = new MapResult<>(List.of(MapResultItem.success("a")), CompletionReason.ALL_COMPLETED); + var result = new MapResult<>(List.of(MapResultItem.success("a")), ConcurrencyCompletionStatus.ALL_COMPLETED); assertThrows(UnsupportedOperationException.class, () -> result.items().add(MapResultItem.success("b"))); } @@ -112,7 +114,7 @@ void items_returnsUnmodifiableList() { void getItem_returnsMapResultItem() { var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.failure(testError("fail"))), - CompletionReason.ALL_COMPLETED); + ConcurrencyCompletionStatus.ALL_COMPLETED); assertEquals(MapResultItem.Status.SUCCEEDED, result.getItem(0).status()); assertEquals("a", result.getItem(0).result()); @@ -127,7 +129,7 @@ void getItem_returnsMapResultItem() { void notStartedItems_haveNotStartedStatusAndNullResultAndError() { var result = new MapResult<>( List.of(MapResultItem.success("a"), MapResultItem.notStarted()), - CompletionReason.MIN_SUCCESSFUL_REACHED); + ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED); assertEquals(MapResultItem.Status.NOT_STARTED, result.getItem(1).status()); assertNull(result.getResult(1)); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java index 7ae7ac718..8df56e796 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/BaseDurableOperationTest.java @@ -63,9 +63,6 @@ void setUp() { when(durableContext.getExecutionManager()).thenReturn(executionManager); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT)); when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(OPERATION); - // Stub runUntilCompleteOrSuspend to pass through the user future — in unit tests there's - // no executionExceptionFuture to race against, so just wait on the completionFuture directly. - when(executionManager.runUntilCompleteOrSuspend(any())).thenAnswer(invocation -> invocation.getArgument(0)); } @Test diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index f0aa4c235..67f7c2ebe 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -213,6 +213,8 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { private boolean failureHandled = false; private final AtomicInteger executingCount = new AtomicInteger(0); private DurableContextImpl lastParentContext; + private final int minSuccessful; + private final int toleratedFailureCount; TestConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -222,14 +224,9 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { int maxConcurrency, int minSuccessful, int toleratedFailureCount) { - super( - operationIdentifier, - resultTypeToken, - resultSerDes, - durableContext, - maxConcurrency, - minSuccessful, - toleratedFailureCount); + super(operationIdentifier, resultTypeToken, resultSerDes, durableContext, maxConcurrency); + this.minSuccessful = minSuccessful; + this.toleratedFailureCount = toleratedFailureCount; } @Override @@ -257,7 +254,7 @@ public void execute() { } @Override - protected void handleSuccess() { + protected void handleSuccess(ConcurrencyCompletionStatus completionStatus) { successHandled = true; // Simulate the checkpoint ACK that a real subclass would receive after sendOperationUpdate. // This drives completionFuture to completion so waitForOperationCompletion() unblocks. @@ -282,6 +279,34 @@ protected void start() {} @Override protected void replay(Operation existing) {} + @Override + protected void validateItemCount() { + if (minSuccessful > getTotalItems() - getFailedCount()) { + throw new IllegalArgumentException("minSuccessful (" + minSuccessful + + ") exceeds the number of registered items (" + getTotalItems() + ")"); + } + } + + @Override + protected ConcurrencyCompletionStatus canComplete() { + int succeeded = getSucceededCount(); + int failed = getFailedCount(); + + if (minSuccessful != -1 && succeeded >= minSuccessful) { + return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED; + } + + if ((minSuccessful == -1 && failed > 0) || failed > toleratedFailureCount) { + return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED; + } + + if (isAllItemsFinished()) { + return ConcurrencyCompletionStatus.ALL_COMPLETED; + } + + return null; + } + @Override public Void get() { return null;