diff --git a/docs/output.md b/docs/output.md index 3b20109b71..9e259afbd4 100644 --- a/docs/output.md +++ b/docs/output.md @@ -258,6 +258,90 @@ print(result.output) _(This example is complete, it can be run "as is")_ +#### Handling partial output in output functions + +!!! warning "Output functions are called multiple times during streaming" + When using streaming mode (`run_stream()`), output functions are called **multiple times** — once for each partial output received from the model, and once for the final complete output. + + For output functions with **side effects** (e.g., sending notifications, logging, database updates), you should check the [`RunContext.partial_output`][pydantic_ai.tools.RunContext.partial_output] flag to avoid executing side effects on partial data. + +**How `partial_output` works:** + +- **In sync mode** (`run_sync()`): + - `partial_output=False` always (function called once) +- **In streaming mode** (`run_stream()`): + - `partial_output=True` for each partial call + - `partial_output=False` for the final complete call + +**Example with side effects:** + +```python {title="output_function_with_side_effects.py"} +from pydantic import BaseModel + +from pydantic_ai import Agent, RunContext + + +class DatabaseRecord(BaseModel): + name: str + value: int + + +def save_to_database(ctx: RunContext, record: DatabaseRecord) -> DatabaseRecord: + """Output function with side effect - only save final output to database.""" + if ctx.partial_output: + # Skip side effects for partial outputs + return record + + # Only execute side effect for the final output + print(f'Saving to database: {record.name} = {record.value}') + #> Saving to database: test = 42 + return record + + +agent = Agent('openai:gpt-5', output_type=save_to_database) + +result = agent.run_sync('Create a record with name "test" and value 42') +print(result.output) +#> name='test' value=42 +``` + +_(This example is complete, it can be run "as is")_ + +**Example without side effects (transformation only):** + +```python {title="output_function_transformation.py"} +from pydantic import BaseModel + +from pydantic_ai import Agent + + +class UserData(BaseModel): + username: str + email: str + + +def normalize_user_data(user: UserData) -> UserData: + """Output function without side effects - safe to call multiple times.""" + # Pure transformation is safe for multiple calls + user.username = user.username.lower() + user.email = user.email.lower() + return user + + +agent = Agent('openai:gpt-5', output_type=normalize_user_data) + +result = agent.run_sync('Create user with username "JohnDoe" and email "JOHN@EXAMPLE.COM"') +print(result.output) +#> username='johndoe' email='john@example.com' +``` + +_(This example is complete, it can be run "as is")_ + +**Best practices:** + +- If your output function **has** side effects (database writes, API calls, notifications) → use `ctx.partial_output` to guard them +- If your output function only **transforms** data (formatting, validation, normalization) → no need to check the flag + ### Output modes Pydantic AI implements three different methods to get a model to output structured data: diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ce2d91c54..68de266968 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3,7 +3,7 @@ import re import sys from collections import defaultdict -from collections.abc import AsyncIterable, AsyncIterator, Callable +from collections.abc import AsyncIterable, Callable from dataclasses import dataclass, replace from datetime import timezone from typing import Any, Generic, Literal, TypeVar, Union @@ -65,7 +65,7 @@ WebSearchTool, WebSearchUserLocation, ) -from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel +from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput from pydantic_ai.result import RunUsage @@ -360,83 +360,73 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ) -def test_output_validator_partial_sync(): - """Test that output validators receive correct value for `partial_output` in sync mode.""" - call_log: list[tuple[str, bool]] = [] +class TestPartialOutput: + """Tests for `ctx.partial_output` flag in output validators and output functions.""" - agent = Agent[None, str](TestModel(custom_output_text='test output')) + # NOTE: When changing these tests: + # 1. Follow the existing order + # 2. Update tests in `tests/test_streaming.py::TestPartialOutput` as well - @agent.output_validator - def validate_output(ctx: RunContext[None], output: str) -> str: - call_log.append((output, ctx.partial_output)) - return output + def test_output_validator_text(self): + """Test that output validators receive correct value for `partial_output` with text output.""" + call_log: list[tuple[str, bool]] = [] - result = agent.run_sync('Hello') - assert result.output == 'test output' + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('Hello world!')]) - assert call_log == snapshot([('test output', False)]) + agent = Agent(FunctionModel(return_model)) + @agent.output_validator + def validate_output(ctx: RunContext[None], output: str) -> str: + call_log.append((output, ctx.partial_output)) + return output -async def test_output_validator_partial_stream_text(): - """Test that output validators receive correct value for `partial_output` when using stream_text().""" - call_log: list[tuple[str, bool]] = [] + result = agent.run_sync('test') - async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: - for chunk in ['Hello', ' ', 'world', '!']: - yield chunk + assert result.output == 'Hello world!' + assert call_log == snapshot([('Hello world!', False)]) - agent = Agent(FunctionModel(stream_function=stream_text)) + def test_output_validator_structured(self): + """Test that output validators receive correct value for `partial_output` with structured output.""" + call_log: list[tuple[Foo, bool]] = [] - @agent.output_validator - def validate_output(ctx: RunContext[None], output: str) -> str: - call_log.append((output, ctx.partial_output)) - return output + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + tool_name = info.output_tools[0].name + args_json = '{"a": 42, "b": "foo"}' + return ModelResponse(parts=[ToolCallPart(tool_name, args_json)]) - async with agent.run_stream('Hello') as result: - text_parts = [] - async for chunk in result.stream_text(debounce_by=None): - text_parts.append(chunk) + agent = Agent(FunctionModel(return_model), output_type=Foo) - assert text_parts[-1] == 'Hello world!' - assert call_log == snapshot( - [ - ('Hello', True), - ('Hello ', True), - ('Hello world', True), - ('Hello world!', True), - ('Hello world!', False), - ] - ) + @agent.output_validator + def validate_output(ctx: RunContext[None], output: Foo) -> Foo: + call_log.append((output, ctx.partial_output)) + return output + result = agent.run_sync('test') -async def test_output_validator_partial_stream_output(): - """Test that output validators receive correct value for `partial_output` when using stream_output().""" - call_log: list[tuple[Foo, bool]] = [] + assert result.output == Foo(a=42, b='foo') + assert call_log == snapshot([(Foo(a=42, b='foo'), False)]) - async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: - assert info.output_tools is not None - yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')} - yield {0: DeltaToolCall(json_args=', "b": "f')} - yield {0: DeltaToolCall(json_args='oo"}')} + def test_output_function_structured(self): + """Test that output functions receive correct value for `partial_output` with structured output.""" + call_log: list[tuple[Foo, bool]] = [] - agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo) + def process_foo(ctx: RunContext[None], foo: Foo) -> Foo: + call_log.append((foo, ctx.partial_output)) + return Foo(a=foo.a * 2, b=foo.b.upper()) - @agent.output_validator - def validate_output(ctx: RunContext[None], output: Foo) -> Foo: - call_log.append((output, ctx.partial_output)) - return output + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + tool_name = info.output_tools[0].name + args_json = '{"a": 21, "b": "foo"}' + return ModelResponse(parts=[ToolCallPart(tool_name, args_json)]) - async with agent.run_stream('Hello') as result: - outputs = [output async for output in result.stream_output(debounce_by=None)] + agent = Agent(FunctionModel(return_model), output_type=process_foo) + result = agent.run_sync('test') - assert outputs[-1] == Foo(a=42, b='foo') - assert call_log == snapshot( - [ - (Foo(a=42, b='f'), True), - (Foo(a=42, b='foo'), True), - (Foo(a=42, b='foo'), False), - ] - ) + assert result.output == Foo(a=42, b='FOO') + assert call_log == snapshot([(Foo(a=21, b='foo'), False)]) def test_plain_response_then_tuple(): diff --git a/tests/test_examples.py b/tests/test_examples.py index 8ed0828250..f786f0bc7a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -539,6 +539,16 @@ async def call_tool( 'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.", 'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...', 'What is the first sentence on https://ai.pydantic.dev?': 'Pydantic AI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI.', + 'Create a record with name "test" and value 42': ToolCallPart( + tool_name='final_result', + args={'name': 'test', 'value': 42}, + tool_call_id='pyd_ai_tool_call_id', + ), + 'Create user with username "JohnDoe" and email "JOHN@EXAMPLE.COM"': ToolCallPart( + tool_name='final_result', + args={'username': 'JohnDoe', 'email': 'JOHN@EXAMPLE.COM'}, + tool_call_id='pyd_ai_tool_call_id', + ), } tool_responses: dict[tuple[str, str], str] = { diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9149d19d1b..3905027ea7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -56,6 +56,11 @@ pytestmark = pytest.mark.anyio +class Foo(BaseModel): + a: int + b: str + + async def test_streamed_text_response(): m = TestModel() @@ -747,6 +752,100 @@ async def ret_a(x: str) -> str: # pragma: no cover ) +class TestPartialOutput: + """Tests for `ctx.partial_output` flag in output validators and output functions.""" + + # NOTE: When changing these tests: + # 1. Follow the existing order + # 2. Update tests in `tests/test_agent.py::TestPartialOutput` as well + + async def test_output_validator_text(self): + """Test that output validators receive correct value for `partial_output` with text output.""" + call_log: list[tuple[str, bool]] = [] + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: + for chunk in ['Hello', ' ', 'world', '!']: + yield chunk + + agent = Agent(FunctionModel(stream_function=sf)) + + @agent.output_validator + def validate_output(ctx: RunContext[None], output: str) -> str: + call_log.append((output, ctx.partial_output)) + return output + + async with agent.run_stream('test') as result: + text_parts = [text_part async for text_part in result.stream_text(debounce_by=None)] + + assert text_parts[-1] == 'Hello world!' + assert call_log == snapshot( + [ + ('Hello', True), + ('Hello ', True), + ('Hello world', True), + ('Hello world!', True), + ('Hello world!', False), + ] + ) + + async def test_output_validator_structured(self): + """Test that output validators receive correct value for `partial_output` with structured output.""" + call_log: list[tuple[Foo, bool]] = [] + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: + assert info.output_tools is not None + yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')} + yield {0: DeltaToolCall(json_args=', "b": "f')} + yield {0: DeltaToolCall(json_args='oo"}')} + + agent = Agent(FunctionModel(stream_function=sf), output_type=Foo) + + @agent.output_validator + def validate_output(ctx: RunContext[None], output: Foo) -> Foo: + call_log.append((output, ctx.partial_output)) + return output + + async with agent.run_stream('test') as result: + outputs = [output async for output in result.stream_output(debounce_by=None)] + + assert outputs[-1] == Foo(a=42, b='foo') + assert call_log == snapshot( + [ + (Foo(a=42, b='f'), True), + (Foo(a=42, b='foo'), True), + (Foo(a=42, b='foo'), False), + ] + ) + + async def test_output_function_structured(self): + """Test that output functions receive correct value for `partial_output` with structured output.""" + call_log: list[tuple[Foo, bool]] = [] + + def process_foo(ctx: RunContext[None], foo: Foo) -> Foo: + call_log.append((foo, ctx.partial_output)) + return Foo(a=foo.a * 2, b=foo.b.upper()) + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: + assert info.output_tools is not None + yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 21')} + yield {0: DeltaToolCall(json_args=', "b": "f')} + yield {0: DeltaToolCall(json_args='oo"}')} + + agent = Agent(FunctionModel(stream_function=sf), output_type=process_foo) + + async with agent.run_stream('test') as result: + outputs = [output async for output in result.stream_output(debounce_by=None)] + + assert outputs[-1] == Foo(a=42, b='FOO') + assert call_log == snapshot( + [ + (Foo(a=21, b='f'), True), + (Foo(a=21, b='foo'), True), + (Foo(a=21, b='foo'), False), + ] + ) + + class OutputType(BaseModel): """Result type used by multiple tests."""