Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,90 @@ print(result.output)

_(This example is complete, it can be run "as is")_

#### Handling partial output in output functions

!!! warning "Output functions are called multiple times during streaming"
When using streaming mode (`run_stream()`), output functions are called **multiple times** — once for each partial output received from the model, and once for the final complete output.

For output functions with **side effects** (e.g., sending notifications, logging, database updates), you should check the [`RunContext.partial_output`][pydantic_ai.tools.RunContext.partial_output] flag to avoid executing side effects on partial data.

**How `partial_output` works:**

- **In sync mode** (`run_sync()`):
- `partial_output=False` always (function called once)
- **In streaming mode** (`run_stream()`):
- `partial_output=True` for each partial call
- `partial_output=False` for the final complete call

**Example with side effects:**

```python {title="output_function_with_side_effects.py"}
from pydantic import BaseModel

from pydantic_ai import Agent, RunContext


class DatabaseRecord(BaseModel):
name: str
value: int


def save_to_database(ctx: RunContext, record: DatabaseRecord) -> DatabaseRecord:
"""Output function with side effect - only save final output to database."""
if ctx.partial_output:
# Skip side effects for partial outputs
return record

# Only execute side effect for the final output
print(f'Saving to database: {record.name} = {record.value}')
#> Saving to database: test = 42
return record


agent = Agent('openai:gpt-5', output_type=save_to_database)

result = agent.run_sync('Create a record with name "test" and value 42')
print(result.output)
#> name='test' value=42
```

_(This example is complete, it can be run "as is")_

**Example without side effects (transformation only):**

```python {title="output_function_transformation.py"}
from pydantic import BaseModel

from pydantic_ai import Agent


class UserData(BaseModel):
username: str
email: str


def normalize_user_data(user: UserData) -> UserData:
"""Output function without side effects - safe to call multiple times."""
# Pure transformation is safe for multiple calls
user.username = user.username.lower()
user.email = user.email.lower()
return user


agent = Agent('openai:gpt-5', output_type=normalize_user_data)

result = agent.run_sync('Create user with username "JohnDoe" and email "[email protected]"')
print(result.output)
#> username='johndoe' email='[email protected]'
```

_(This example is complete, it can be run "as is")_

**Best practices:**

- If your output function **has** side effects (database writes, API calls, notifications) → use `ctx.partial_output` to guard them
- If your output function only **transforms** data (formatting, validation, normalization) → no need to check the flag

### Output modes

Pydantic AI implements three different methods to get a model to output structured data:
Expand Down
112 changes: 51 additions & 61 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
from collections import defaultdict
from collections.abc import AsyncIterable, AsyncIterator, Callable
from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass, replace
from datetime import timezone
from typing import Any, Generic, Literal, TypeVar, Union
Expand Down Expand Up @@ -65,7 +65,7 @@
WebSearchTool,
WebSearchUserLocation,
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput
from pydantic_ai.result import RunUsage
Expand Down Expand Up @@ -360,83 +360,73 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo:
)


def test_output_validator_partial_sync():
"""Test that output validators receive correct value for `partial_output` in sync mode."""
call_log: list[tuple[str, bool]] = []
class TestPartialOutput:
"""Tests for `ctx.partial_output` flag in output validators and output functions."""

agent = Agent[None, str](TestModel(custom_output_text='test output'))
# NOTE: When changing these tests:
# 1. Follow the existing order
# 2. Update tests in `tests/test_streaming.py::TestPartialOutput` as well

@agent.output_validator
def validate_output(ctx: RunContext[None], output: str) -> str:
call_log.append((output, ctx.partial_output))
return output
def test_output_validator_text(self):
"""Test that output validators receive correct value for `partial_output` with text output."""
call_log: list[tuple[str, bool]] = []

result = agent.run_sync('Hello')
assert result.output == 'test output'
def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart('Hello world!')])

assert call_log == snapshot([('test output', False)])
agent = Agent(FunctionModel(return_model))

@agent.output_validator
def validate_output(ctx: RunContext[None], output: str) -> str:
call_log.append((output, ctx.partial_output))
return output

async def test_output_validator_partial_stream_text():
"""Test that output validators receive correct value for `partial_output` when using stream_text()."""
call_log: list[tuple[str, bool]] = []
result = agent.run_sync('test')

async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]:
for chunk in ['Hello', ' ', 'world', '!']:
yield chunk
assert result.output == 'Hello world!'
assert call_log == snapshot([('Hello world!', False)])

agent = Agent(FunctionModel(stream_function=stream_text))
def test_output_validator_structured(self):
"""Test that output validators receive correct value for `partial_output` with structured output."""
call_log: list[tuple[Foo, bool]] = []

@agent.output_validator
def validate_output(ctx: RunContext[None], output: str) -> str:
call_log.append((output, ctx.partial_output))
return output
def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.output_tools is not None
tool_name = info.output_tools[0].name
args_json = '{"a": 42, "b": "foo"}'
return ModelResponse(parts=[ToolCallPart(tool_name, args_json)])

async with agent.run_stream('Hello') as result:
text_parts = []
async for chunk in result.stream_text(debounce_by=None):
text_parts.append(chunk)
agent = Agent(FunctionModel(return_model), output_type=Foo)

assert text_parts[-1] == 'Hello world!'
assert call_log == snapshot(
[
('Hello', True),
('Hello ', True),
('Hello world', True),
('Hello world!', True),
('Hello world!', False),
]
)
@agent.output_validator
def validate_output(ctx: RunContext[None], output: Foo) -> Foo:
call_log.append((output, ctx.partial_output))
return output

result = agent.run_sync('test')

