From be04364a1204aa335dac2c8d063aab12971b2616 Mon Sep 17 00:00:00 2001 From: yaythomas Date: Thu, 29 Jan 2026 13:36:50 -0800 Subject: [PATCH 1/2] feat: Add ExecutionContext with durable execution ARN Add immutable ExecutionContext dataclass to provide readonly access to execution-level metadata. The context contains the durable execution ARN which uniquely identifies the execution instance within AWS. Changes: - Create ExecutionContext frozen dataclass with durable_execution_arn field - Add execution_context parameter to DurableContext constructor - Update from_lambda_context() to create ExecutionContext from state - Update create_child_context() to propagate execution_context - Export ExecutionContext in public API - Add create_test_context() helper in test files for easier instantiation - Update all test files to use new helper function The ExecutionContext is automatically created in from_lambda_context() using the state's durable_execution_arn, keeping the API simple while maintaining immutability and thread safety. closes #283 --- .../context.py | 21 ++ tests/context_test.py | 214 +++++++++++++----- tests/operation/map_test.py | 31 ++- tests/operation/parallel_test.py | 31 ++- tests/test_helpers.py | 7 +- 5 files changed, 237 insertions(+), 67 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 8efaed0..46d48a1 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -2,6 +2,7 @@ import hashlib import logging +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar from aws_durable_execution_sdk_python.config import ( @@ -74,6 +75,20 @@ PASS_THROUGH_SERDES: SerDes[Any] = PassThroughSerDes() +@dataclass(frozen=True) +class ExecutionContext: + """Readonly metadata about the current durable execution context. + + This class provides immutable access to execution-level metadata. + + Attributes: + durable_execution_arn: The Amazon Resource Name (ARN) of the current + durable execution. + """ + + durable_execution_arn: str + + def durable_step( func: Callable[Concatenate[StepContext, Params], T], ) -> Callable[Params, Callable[[StepContext], T]]: @@ -218,11 +233,13 @@ class DurableContext(DurableContextProtocol): def __init__( self, state: ExecutionState, + execution_context: ExecutionContext, lambda_context: LambdaContext | None = None, parent_id: str | None = None, logger: Logger | None = None, ) -> None: self.state: ExecutionState = state + self.execution_context: ExecutionContext = execution_context self.lambda_context = lambda_context self._parent_id: str | None = parent_id self._step_counter: OrderedCounter = OrderedCounter() @@ -245,6 +262,9 @@ def from_lambda_context( ): return DurableContext( state=state, + execution_context=ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ), lambda_context=lambda_context, parent_id=None, ) @@ -254,6 +274,7 @@ def create_child_context(self, parent_id: str) -> DurableContext: logger.debug("Creating child context for parent %s", parent_id) return DurableContext( state=self.state, + execution_context=self.execution_context, lambda_context=self.lambda_context, parent_id=parent_id, logger=self.logger.with_log_info( diff --git a/tests/context_test.py b/tests/context_test.py index 4e43347..507cfc5 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -16,7 +16,11 @@ ParallelConfig, StepConfig, ) -from aws_durable_execution_sdk_python.context import Callback, DurableContext +from aws_durable_execution_sdk_python.context import ( + Callback, + DurableContext, + ExecutionContext, +) from aws_durable_execution_sdk_python.exceptions import ( CallbackError, SuspendExecution, @@ -39,6 +43,24 @@ from tests.test_helpers import operation_id_sequence +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_durable_context(): """Test the context module.""" assert DurableContext is not None @@ -250,7 +272,7 @@ def test_create_callback_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -282,7 +304,7 @@ def test_create_callback_with_name_and_config(mock_executor_class): ) config = CallbackConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() [next(operation_ids) for _ in range(5)] # Skip 5 IDs expected_operation_id = next(operation_ids) # Get the 6th ID @@ -315,7 +337,7 @@ def test_create_callback_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") operation_ids = operation_id_sequence("parent123") [next(operation_ids) for _ in range(2)] # Skip 2 IDs expected_operation_id = next(operation_ids) # Get the 3rd ID @@ -345,7 +367,7 @@ def test_create_callback_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 callback1 = context.create_callback() @@ -383,7 +405,7 @@ def test_step_basic(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -418,7 +440,7 @@ def test_step_with_name_and_config(mock_executor_class): ) # Ensure Mock doesn't have _original_name config = StepConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 result = context.step(mock_callable, config=config) @@ -456,7 +478,7 @@ def test_step_with_parent_id(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.step(mock_callable) @@ -493,7 +515,7 @@ def test_step_increments_counter(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.step(mock_callable) @@ -529,7 +551,7 @@ def test_step_with_original_name(mock_executor_class): mock_callable = Mock() mock_callable._original_name = "original_function" # noqa: SLF001 - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.step(mock_callable, name="override_name") @@ -564,7 +586,7 @@ def test_invoke_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -596,7 +618,7 @@ def test_invoke_with_name_and_config(mock_executor_class): ) config = InvokeConfig[str, str](timeout=Duration.from_seconds(30)) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 result = context.invoke( @@ -632,7 +654,7 @@ def test_invoke_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.invoke("test_function", None) @@ -664,7 +686,7 @@ def test_invoke_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.invoke("function1", "payload1") @@ -697,7 +719,7 @@ def test_invoke_with_none_payload(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", None) @@ -737,7 +759,7 @@ def test_invoke_with_custom_serdes(mock_executor_class): timeout=Duration.from_minutes(1), ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke( "test_function", @@ -778,7 +800,7 @@ def test_wait_basic(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -804,7 +826,7 @@ def test_wait_with_name(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 context.wait(Duration.from_minutes(1), name="test_wait") @@ -833,7 +855,7 @@ def test_wait_with_parent_id(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state, parent_id="parent123") + context = create_test_context(state=mock_state, parent_id="parent123") [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 context.wait(Duration.from_seconds(45)) @@ -862,7 +884,7 @@ def test_wait_increments_counter(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 context.wait(Duration.from_seconds(15)) @@ -894,7 +916,7 @@ def test_wait_returns_none(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait(Duration.from_seconds(10)) @@ -913,7 +935,7 @@ def test_wait_with_time_less_than_one(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) with pytest.raises(ValidationError): context.wait(Duration.from_seconds(0)) @@ -936,7 +958,7 @@ def test_run_in_child_context_basic(mock_handler): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) operation_ids = operation_id_sequence() expected_operation_id = next(operation_ids) @@ -967,7 +989,7 @@ def test_run_in_child_context_with_name_and_config(mock_handler): config = ChildConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(3)] # Set counter to 3 # noqa: SLF001 result = context.run_in_child_context(mock_callable, config=config) @@ -1001,7 +1023,7 @@ def test_run_in_child_context_with_parent_id(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure Mock doesn't have _original_name - context = DurableContext(state=mock_state, parent_id="parent456") + context = create_test_context(state=mock_state, parent_id="parent456") [context._create_step_id() for _ in range(1)] # Set counter to 1 # noqa: SLF001 context.run_in_child_context(mock_callable) @@ -1037,7 +1059,7 @@ def capture_child_context(child_context): mock_callable = Mock(side_effect=capture_child_context) mock_executor_class.side_effect = lambda func, **kwargs: func() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.run_in_child_context(mock_callable) @@ -1062,7 +1084,7 @@ def test_run_in_child_context_increments_counter(mock_executor_class): mock_callable._original_name # noqa: SLF001 ) # Ensure _original_name doesn't exist - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 context.run_in_child_context(mock_callable) @@ -1097,7 +1119,7 @@ def test_run_in_child_context_resolves_name_from_callable(mock_executor_class): mock_callable = Mock() mock_callable._original_name = "original_function_name" # noqa: SLF001 - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.run_in_child_context(mock_callable) @@ -1128,7 +1150,7 @@ def test_wait_for_callback_basic(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter) @@ -1158,7 +1180,7 @@ def test_wait_for_callback_with_name_and_config(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "configured_callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter, config=config) @@ -1186,7 +1208,7 @@ def test_wait_for_callback_resolves_name_from_submitter(mock_executor_class): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "named_callback_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.wait_for_callback(mock_submitter) @@ -1214,11 +1236,11 @@ def capture_handler_call(context, submitter, name, config): def run_child_context(callable_func, name): # Execute the child context callable - child_context = DurableContext(state=mock_state, parent_id="test") + child_context = create_test_context(state=mock_state, parent_id="test") return callable_func(child_context) mock_run_in_child.side_effect = run_child_context - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.wait_for_callback(mock_submitter) @@ -1244,7 +1266,7 @@ def test_function(context, item, index, items): inputs = [1, 2, 3] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1273,7 +1295,7 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function, name="custom_map", config=config) @@ -1298,7 +1320,7 @@ def test_function(context, item, index, items): inputs = ["hello", "world"] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1322,7 +1344,7 @@ def test_function(context, item, index, items): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "empty_map_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1345,7 +1367,7 @@ def test_function(context, item, index, items): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "mixed_map_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function) @@ -1373,7 +1395,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1403,7 +1425,7 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables, name="custom_parallel", config=config) @@ -1435,7 +1457,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) # Use _resolve_step_name to test name resolution resolved_name = context._resolve_step_name(None, mock_callable) # noqa: SLF001 @@ -1466,7 +1488,7 @@ def task2(context): callables = [task1, task2] - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1487,7 +1509,7 @@ def test_parallel_with_empty_callables(mock_handler): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "empty_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1510,7 +1532,7 @@ def single_task(context): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "single_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1536,7 +1558,7 @@ def task(context): with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: mock_run_in_child.return_value = "many_parallel_result" - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables) @@ -1562,7 +1584,7 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(inputs, test_function, config=config) @@ -1588,7 +1610,7 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel(callables, config=config) @@ -1603,7 +1625,7 @@ def test_wait_for_condition_validation_errors(): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) def dummy_wait_strategy(state, attempt): return None @@ -1640,7 +1662,7 @@ def test_function(context, item, index, items): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Mock the handlers to track calls with patch( @@ -1680,7 +1702,7 @@ def test_callable_2(context): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Mock the handlers to track calls with patch( @@ -1719,7 +1741,7 @@ def test_wait_strategy(state, attempt): state = Mock() state.durable_execution_arn = "test_arn" - context = DurableContext(state=state) + context = create_test_context(state=state) # Create config config = WaitForConditionConfig( @@ -1820,7 +1842,7 @@ def test_invoke_with_explicit_tenant_id(mock_executor_class): ) config = InvokeConfig(tenant_id="explicit-tenant") - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", "payload", config=config) @@ -1842,7 +1864,7 @@ def test_invoke_without_tenant_id_defaults_to_none(mock_executor_class): "arn:aws:durable:us-east-1:123456789012:execution/test" ) - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.invoke("test_function", "payload") @@ -1851,3 +1873,89 @@ def test_invoke_without_tenant_id_defaults_to_none(mock_executor_class): call_args = mock_executor_class.call_args[1] assert isinstance(call_args["config"], InvokeConfig) assert call_args["config"].tenant_id is None + + +# region ExecutionContext tests + + +def test_execution_context_exists_on_durable_context(): + """Test that DurableContext has execution_context attribute.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test-execution" + ) + + context = create_test_context(state=mock_state) + + assert hasattr(context, "execution_context") + assert context.execution_context is not None + + +def test_execution_context_has_correct_arn(): + """Test that ExecutionContext contains the correct durable_execution_arn.""" + expected_arn = "arn:aws:durable:us-west-2:987654321098:execution/my-execution" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = expected_arn + + context = create_test_context(state=mock_state) + + assert context.execution_context.durable_execution_arn == expected_arn + + +def test_execution_context_is_immutable(): + """Test that ExecutionContext is frozen and immutable.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = create_test_context(state=mock_state) + + # Attempt to modify should raise FrozenInstanceError for frozen dataclass + with pytest.raises(AttributeError, match="cannot assign to field"): + context.execution_context.durable_execution_arn = "new-arn" + + +def test_execution_context_propagates_to_child_context(): + """Test that child contexts inherit the same execution_context.""" + parent_arn = "arn:aws:durable:eu-west-1:111222333444:execution/parent-exec" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = parent_arn + + parent_context = create_test_context(state=mock_state) + child_context = parent_context.create_child_context(parent_id="parent-op-123") + + assert child_context.execution_context is not None + assert child_context.execution_context.durable_execution_arn == parent_arn + # Should be the same instance (not a copy) + assert child_context.execution_context is parent_context.execution_context + + +def test_from_lambda_context_creates_execution_context(): + """Test that from_lambda_context factory creates ExecutionContext.""" + expected_arn = "arn:aws:durable:ap-south-1:555666777888:execution/lambda-exec" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = expected_arn + mock_lambda_context = Mock() + + context = DurableContext.from_lambda_context( + state=mock_state, lambda_context=mock_lambda_context + ) + + assert context.execution_context is not None + assert context.execution_context.durable_execution_arn == expected_arn + + +def test_execution_context_type(): + """Test that execution_context is of type ExecutionContext.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = create_test_context(state=mock_state) + + assert isinstance(context.execution_context, ExecutionContext) + + +# endregion ExecutionContext tests diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index 5c5a5a1..69d2f31 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -19,15 +19,34 @@ ItemBatcher, MapConfig, ) -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child # PLC0415 from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from aws_durable_execution_sdk_python.serdes import serialize +from aws_durable_execution_sdk_python.state import ExecutionState from tests.serdes_test import CustomStrSerDes +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_map_executor_init(): """Test MapExecutor initialization.""" executables = [Executable(index=0, func=lambda: None)] @@ -808,7 +827,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.map( ["a", "b"], lambda ctx, item, idx, items: item, @@ -870,7 +889,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.map( ["a", "b"], lambda ctx, item, idx, items: item, @@ -970,7 +989,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) assert len(mock_serdes_serialize.call_args_list) == 3 @@ -1022,7 +1041,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map(["a", "b"], lambda ctx, item, idx, items: item) assert isinstance(result, BatchResult) @@ -1078,7 +1097,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.map( ["a", "b"], lambda ctx, item, idx, items: item, diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index c43be7e..5021788 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -17,7 +17,7 @@ Executable, ) from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation import child @@ -26,9 +26,28 @@ parallel_handler, ) from aws_durable_execution_sdk_python.serdes import serialize +from aws_durable_execution_sdk_python.state import ExecutionState from tests.serdes_test import CustomStrSerDes +def create_test_context( + state: ExecutionState | None = None, parent_id: str | None = None +) -> DurableContext: + """Helper to create DurableContext for tests with required execution_context.""" + if state is None: + state = Mock(spec=ExecutionState) + state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + execution_context = ExecutionContext( + durable_execution_arn=state.durable_execution_arn + ) + return DurableContext( + state=state, execution_context=execution_context, parent_id=parent_id + ) + + def test_parallel_executor_init(): """Test ParallelExecutor initialization.""" executables = [Executable(index=0, func=lambda x: x)] @@ -791,7 +810,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), @@ -852,7 +871,7 @@ def create_id(self, i): ) with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), @@ -964,7 +983,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) assert len(mock_serdes_serialize.call_args_list) == 3 @@ -1015,7 +1034,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel([lambda ctx: "a", lambda ctx: "b"]) assert isinstance(result, BatchResult) @@ -1071,7 +1090,7 @@ def create_id(self, i): with patch.object( DurableContext, "_create_step_id_for_logical_step", create_id ): - context = DurableContext(state=mock_state) + context = create_test_context(state=mock_state) result = context.parallel( [lambda ctx: "a", lambda ctx: "b"], config=ParallelConfig(serdes=custom_serdes), diff --git a/tests/test_helpers.py b/tests/test_helpers.py index dca15a0..77611a3 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.execution import ExecutionState @@ -11,7 +11,10 @@ def operation_id_sequence(parent_id: str | None = None): mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test-arn" - context = DurableContext(state=mock_state, parent_id=parent_id) + execution_context = ExecutionContext(durable_execution_arn="test-arn") + context = DurableContext( + state=mock_state, execution_context=execution_context, parent_id=parent_id + ) while True: yield context._create_step_id() # noqa: SLF001 From 77fa1f661da780bd93201fad7d15b0027ce30090 Mon Sep 17 00:00:00 2001 From: yaythomas Date: Thu, 29 Jan 2026 13:45:19 -0800 Subject: [PATCH 2/2] refactor: Import concrete BatchResult Change __init__.py to import BatchResult from concurrency.models (concrete implementation) rather than from types (protocol definition). This provides users with the actual class implementation for better IDE support and type checking. The protocol remains in types.py for internal type checking, while the public API exports the concrete dataclass implementation. closes #284 --- src/aws_durable_execution_sdk_python/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 1a24d31..0514767 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -2,6 +2,8 @@ # Main context - used in every durable function # Helper decorators - commonly used for step functions +# Concurrency +from aws_durable_execution_sdk_python.concurrency.models import BatchResult from aws_durable_execution_sdk_python.context import ( DurableContext, durable_step, @@ -20,7 +22,7 @@ from aws_durable_execution_sdk_python.execution import durable_execution # Essential context types - passed to user functions -from aws_durable_execution_sdk_python.types import BatchResult, StepContext +from aws_durable_execution_sdk_python.types import StepContext __all__ = [ "BatchResult",