|
1 | 1 | from __future__ import annotations as _annotations |
2 | 2 |
|
3 | | -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator |
| 3 | +import inspect |
| 4 | +from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine, Iterable, Iterator |
4 | 5 | from contextlib import asynccontextmanager |
5 | 6 | from copy import deepcopy |
6 | 7 | from dataclasses import dataclass, field, replace |
7 | 8 | from datetime import datetime |
8 | | -from typing import TYPE_CHECKING, Generic, cast, overload |
| 9 | +from typing import TYPE_CHECKING, Any, Generic, cast, overload |
9 | 10 |
|
10 | 11 | from pydantic import ValidationError |
11 | 12 | from typing_extensions import TypeVar, deprecated |
@@ -739,15 +740,14 @@ async def my_task(): |
739 | 740 |
|
740 | 741 | def _async_to_sync( |
741 | 742 | self, |
742 | | - func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], Awaitable[T]] |
743 | | - | Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], T], |
| 743 | + func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], Coroutine[Any, Any, T] | T], |
744 | 744 | ) -> T: |
745 | 745 | async def my_task(): |
746 | 746 | async with self._with_streamed_run_result() as result: |
747 | | - if _utils.is_async_callable(func): |
748 | | - return await func(result) |
749 | | - else: |
750 | | - return func(result) |
| 747 | + res = func(result) |
| 748 | + if inspect.isawaitable(res): |
| 749 | + res = await res |
| 750 | + return res |
751 | 751 |
|
752 | 752 | return _utils.get_event_loop().run_until_complete(my_task()) |
753 | 753 |
|
|
0 commit comments