Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 39 additions & 71 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from collections.abc import AsyncGenerator, AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types.a2a_pb2 import (
AgentCard,
Expand All @@ -23,8 +21,6 @@
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
Message,
SendMessageConfiguration,
SendMessageRequest,
StreamResponse,
SubscribeToTaskRequest,
Expand All @@ -51,12 +47,9 @@

async def send_message(
self,
request: Message,
request: SendMessageRequest,
*,
configuration: SendMessageConfiguration | None = None,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the agent.

Expand All @@ -66,35 +59,15 @@

Args:
request: The message to send to the agent.
configuration: Optional per-call overrides for message sending behavior.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent`
"""
config = SendMessageConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
push_notification_config=(
self._config.push_notification_configs[0]
if self._config.push_notification_configs
else None
),
)

if configuration:
config.MergeFrom(configuration)
config.blocking = configuration.blocking

send_message_request = SendMessageRequest(
message=request, configuration=config, metadata=request_metadata
)

self._apply_client_config(request)
if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
send_message_request, context=context, extensions=extensions
request, context=context
)

# In non-streaming case we convert to a StreamResponse so that the
Expand All @@ -116,11 +89,29 @@
return

stream = self._transport.send_message_streaming(
send_message_request, context=context, extensions=extensions
request, context=context
)
async for client_event in self._process_stream(stream):
yield client_event

def _apply_client_config(self, request: SendMessageRequest):

Check failure on line 97 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (ANN202)

src/a2a/client/base_client.py:97:9: ANN202 Missing return type annotation for private function `_apply_client_config`
if not request.configuration.blocking and self._config.polling:
request.configuration.blocking = self._config.polling
if (
not request.configuration.HasField('push_notification_config')
and self._config.push_notification_configs
):
request.configuration.push_notification_config.CopyFrom(
self._config.push_notification_configs[0]
)
if (
not request.configuration.accepted_output_modes
and self._config.accepted_output_modes
):
request.configuration.accepted_output_modes.extend(
self._config.accepted_output_modes
)

async def _process_stream(
self, stream: AsyncIterator[StreamResponse]
) -> AsyncGenerator[ClientEvent]:
Expand All @@ -147,21 +138,17 @@
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.

Args:
request: The `GetTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object representing the current state of the task.
"""
return await self._transport.get_task(
request, context=context, extensions=extensions
)
return await self._transport.get_task(request, context=context)

async def list_tasks(
self,
Expand All @@ -177,118 +164,104 @@
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.

Args:
request: The `CancelTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object containing the updated task status.
"""
return await self._transport.cancel_task(
request, context=context, extensions=extensions
)
return await self._transport.cancel_task(request, context=context)

async def create_task_push_notification_config(
self,
request: CreateTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task.

Args:
request: The `TaskPushNotificationConfig` object with the new configuration.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
The created or updated `TaskPushNotificationConfig` object.
"""
return await self._transport.create_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def get_task_push_notification_config(
self,
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task.

Args:
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `TaskPushNotificationConfig` object containing the configuration.
"""
return await self._transport.get_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def list_task_push_notification_configs(
self,
request: ListTaskPushNotificationConfigsRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task.

Args:
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `ListTaskPushNotificationConfigsResponse` object.
"""
return await self._transport.list_task_push_notification_configs(
request, context=context, extensions=extensions
request, context=context
)

async def delete_task_push_notification_config(
self,
request: DeleteTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task.

Args:
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
"""
await self._transport.delete_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def subscribe(
self,
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream.

This is only available if both the client and server support streaming.

Args:
request: Parameters to identify the task to resubscribe to.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent` objects.
Expand All @@ -304,9 +277,7 @@
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
stream = self._transport.subscribe(
request, context=context, extensions=extensions
)
stream = self._transport.subscribe(request, context=context)
async for client_event in self._process_stream(stream):
yield client_event

Expand All @@ -315,7 +286,6 @@
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.
Expand All @@ -325,8 +295,7 @@

Args:
request: The `GetExtendedAgentCardRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
Expand All @@ -335,7 +304,6 @@
card = await self._transport.get_extended_agent_card(
request,
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
Expand Down
19 changes: 2 additions & 17 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
Message,
PushNotificationConfig,
SendMessageConfiguration,
SendMessageRequest,
StreamResponse,
SubscribeToTaskRequest,
Task,
Expand Down Expand Up @@ -77,9 +76,6 @@ class ClientConfig:
)
"""Push notification configurations to use for every request."""

extensions: list[str] = dataclasses.field(default_factory=list)
"""A list of extension URIs the client supports."""


ClientEvent = tuple[StreamResponse, Task | None]

Expand Down Expand Up @@ -130,12 +126,9 @@ async def __aexit__(
@abstractmethod
async def send_message(
self,
request: Message,
request: SendMessageRequest,
*,
configuration: SendMessageConfiguration | None = None,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the server.

Expand All @@ -154,7 +147,6 @@ async def get_task(
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

Expand All @@ -173,7 +165,6 @@ async def cancel_task(
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

Expand All @@ -183,7 +174,6 @@ async def create_task_push_notification_config(
request: CreateTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

Expand All @@ -193,7 +183,6 @@ async def get_task_push_notification_config(
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

Expand All @@ -203,7 +192,6 @@ async def list_task_push_notification_configs(
request: ListTaskPushNotificationConfigsRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""

Expand All @@ -213,7 +201,6 @@ async def delete_task_push_notification_config(
request: DeleteTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task."""

Expand All @@ -223,7 +210,6 @@ async def subscribe(
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
Expand All @@ -235,7 +221,6 @@ async def get_extended_agent_card(
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
Expand Down
Loading
Loading