async def test_output_validator_partial_stream_output():
"""Test that output validators receive correct value for `partial_output` when using stream_output()."""
call_log: list[tuple[Foo, bool]] = []
assert result.output == Foo(a=42, b='foo')
assert call_log == snapshot([(Foo(a=42, b='foo'), False)])

async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
assert info.output_tools is not None
yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')}
yield {0: DeltaToolCall(json_args=', "b": "f')}
yield {0: DeltaToolCall(json_args='oo"}')}
def test_output_function_structured(self):
"""Test that output functions receive correct value for `partial_output` with structured output."""
call_log: list[tuple[Foo, bool]] = []

agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo)
def process_foo(ctx: RunContext[None], foo: Foo) -> Foo:
call_log.append((foo, ctx.partial_output))
return Foo(a=foo.a * 2, b=foo.b.upper())

@agent.output_validator
def validate_output(ctx: RunContext[None], output: Foo) -> Foo:
call_log.append((output, ctx.partial_output))
return output
def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.output_tools is not None
tool_name = info.output_tools[0].name
args_json = '{"a": 21, "b": "foo"}'
return ModelResponse(parts=[ToolCallPart(tool_name, args_json)])

async with agent.run_stream('Hello') as result:
outputs = [output async for output in result.stream_output(debounce_by=None)]
agent = Agent(FunctionModel(return_model), output_type=process_foo)
result = agent.run_sync('test')

assert outputs[-1] == Foo(a=42, b='foo')
assert call_log == snapshot(
[
(Foo(a=42, b='f'), True),
(Foo(a=42, b='foo'), True),
(Foo(a=42, b='foo'), False),
]
)
assert result.output == Foo(a=42, b='FOO')
assert call_log == snapshot([(Foo(a=21, b='foo'), False)])


def test_plain_response_then_tuple():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,16 @@ async def call_tool(
'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.",
'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...',
'What is the first sentence on https://ai.pydantic.dev?': 'Pydantic AI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI.',
'Create a record with name "test" and value 42': ToolCallPart(
tool_name='final_result',
args={'name': 'test', 'value': 42},
tool_call_id='pyd_ai_tool_call_id',
),
'Create user with username "JohnDoe" and email "[email protected]"': ToolCallPart(
tool_name='final_result',
args={'username': 'JohnDoe', 'email': '[email protected]'},
tool_call_id='pyd_ai_tool_call_id',
),
}

tool_responses: dict[tuple[str, str], str] = {
Expand Down
99 changes: 99 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
pytestmark = pytest.mark.anyio


class Foo(BaseModel):
a: int
b: str


async def test_streamed_text_response():
m = TestModel()

Expand Down Expand Up @@ -747,6 +752,100 @@ async def ret_a(x: str) -> str: # pragma: no cover
)


class TestPartialOutput:
"""Tests for `ctx.partial_output` flag in output validators and output functions."""

# NOTE: When changing these tests:
# 1. Follow the existing order
# 2. Update tests in `tests/test_agent.py::TestPartialOutput` as well

async def test_output_validator_text(self):
"""Test that output validators receive correct value for `partial_output` with text output."""
call_log: list[tuple[str, bool]] = []

async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]:
for chunk in ['Hello', ' ', 'world', '!']:
yield chunk

agent = Agent(FunctionModel(stream_function=sf))

@agent.output_validator
def validate_output(ctx: RunContext[None], output: str) -> str:
call_log.append((output, ctx.partial_output))
return output

async with agent.run_stream('test') as result:
text_parts = [text_part async for text_part in result.stream_text(debounce_by=None)]

assert text_parts[-1] == 'Hello world!'
assert call_log == snapshot(
[
('Hello', True),
('Hello ', True),
('Hello world', True),
('Hello world!', True),
('Hello world!', False),
]
)

async def test_output_validator_structured(self):
"""Test that output validators receive correct value for `partial_output` with structured output."""
call_log: list[tuple[Foo, bool]] = []

async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
assert info.output_tools is not None
yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')}
yield {0: DeltaToolCall(json_args=', "b": "f')}
yield {0: DeltaToolCall(json_args='oo"}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=Foo)

@agent.output_validator
def validate_output(ctx: RunContext[None], output: Foo) -> Foo:
call_log.append((output, ctx.partial_output))
return output

async with agent.run_stream('test') as result:
outputs = [output async for output in result.stream_output(debounce_by=None)]

assert outputs[-1] == Foo(a=42, b='foo')
assert call_log == snapshot(
[
(Foo(a=42, b='f'), True),
(Foo(a=42, b='foo'), True),
(Foo(a=42, b='foo'), False),
]
)

async def test_output_function_structured(self):
"""Test that output functions receive correct value for `partial_output` with structured output."""
call_log: list[tuple[Foo, bool]] = []

def process_foo(ctx: RunContext[None], foo: Foo) -> Foo:
call_log.append((foo, ctx.partial_output))
return Foo(a=foo.a * 2, b=foo.b.upper())

async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
assert info.output_tools is not None
yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 21')}
yield {0: DeltaToolCall(json_args=', "b": "f')}
yield {0: DeltaToolCall(json_args='oo"}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=process_foo)

async with agent.run_stream('test') as result:
outputs = [output async for output in result.stream_output(debounce_by=None)]

assert outputs[-1] == Foo(a=42, b='FOO')
assert call_log == snapshot(
[
(Foo(a=21, b='f'), True),
(Foo(a=21, b='foo'), True),
(Foo(a=21, b='foo'), False),
]
)


class OutputType(BaseModel):
"""Result type used by multiple tests."""

Expand Down