diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 0f298965d7..1420d0cfb0 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -778,7 +778,7 @@ async def _map_tool_result_part( pass return text elif isinstance(part, mcp_types.ImageContent): - return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) + return messages.BinaryImage(data=base64.b64decode(part.data), media_type=part.mimeType) elif isinstance(part, mcp_types.AudioContent): # NOTE: The FastMCP server doesn't support audio content. # See for more details. @@ -799,8 +799,10 @@ def _get_content( if isinstance(resource, mcp_types.TextResourceContents): return resource.text elif isinstance(resource, mcp_types.BlobResourceContents): - return messages.BinaryContent( - data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream' + return messages.BinaryContent.narrow_type( + messages.BinaryContent( + data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream' + ) ) else: assert_never(resource) diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index fe3513ae58..d0882c77cf 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -2,6 +2,7 @@ from __future__ import annotations +from base64 import b64decode from collections.abc import Mapping, Sequence from functools import cached_property from typing import ( @@ -12,14 +13,19 @@ from ... import ExternalToolset, ToolDefinition from ...messages import ( + AudioUrl, + BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + DocumentUrl, + ImageUrl, ModelMessage, SystemPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, + VideoUrl, ) from ...output import OutputDataT from ...tools import AgentDepsT @@ -27,12 +33,15 @@ try: from ag_ui.core import ( + ActivityMessage, AssistantMessage, BaseEvent, + BinaryInputContent, DeveloperMessage, Message, RunAgentInput, SystemMessage, + TextInputContent, Tool as AGUITool, ToolMessage, UserMessage, @@ -118,53 +127,71 @@ def state(self) -> dict[str, Any] | None: return cast('dict[str, Any]', state) @classmethod - def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: + def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: # noqa: C901 """Transform AG-UI messages into Pydantic AI messages.""" builder = MessagesBuilder() tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping. - for msg in messages: - if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or ( - isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX) - ): - if isinstance(msg, UserMessage): - builder.add(UserPromptPart(content=msg.content)) - elif isinstance(msg, SystemMessage | DeveloperMessage): - builder.add(SystemPromptPart(content=msg.content)) - else: - tool_call_id = msg.tool_call_id - tool_name = tool_calls.get(tool_call_id) - if tool_name is None: # pragma: no cover - raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.') - - builder.add( - ToolReturnPart( - tool_name=tool_name, - content=msg.content, - tool_call_id=tool_call_id, - ) - ) - - elif isinstance(msg, AssistantMessage) or ( # pragma: no branch - isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX) - ): - if isinstance(msg, AssistantMessage): - if msg.content: - builder.add(TextPart(content=msg.content)) - - if msg.tool_calls: - for tool_call in msg.tool_calls: + match msg: + case UserMessage(content=content): + if isinstance(content, str): + builder.add(UserPromptPart(content=content)) + else: + user_prompt_content: list[Any] = [] + for part in content: + match part: + case TextInputContent(text=text): + user_prompt_content.append(text) + case BinaryInputContent(): + if part.url: + try: + binary_part = BinaryContent.from_data_uri(part.url) + except ValueError: + media_type_constructors = { + 'image': ImageUrl, + 'video': VideoUrl, + 'audio': AudioUrl, + } + media_type_prefix = part.mime_type.split('/', 1)[0] + constructor = media_type_constructors.get(media_type_prefix, DocumentUrl) + binary_part = constructor(url=part.url, media_type=part.mime_type) + elif part.data: + binary_part = BinaryContent( + data=b64decode(part.data), media_type=part.mime_type + ) + else: # pragma: no cover + raise ValueError('BinaryInputContent must have either a `url` or `data` field.') + user_prompt_content.append(binary_part) + case _: # pragma: no cover + raise ValueError(f'Unsupported user message part type: {type(part)}') + + if user_prompt_content: # pragma: no branch + content_to_add = ( + user_prompt_content[0] + if len(user_prompt_content) == 1 and isinstance(user_prompt_content[0], str) + else user_prompt_content + ) + builder.add(UserPromptPart(content=content_to_add)) + + case SystemMessage(content=content) | DeveloperMessage(content=content): + builder.add(SystemPromptPart(content=content)) + + case AssistantMessage(content=content, tool_calls=tool_calls_list): + if content: + builder.add(TextPart(content=content)) + if tool_calls_list: + for tool_call in tool_calls_list: tool_call_id = tool_call.id tool_name = tool_call.function.name tool_calls[tool_call_id] = tool_name if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX): - _, provider_name, tool_call_id = tool_call_id.split('|', 2) + _, provider_name, original_id = tool_call_id.split('|', 2) builder.add( BuiltinToolCallPart( tool_name=tool_name, args=tool_call.function.arguments, - tool_call_id=tool_call_id, + tool_call_id=original_id, provider_name=provider_name, ) ) @@ -176,20 +203,32 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: args=tool_call.function.arguments, ) ) - else: - tool_call_id = msg.tool_call_id + case ToolMessage() as tool_msg: + tool_call_id = tool_msg.tool_call_id tool_name = tool_calls.get(tool_call_id) if tool_name is None: # pragma: no cover raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.') - _, provider_name, tool_call_id = tool_call_id.split('|', 2) - - builder.add( - BuiltinToolReturnPart( - tool_name=tool_name, - content=msg.content, - tool_call_id=tool_call_id, - provider_name=provider_name, + + if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX): + _, provider_name, original_id = tool_call_id.split('|', 2) + builder.add( + BuiltinToolReturnPart( + tool_name=tool_name, + content=tool_msg.content, + tool_call_id=original_id, + provider_name=provider_name, + ) + ) + else: + builder.add( + ToolReturnPart( + tool_name=tool_name, + content=tool_msg.content, + tool_call_id=tool_call_id, + ) ) - ) + + case ActivityMessage(): # pragma: no cover + raise ValueError(f'Unsupported message type: {type(msg)}') return builder.messages diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index c5a85c0b3d..8d196a813b 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -103,7 +103,7 @@ ui = ["starlette>=0.45.3"] # A2A a2a = ["fasta2a>=0.4.1"] # AG-UI -ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +ag-ui = ["ag-ui-protocol>=0.1.10", "starlette>=0.45.3"] # Web web = ["starlette>=0.45.3", "httpx>=0.27.0", "uvicorn>=0.38.0"] # Retries diff --git a/tests/conftest.py b/tests/conftest.py index 32b4a475cc..35a163f300 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from vcr import VCR, request as vcr_request import pydantic_ai.models -from pydantic_ai import Agent, BinaryContent +from pydantic_ai import Agent, BinaryContent, BinaryImage from pydantic_ai.models import Model, cached_async_http_client __all__ = ( @@ -333,9 +333,9 @@ def audio_content(assets_path: Path) -> BinaryContent: @pytest.fixture(scope='session') -def image_content(assets_path: Path) -> BinaryContent: +def image_content(assets_path: Path) -> BinaryImage: image_bytes = assets_path.joinpath('kiwi.png').read_bytes() - return BinaryContent(data=image_bytes, media_type='image/png') + return BinaryImage(data=image_bytes, media_type='image/png') @pytest.fixture(scope='session') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ee1aa83b15..fe6fa8399a 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -63,7 +63,7 @@ from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage -from ..conftest import ClientWithHandler, IsBytes, IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import +from ..conftest import ClientWithHandler, IsDatetime, IsInstance, IsNow, IsStr, TestEnv, try_import pytestmark = [ pytest.mark.anyio, @@ -1226,10 +1226,7 @@ async def get_image() -> BinaryContent: UserPromptPart( content=[ 'This is file 1c8566:', - BinaryContent( - data=IsBytes(), - media_type='image/png', - ), + image_content, ], timestamp=IsDatetime(), ), diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index b062de8b6f..49abd1f5e9 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -17,10 +17,14 @@ from pydantic import BaseModel from pydantic_ai import ( + AudioUrl, + BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + DocumentUrl, FunctionToolCallEvent, FunctionToolResultEvent, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, @@ -35,6 +39,7 @@ ToolReturn, ToolReturnPart, UserPromptPart, + VideoUrl, ) from pydantic_ai._run_context import RunContext from pydantic_ai.agent import Agent, AgentRunResult @@ -58,6 +63,7 @@ from ag_ui.core import ( AssistantMessage, BaseEvent, + BinaryInputContent, CustomEvent, DeveloperMessage, EventType, @@ -66,6 +72,7 @@ RunAgentInput, StateSnapshotEvent, SystemMessage, + TextInputContent, Tool, ToolCall, ToolMessage, @@ -1558,7 +1565,7 @@ async def async_callback(run_result: AgentRunResult[Any]) -> None: assert events[-1]['type'] == 'RUN_FINISHED' -async def test_messages() -> None: +async def test_messages(image_content: BinaryContent, document_content: BinaryContent) -> None: messages = [ SystemMessage( id='msg_1', @@ -1576,6 +1583,32 @@ async def test_messages() -> None: id='msg_4', content='User message', ), + UserMessage( + id='msg_1', + content=[ + TextInputContent(text='this is an image:'), + BinaryInputContent(url=image_content.data_uri, mime_type=image_content.media_type), + ], + ), + UserMessage( + id='msg2', + content=[BinaryInputContent(url='http://example.com/image.png', mime_type='image/png')], + ), + UserMessage( + id='msg3', + content=[BinaryInputContent(url='http://example.com/video.mp4', mime_type='video/mp4')], + ), + UserMessage( + id='msg4', + content=[BinaryInputContent(url='http://example.com/audio.mp3', mime_type='audio/mpeg')], + ), + UserMessage( + id='msg5', + content=[BinaryInputContent(url='http://example.com/doc.pdf', mime_type='application/pdf')], + ), + UserMessage( + id='msg6', content=[BinaryInputContent(data=document_content.base64, mime_type=document_content.media_type)] + ), AssistantMessage( id='msg_5', tool_calls=[ @@ -1661,6 +1694,48 @@ async def test_messages() -> None: content='User message', timestamp=IsDatetime(), ), + UserPromptPart( + content=['this is an image:', image_content], + timestamp=IsDatetime(), + ), + UserPromptPart( + content=[ + ImageUrl( + url='http://example.com/image.png', _media_type='image/png', media_type='image/png' + ) + ], + timestamp=IsDatetime(), + ), + UserPromptPart( + content=[ + VideoUrl( + url='http://example.com/video.mp4', _media_type='video/mp4', media_type='video/mp4' + ) + ], + timestamp=IsDatetime(), + ), + UserPromptPart( + content=[ + AudioUrl( + url='http://example.com/audio.mp3', _media_type='audio/mpeg', media_type='audio/mpeg' + ) + ], + timestamp=IsDatetime(), + ), + UserPromptPart( + content=[ + DocumentUrl( + url='http://example.com/doc.pdf', + _media_type='application/pdf', + media_type='application/pdf', + ) + ], + timestamp=IsDatetime(), + ), + UserPromptPart( + content=[document_content], + timestamp=IsDatetime(), + ), ] ), ModelResponse( diff --git a/uv.lock b/uv.lock index dae85d1180..ba4012112b 100644 --- a/uv.lock +++ b/uv.lock @@ -44,14 +44,14 @@ wheels = [ [[package]] name = "ag-ui-protocol" -version = "0.1.8" +version = "0.1.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/de/0bddf7f26d5f38274c99401735c82ad59df9cead6de42f4bb2ad837286fe/ag_ui_protocol-0.1.8.tar.gz", hash = "sha256:eb745855e9fc30964c77e953890092f8bd7d4bbe6550d6413845428dd0faac0b", size = 5323, upload-time = "2025-07-15T10:55:36.389Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/bb/5a5ec893eea5805fb9a3db76a9888c3429710dfb6f24bbb37568f2cf7320/ag_ui_protocol-0.1.10.tar.gz", hash = "sha256:3213991c6b2eb24bb1a8c362ee270c16705a07a4c5962267a083d0959ed894f4", size = 6945, upload-time = "2025-11-06T15:17:17.068Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/00/40c6b0313c25d1ab6fac2ecba1cd5b15b1cd3c3a71b3d267ad890e405889/ag_ui_protocol-0.1.8-py3-none-any.whl", hash = "sha256:1567ccb067b7b8158035b941a985e7bb185172d660d4542f3f9c6fff77b55c6e", size = 7066, upload-time = "2025-07-15T10:55:35.075Z" }, + { url = "https://files.pythonhosted.org/packages/8f/78/eb55fabaab41abc53f52c0918a9a8c0f747807e5306273f51120fd695957/ag_ui_protocol-0.1.10-py3-none-any.whl", hash = "sha256:c81e6981f30aabdf97a7ee312bfd4df0cd38e718d9fc10019c7d438128b93ab5", size = 7889, upload-time = "2025-11-06T15:17:15.325Z" }, ] [[package]] @@ -5672,7 +5672,7 @@ web = [ [package.metadata] requires-dist = [ - { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.8" }, + { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.10" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.75.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.40.14" },