Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: uv build

- name: Upload distributions
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: release-dists
path: dist/
Expand All @@ -40,7 +40,7 @@ jobs:

steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: release-dists
path: dist/
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Changelog

## [0.3.25](https://github.com/a2aproject/a2a-python/compare/v0.3.24...v0.3.25) (2026-03-10)


### Features

* Implement a vertex based task store ([#752](https://github.com/a2aproject/a2a-python/issues/752)) ([fa14dbf](https://github.com/a2aproject/a2a-python/commit/fa14dbf46b603f288a1f1c474401483bf53950e4))


### Bug Fixes

* return background task from consume_and_break_on_interrupt to prevent GC ([#775](https://github.com/a2aproject/a2a-python/issues/775)) ([a236d4d](https://github.com/a2aproject/a2a-python/commit/a236d4df8dceb2db1e1170e0b57599f3837ebd71))
* use default_factory for mutable field defaults in ServerCallContext ([#744](https://github.com/a2aproject/a2a-python/issues/744)) ([22b25d6](https://github.com/a2aproject/a2a-python/commit/22b25d653e57e2d1453bbc282052e51dbd904ac6))

## [0.3.24](https://github.com/a2aproject/a2a-python/compare/v0.3.23...v0.3.24) (2026-02-20)


Expand Down
Empty file added src/a2a/contrib/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions src/a2a/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class ServerCallContext(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

state: State = Field(default={})
user: User = Field(default=UnauthenticatedUser())
state: State = Field(default_factory=dict)
user: User = Field(default_factory=UnauthenticatedUser)
tenant: str = Field(default='')
requested_extensions: set[str] = Field(default_factory=set)
activated_extensions: set[str] = Field(default_factory=set)
5 changes: 5 additions & 0 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,17 @@ async def push_notification_callback(event: Event) -> None:
(
result,
interrupted_or_non_blocking,
bg_consume_task,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer,
blocking=blocking,
event_callback=push_notification_callback,
)

if bg_consume_task is not None:
bg_consume_task.set_name(f'continue_consuming:{task_id}')
self._track_background_task(bg_consume_task)

except Exception:
logger.exception('Agent execution failed')
producer_task.cancel()
Expand Down
17 changes: 12 additions & 5 deletions src/a2a/server/tasks/result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt(
consumer: EventConsumer,
blocking: bool = True,
event_callback: Callable[[Event], Awaitable[None]] | None = None,
) -> tuple[Task | Message | None, bool]:
) -> tuple[Task | Message | None, bool, asyncio.Task | None]:
"""Processes the event stream until completion or an interruptible state is encountered.

If `blocking` is False, it returns after the first event that creates a Task or Message.
Expand All @@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt(
A tuple containing:
- The current aggregated result (`Task` or `Message`) at the point of completion or interruption.
- A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`).
- The background ``asyncio.Task`` that continues consuming events
after an interruption, or ``None`` when no background work was
spawned. **Callers must hold a strong reference** to this task
(e.g. in a ``set``) to prevent the garbage collector from
collecting it before it finishes — the event loop only keeps
weak references to tasks.

Raises:
BaseException: If the `EventConsumer` raises an exception during consumption.
"""
event_stream = consumer.consume_all()
interrupted = False
bg_task: asyncio.Task | None = None
async for event in event_stream:
if isinstance(event, Message):
self._message = event
return event, False
return event, False, None
await self.task_manager.process(event)

if event_callback:
Expand Down Expand Up @@ -161,13 +168,13 @@ async def consume_and_break_on_interrupt(

if should_interrupt:
# Continue consuming the rest of the events in the background.
# TODO: We should track all outstanding tasks to ensure they eventually complete.
asyncio.create_task( # noqa: RUF006
# The caller is responsible for tracking this task to prevent GC.
bg_task = asyncio.create_task(
self._continue_consuming(event_stream, event_callback)
)
interrupted = True
break
return await self.task_manager.get_task(), interrupted
return await self.task_manager.get_task(), interrupted, bg_task

async def _continue_consuming(
self,
Expand Down
12 changes: 9 additions & 3 deletions tck/sut_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DefaultRequestHandler,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.tasks.task_store import TaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
Expand Down Expand Up @@ -128,8 +129,8 @@ async def execute(
await event_queue.enqueue_event(final_update)


def main() -> None:
"""Main entrypoint."""
def serve(task_store: TaskStore) -> None:
"""Sets up the A2A service and starts the HTTP server."""
http_port = int(os.environ.get('HTTP_PORT', '41241'))

agent_card = AgentCard(
Expand Down Expand Up @@ -168,7 +169,7 @@ def main() -> None:

request_handler = DefaultRequestHandler(
agent_executor=SUTAgentExecutor(),
task_store=InMemoryTaskStore(),
task_store=task_store,
)

server = A2AStarletteApplication(
Expand All @@ -182,5 +183,10 @@ def main() -> None:
uvicorn.run(app, host='127.0.0.1', port=http_port, log_level='info')


def main() -> None:
"""Main entrypoint."""
serve(InMemoryTaskStore())


if __name__ == '__main__':
main()
Empty file added tests/contrib/__init__.py
Empty file.
12 changes: 11 additions & 1 deletion tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ async def test_on_message_send_with_push_notification():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result async property to return the final task result
Expand Down Expand Up @@ -643,6 +644,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
initial_task,
True, # interrupted = True for non-blocking
MagicMock(spec=asyncio.Task), # background task
)

# Mock the current_result async property to return the final task
Expand All @@ -666,7 +668,11 @@ async def mock_consume_and_break_on_interrupt(
event_callback_received = event_callback
if event_callback_received:
await event_callback_received(final_task)
return initial_task, True # interrupted = True for non-blocking
return (
initial_task,
True,
MagicMock(spec=asyncio.Task),
) # interrupted = True for non-blocking

mock_result_aggregator_instance.consume_and_break_on_interrupt = (
mock_consume_and_break_on_interrupt
Expand Down Expand Up @@ -758,6 +764,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result async property to return the final task result
Expand Down Expand Up @@ -815,6 +822,7 @@ async def test_on_message_send_no_result_from_aggregator():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
None,
False,
None,
)

with (
Expand Down Expand Up @@ -864,6 +872,7 @@ async def test_on_message_send_task_id_mismatch():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
mismatched_task,
False,
None,
)

with (
Expand Down Expand Up @@ -1069,6 +1078,7 @@ async def test_on_message_send_interrupted_flow():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
interrupt_task_result,
True,
MagicMock(spec=asyncio.Task), # background task
) # Interrupted = True

# Collect coroutines passed to create_task so we can close them
Expand Down
6 changes: 3 additions & 3 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ async def test_on_message_new_message_success(

with patch(
'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt',
return_value=(mock_task, False),
return_value=(mock_task, False, None),
):
request = SendMessageRequest(
message=create_message(
Expand Down Expand Up @@ -352,7 +352,7 @@ async def test_on_message_new_message_with_existing_task_success(

with patch(
'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt',
return_value=(mock_task, False),
return_value=(mock_task, False, None),
):
request = SendMessageRequest(
message=create_message(
Expand Down Expand Up @@ -1021,7 +1021,7 @@ async def test_on_message_send_task_id_mismatch(self) -> None:
# Task returned has task_id='task_123' but request_context will have generated UUID
with patch(
'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt',
return_value=(mock_task, False),
return_value=(mock_task, False, None),
):
request = SendMessageRequest(
message=create_message(), # No task_id, so UUID is generated
Expand Down
12 changes: 11 additions & 1 deletion tests/server/tasks/test_result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, sample_message)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly
# _continue_consuming should not be called if it's a message interrupt
# and no auth_required state.
Expand Down Expand Up @@ -268,12 +270,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, auth_task)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(auth_task)
mock_create_task.assert_called_once() # Check that create_task was called
# self.aggregator._continue_consuming is an AsyncMock.
Expand Down Expand Up @@ -322,12 +326,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, current_task_state_after_update)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(
auth_status_update
)
Expand Down Expand Up @@ -358,13 +364,15 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

# If the first event is a Message, it's returned directly.
self.assertEqual(result, event1)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
# process() is NOT called for the Message if it's the one causing the return
self.mock_task_manager.process.assert_not_called()
self.mock_task_manager.get_task.assert_not_called()
Expand Down Expand Up @@ -420,12 +428,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer, blocking=False
)

self.assertEqual(result, first_event)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(first_event)
mock_create_task.assert_called_once()
# The background task should be created with the remaining stream
Expand Down Expand Up @@ -474,7 +484,7 @@ async def initial_consume_generator():
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)

# Call the main method that triggers _continue_consuming via create_task
_, _ = await self.aggregator.consume_and_break_on_interrupt(
_, _, _ = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

Expand Down
Loading
Loading