diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 4be5e5164b..0be61fd741 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -746,8 +746,7 @@ async def _consume_stream(): ) as stream_result: yield stream_result - async_result = _utils.get_event_loop().run_until_complete(anext(_consume_stream())) - return result.StreamedRunResultSync(async_result) + return result.StreamedRunResultSync(_consume_stream()) @overload def run_stream_events( diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 13a1a15f2e..40f5c7737c 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,6 +1,8 @@ from __future__ import annotations as _annotations +import inspect from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator +from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass, field, replace from datetime import datetime @@ -585,10 +587,17 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]): """Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods.""" - _streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] + _streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] | None = None - def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None: - self._streamed_run_result = streamed_run_result + def __init__( + self, + streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] + | AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]], + ) -> None: + if isinstance(streamed_run_result, StreamedRunResult): + self._streamed_run_result = streamed_run_result + else: + self._stream = streamed_run_result def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return the history of messages. @@ -602,7 +611,9 @@ def all_messages(self, *, output_tool_return_content: str | None = None) -> list Returns: List of messages. """ - return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.all_messages(output_tool_return_content=output_tool_return_content) + ) def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes. @@ -616,7 +627,9 @@ def all_messages_json(self, *, output_tool_return_content: str | None = None) -> Returns: JSON bytes representing the messages. """ - return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.all_messages_json(output_tool_return_content=output_tool_return_content) + ) def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return new messages associated with this run. @@ -632,7 +645,9 @@ def new_messages(self, *, output_tool_return_content: str | None = None) -> list Returns: List of new messages. """ - return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.new_messages(output_tool_return_content=output_tool_return_content) + ) def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes. @@ -646,7 +661,9 @@ def new_messages_json(self, *, output_tool_return_content: str | None = None) -> Returns: JSON bytes representing the new messages. """ - return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.new_messages_json(output_tool_return_content=output_tool_return_content) + ) def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: """Stream the output as an iterable. @@ -663,7 +680,7 @@ def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDa Returns: An iterable of the response data. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_output(debounce_by=debounce_by)) def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: """Stream the text result as an iterable. @@ -678,7 +695,7 @@ def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) - Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_text(delta=delta, debounce_by=debounce_by)) def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]: """Stream the response as an iterable of Structured LLM Messages. @@ -691,16 +708,66 @@ def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple Returns: An iterable of the structured response message and whether that is the last message. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_responses(debounce_by=debounce_by)) def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" - return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output()) + return self._async_to_sync(lambda result: result.get_output()) + + @asynccontextmanager + async def _with_streamed_run_result(self) -> AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]]: + clean_up = False + if self._streamed_run_result is None: + clean_up = True + self._streamed_run_result = await anext(self._stream) + + yield self._streamed_run_result + + if clean_up: + try: + await anext(self._stream) + except StopAsyncIteration: + pass + + def _async_iterator_to_sync( + self, + func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], AsyncIterator[T]], + ) -> Iterator[T]: + async def my_task(): + try: + async with self._with_streamed_run_result() as result: + async for item in func(result): + yield item + except RuntimeError as e: + if str(e) != 'Attempted to exit cancel scope in a different task than it was entered in': + raise # pragma: no cover + + return _utils.sync_async_iterator(my_task()) + + def _async_to_sync( + self, + func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], T] + | Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], Awaitable[T]], + ) -> T: + if self._streamed_run_result is not None: + res = func(self._streamed_run_result) + if inspect.isawaitable(res): + res = _utils.get_event_loop().run_until_complete(res) + return res + + async def my_task(): + async with self._with_streamed_run_result() as result: + res = func(result) + if inspect.isawaitable(res): + res = cast(T, await res) + return res + + return _utils.get_event_loop().run_until_complete(my_task()) @property def response(self) -> _messages.ModelResponse: """Return the current state of the response.""" - return self._streamed_run_result.response + return self._async_to_sync(lambda result: result.response) def usage(self) -> RunUsage: """Return the usage of the whole run. @@ -708,22 +775,20 @@ def usage(self) -> RunUsage: !!! note This won't return the full usage until the stream is finished. """ - return self._streamed_run_result.usage() + return self._async_to_sync(lambda result: result.usage()) def timestamp(self) -> datetime: """Get the timestamp of the response.""" - return self._streamed_run_result.timestamp() + return self._async_to_sync(lambda result: result.timestamp()) @property def run_id(self) -> str: """The unique identifier for the agent run.""" - return self._streamed_run_result.run_id + return self._async_to_sync(lambda result: result.run_id) def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT: """Validate a structured result message.""" - return _utils.get_event_loop().run_until_complete( - self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial) - ) + return self._async_to_sync(lambda result: result.validate_response_output(message, allow_partial=allow_partial)) @property def is_complete(self) -> bool: @@ -735,7 +800,7 @@ def is_complete(self) -> bool: [`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or [`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes. """ - return self._streamed_run_result.is_complete + return self._async_to_sync(lambda result: result.is_complete) @dataclass(repr=False) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7d3f21fd33..41460a47df 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -11,6 +11,7 @@ import pytest from inline_snapshot import snapshot +from logfire.testing import CaptureLogfire from pydantic import BaseModel from pydantic_core import ErrorDetails @@ -46,12 +47,15 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import AgentStream, FinalResult, RunUsage +from pydantic_ai.result import AgentStream, FinalResult, RunUsage, StreamedRunResultSync from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolApproved, ToolDefinition from pydantic_ai.usage import RequestUsage from pydantic_graph import End -from .conftest import IsDatetime, IsInt, IsNow, IsStr +from .conftest import IsDatetime, IsInt, IsNow, IsStr, try_import + +with try_import() as logfire_imports_successful: + from logfire.testing import CaptureLogfire pytestmark = pytest.mark.anyio @@ -191,7 +195,7 @@ async def ret_a(x: str) -> str: RunUsage( requests=2, input_tokens=103, - output_tokens=5, + output_tokens=11, tool_calls=1, ) ) @@ -2141,12 +2145,12 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf FunctionToolResultEvent( result=RetryPromptPart( content=[ - ErrorDetails( - type='missing', - loc=('value',), - msg='Field required', - input={'bad_value': 'invalid'}, - ), + { + 'type': 'missing', + 'loc': ('value',), + 'msg': 'Field required', + 'input': {'bad_value': 'invalid'}, + }, ], tool_name='final_result', tool_call_id=IsStr(), @@ -2286,7 +2290,7 @@ def my_tool(x: int) -> int: assert await result.validate_response_output(responses[0]) == snapshot( DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())]) ) - assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=0)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51)) assert result.timestamp() == IsNow(tz=timezone.utc) assert result.is_complete @@ -2671,8 +2675,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content=[ 'This is file bd38f5:', ImageUrl( - url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg', - identifier='bd38f5', + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg' ), ], ), @@ -2801,3 +2804,98 @@ async def test_get_output_after_stream_output(): ), ] ) + + +async def test_streamed_run_result_sync(): + m = TestModel(custom_output_text='The cat sat on the mat.') + + agent = Agent(m) + + async with agent.run_stream('Hello') as result: + output = await result.get_output() + assert output == snapshot('The cat sat on the mat.') + result_sync = StreamedRunResultSync(result) + assert result_sync.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='The cat sat on the mat.')], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + run_id=IsStr(), + ), + ] + ) + + +def test_stream_output_after_get_output_sync(): + m = TestModel() + + agent = Agent(m, output_type=bool) + + result = agent.run_stream_sync('Hello') + + assert result.get_output() == snapshot(False) + assert [c for c in result.stream_output()] == snapshot([False]) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'response': False}, + tool_call_id='pyd_ai_tool_call_id__final_result', + ) + ], + usage=RequestUsage(input_tokens=51), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='pyd_ai_tool_call_id__final_result', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ] + ) + + +@pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed') +def test_run_stream_sync_instrumentation(capfire: CaptureLogfire): + m = TestModel() + + agent = Agent(m, instrument=True) + + result = agent.run_stream_sync('Hello') + output = [c for c in result.stream_output()] + assert output == snapshot(['success (no tool calls)']) + + assert capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot()