diff --git a/python/copilot/client.py b/python/copilot/client.py index 29cdf81d..4e50f44f 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -23,7 +23,7 @@ import uuid from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Any, cast +from typing import Any, cast, overload from .generated.rpc import ServerRpc from .generated.session_events import PermissionRequest, session_event_from_dict @@ -53,6 +53,8 @@ ToolResult, ) +HandlerUnsubcribe = Callable[[], None] + NO_RESULT_PERMISSION_V2_ERROR = ( "Permission handlers cannot return 'no-result' when connected to a protocol v2 server." ) @@ -1097,11 +1099,20 @@ async def set_foreground_session_id(self, session_id: str) -> None: error = response.get("error", "Unknown error") raise RuntimeError(f"Failed to set foreground session: {error}") + @overload + def on(self, handler: SessionLifecycleHandler, /) -> HandlerUnsubcribe: ... + + @overload + def on( + self, event_type: SessionLifecycleEventType, /, handler: SessionLifecycleHandler + ) -> HandlerUnsubcribe: ... + def on( self, event_type_or_handler: SessionLifecycleEventType | SessionLifecycleHandler, + /, handler: SessionLifecycleHandler | None = None, - ) -> Callable[[], None]: + ) -> HandlerUnsubcribe: """ Subscribe to session lifecycle events.