From 20bfac9287c4009953f01cf91eb175e5bb827410 Mon Sep 17 00:00:00 2001 From: James Austin Date: Tue, 28 Oct 2025 19:03:16 -0700 Subject: [PATCH 1/3] Abort streams --- src/replit_river/client_session.py | 3 +++ src/replit_river/server_session.py | 3 +++ src/replit_river/session.py | 12 +++++++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index ed498847..5b493291 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -99,6 +99,7 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws() except ConnectionClosed: + self._abort_all_streams() if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -106,10 +107,12 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. + self._abort_all_streams() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) except Exception: + self._abort_all_streams() logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index c397e900..755fb8d6 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -87,6 +87,7 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws(tg) except ConnectionClosed: + self._abort_all_streams() if self._retry_connection_callback: self._task_manager.create_task( self._retry_connection_callback() @@ -96,10 +97,12 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. + self._abort_all_streams() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) except Exception: + self._abort_all_streams() logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 465a6672..b27f09e7 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -286,6 +286,14 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) + def _abort_all_streams(self) -> None: + """Close all active stream channels, notifying any waiting consumers.""" + if not self._streams: + return + for stream in self._streams.values(): + stream.close() + self._streams.clear() + async def close(self) -> None: """Close the session and all associated streams.""" logger.info( @@ -310,9 +318,7 @@ async def close(self) -> None: # TODO: unexpected_close should close stream differently here to # throw exception correctly. - for stream in self._streams.values(): - stream.close() - self._streams.clear() + self._abort_all_streams() self._state = SessionState.CLOSED From 248cdf1ffc93f6151364dd73e29666f7d8b28e85 Mon Sep 17 00:00:00 2001 From: James Austin Date: Tue, 28 Oct 2025 20:45:11 -0700 Subject: [PATCH 2/3] Fix 2 --- src/replit_river/client_session.py | 6 ++-- src/replit_river/server_session.py | 6 ++-- src/replit_river/server_transport.py | 44 ++++++++++++++++++++-------- src/replit_river/session.py | 3 ++ 4 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 5b493291..99160b57 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -99,7 +99,8 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws() except ConnectionClosed: - self._abort_all_streams() + if self._should_abort_streams_after_transport_failure(): + self._abort_all_streams() if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -107,7 +108,8 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. - self._abort_all_streams() + if self._should_abort_streams_after_transport_failure(): + self._abort_all_streams() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 755fb8d6..4eaf6950 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -87,7 +87,8 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws(tg) except ConnectionClosed: - self._abort_all_streams() + if self._should_abort_streams_after_transport_failure(): + self._abort_all_streams() if self._retry_connection_callback: self._task_manager.create_task( self._retry_connection_callback() @@ -97,7 +98,8 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. - self._abort_all_streams() + if self._should_abort_streams_after_transport_failure(): + self._abort_all_streams() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 3f743e51..ca65183f 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -153,18 +153,35 @@ async def _get_or_create_session( close_session_callback=self._delete_session, ) else: - # If the instance id is the same, we reuse the session and assign - # a new websocket to it. - logger.debug( - 'Reuse old session with "%s" using new ws: %s', - to_id, - websocket.id, - ) - try: - await old_session.replace_with_new_websocket(websocket) - new_session = old_session - except FailedSendingMessageException as e: - raise e + if not await old_session.is_session_open(): + logger.info( + 'Session "%s" is not active, creating replacement ' + "session %s instead of reusing", + to_id, + session_id, + ) + new_session = ServerSession( + transport_id, + to_id, + session_id, + websocket, + self._transport_options, + self._handlers, + close_session_callback=self._delete_session, + ) + else: + # If the instance id is the same, we reuse the session and assign + # a new websocket to it. + logger.debug( + 'Reuse old session with "%s" using new ws: %s', + to_id, + websocket.id, + ) + try: + await old_session.replace_with_new_websocket(websocket) + new_session = old_session + except FailedSendingMessageException as e: + raise e self._sessions[new_session._to_id] = new_session @@ -311,5 +328,6 @@ async def _establish_handshake( async def _delete_session(self, session: Session) -> None: async with self._session_lock: - if session._to_id in self._sessions: + existing_session = self._sessions.get(session._to_id) + if existing_session is session: del self._sessions[session._to_id] diff --git a/src/replit_river/session.py b/src/replit_river/session.py index b27f09e7..2980b561 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -286,6 +286,9 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) + def _should_abort_streams_after_transport_failure(self) -> bool: + return self._retry_connection_callback is None + def _abort_all_streams(self) -> None: """Close all active stream channels, notifying any waiting consumers.""" if not self._streams: From 80ce2dc2bce8a688b94b8ebb1ca592349670e07f Mon Sep 17 00:00:00 2001 From: James Austin Date: Wed, 29 Oct 2025 10:28:57 -0700 Subject: [PATCH 3/3] Incorporated Jacky's feedback --- src/replit_river/client_session.py | 6 ++-- src/replit_river/server_session.py | 5 ---- src/replit_river/server_transport.py | 44 ++++++++-------------------- src/replit_river/session.py | 2 +- 4 files changed, 17 insertions(+), 40 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 99160b57..d703c87b 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -100,7 +100,7 @@ async def serve(self) -> None: await self._handle_messages_from_ws() except ConnectionClosed: if self._should_abort_streams_after_transport_failure(): - self._abort_all_streams() + await self.close() if self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) @@ -109,12 +109,12 @@ async def serve(self) -> None: except FailedSendingMessageException: # Expected error if the connection is closed. if self._should_abort_streams_after_transport_failure(): - self._abort_all_streams() + await self.close() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) except Exception: - self._abort_all_streams() + await self.close() logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 4eaf6950..c397e900 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -87,8 +87,6 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws(tg) except ConnectionClosed: - if self._should_abort_streams_after_transport_failure(): - self._abort_all_streams() if self._retry_connection_callback: self._task_manager.create_task( self._retry_connection_callback() @@ -98,13 +96,10 @@ async def serve(self) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. - if self._should_abort_streams_after_transport_failure(): - self._abort_all_streams() logger.debug( "FailedSendingMessageException while serving", exc_info=True ) except Exception: - self._abort_all_streams() logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index ca65183f..3f743e51 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -153,35 +153,18 @@ async def _get_or_create_session( close_session_callback=self._delete_session, ) else: - if not await old_session.is_session_open(): - logger.info( - 'Session "%s" is not active, creating replacement ' - "session %s instead of reusing", - to_id, - session_id, - ) - new_session = ServerSession( - transport_id, - to_id, - session_id, - websocket, - self._transport_options, - self._handlers, - close_session_callback=self._delete_session, - ) - else: - # If the instance id is the same, we reuse the session and assign - # a new websocket to it. - logger.debug( - 'Reuse old session with "%s" using new ws: %s', - to_id, - websocket.id, - ) - try: - await old_session.replace_with_new_websocket(websocket) - new_session = old_session - except FailedSendingMessageException as e: - raise e + # If the instance id is the same, we reuse the session and assign + # a new websocket to it. + logger.debug( + 'Reuse old session with "%s" using new ws: %s', + to_id, + websocket.id, + ) + try: + await old_session.replace_with_new_websocket(websocket) + new_session = old_session + except FailedSendingMessageException as e: + raise e self._sessions[new_session._to_id] = new_session @@ -328,6 +311,5 @@ async def _establish_handshake( async def _delete_session(self, session: Session) -> None: async with self._session_lock: - existing_session = self._sessions.get(session._to_id) - if existing_session is session: + if session._to_id in self._sessions: del self._sessions[session._to_id] diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 2980b561..b4240b1a 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -287,7 +287,7 @@ async def close_websocket( self._task_manager.create_task(self._retry_connection_callback()) def _should_abort_streams_after_transport_failure(self) -> bool: - return self._retry_connection_callback is None + return not self._transport_options.transparent_reconnect def _abort_all_streams(self) -> None: """Close all active stream channels, notifying any waiting consumers."""