Skip to content
200 changes: 134 additions & 66 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from base64 import b64decode
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import (
Expand All @@ -12,28 +13,37 @@

from ... import ExternalToolset, ToolDefinition
from ...messages import (
AudioUrl,
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
DocumentUrl,
ImageUrl,
ModelMessage,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
VideoUrl,
)
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...toolsets import AbstractToolset

try:
from ag_ui.core import (
ActivityMessage,
AssistantMessage,
BaseEvent,
BinaryInputContent,
DeveloperMessage,
Message,
RunAgentInput,
SystemMessage,
TextInputContent,
Tool as AGUITool,
ToolCall,
ToolMessage,
UserMessage,
)
Expand Down Expand Up @@ -124,72 +134,130 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.

for msg in messages:
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
):
if isinstance(msg, UserMessage):
builder.add(UserPromptPart(content=msg.content))
elif isinstance(msg, SystemMessage | DeveloperMessage):
builder.add(SystemPromptPart(content=msg.content))
else:
tool_call_id = msg.tool_call_id
tool_name = tool_calls.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')

builder.add(
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
)
)
match msg:
case UserMessage(content=content):
if isinstance(content, str):
builder.add(UserPromptPart(content=content))
else:
user_prompt_content: list[Any] = []
for part in content:
match part:
case TextInputContent(text=text):
user_prompt_content.append(text)
case BinaryInputContent():
user_prompt_content.append(cls.load_binary_part(part))
case _:
raise ValueError(f'Unsupported user message part type: {type(part)}')

elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
):
if isinstance(msg, AssistantMessage):
if msg.content:
builder.add(TextPart(content=msg.content))

if msg.tool_calls:
for tool_call in msg.tool_calls:
tool_call_id = tool_call.id
tool_name = tool_call.function.name
tool_calls[tool_call_id] = tool_name

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolCallPart(
tool_name=tool_name,
args=tool_call.function.arguments,
tool_call_id=tool_call_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolCallPart(
tool_name=tool_name,
tool_call_id=tool_call_id,
args=tool_call.function.arguments,
)
)
else:
tool_call_id = msg.tool_call_id
tool_name = tool_calls.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
_, provider_name, tool_call_id = tool_call_id.split('|', 2)

builder.add(
BuiltinToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
provider_name=provider_name,
)
)
if user_prompt_content:
content_to_add = (
user_prompt_content[0]
if len(user_prompt_content) == 1 and isinstance(user_prompt_content[0], str)
else user_prompt_content
)
builder.add(UserPromptPart(content=content_to_add))

case SystemMessage(content=content) | DeveloperMessage(content=content):
builder.add(SystemPromptPart(content=content))

case AssistantMessage(content=content, tool_calls=tool_calls_list):
if content:
builder.add(TextPart(content=content))
if tool_calls_list:
cls.add_assistant_tool_parts(builder, tool_calls_list, tool_calls)

case ToolMessage() as tool_msg:
cls.add_tool_return_part(builder, tool_msg, tool_calls)

case ActivityMessage():
raise ValueError(f'Unsupported message type: {type(msg)}')

return builder.messages

@classmethod
def load_binary_part(cls, part: BinaryInputContent) -> BinaryContent | ImageUrl | VideoUrl | AudioUrl | DocumentUrl:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods really shouldn't be part of the public interface, so please make them private, either has class methods, or if that doesn't work, as private functions in the module, outside of the class

"""Transforms an AG-UI BinaryInputContent part into a Pydantic AI content part."""
if part.url:
try:
return BinaryContent.from_data_uri(part.url)
except ValueError:
media_type_constructors = {
'image': ImageUrl,
'video': VideoUrl,
'audio': AudioUrl,
}
media_type_prefix = part.mime_type.split('/', 1)[0]
constructor = media_type_constructors.get(media_type_prefix, DocumentUrl)
return constructor(
url=part.url,
media_type=part.mime_type,
identifier=part.id,
)
if part.data:
return BinaryContent(data=b64decode(part.data), kind='binary', media_type=part.mime_type)

raise ValueError('BinaryInputContent must have either a `url` or `data` field.')

@classmethod
def add_assistant_tool_parts(
cls,
builder: MessagesBuilder,
tool_calls_list: list[ToolCall],
tool_calls_map: dict[str, str],
) -> None:
"""Adds tool call parts from an AssistantMessage to the builder."""
for tool_call in tool_calls_list:
tool_call_id = tool_call.id
tool_name = tool_call.function.name
tool_calls_map[tool_call_id] = tool_name

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, original_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolCallPart(
tool_name=tool_name,
args=tool_call.function.arguments,
tool_call_id=original_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolCallPart(
tool_name=tool_name,
tool_call_id=tool_call_id,
args=tool_call.function.arguments,
)
)

@classmethod
def add_tool_return_part(
cls,
builder: MessagesBuilder,
msg: ToolMessage,
tool_calls_map: dict[str, str],
) -> None:
"""Adds a tool return part from a ToolMessage to the builder."""
tool_call_id = msg.tool_call_id
tool_name = tool_calls_map.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, original_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=original_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
)
)
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ui = ["starlette>=0.45.3"]
# A2A
a2a = ["fasta2a>=0.4.1"]
# AG-UI
ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
ag-ui = ["ag-ui-protocol>=0.1.10", "starlette>=0.45.3"]
# Retries
retries = ["tenacity>=8.2.3"]
# Temporal
Expand Down
Loading
Loading