diff --git a/httpx/_client.py b/httpx/_client.py index 13cd933673..22c6a88de6 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -6,6 +6,7 @@ import time import typing import warnings +import weakref from contextlib import asynccontextmanager, contextmanager from types import TracebackType @@ -140,13 +141,14 @@ class BoundSyncStream(SyncByteStream): """ A byte stream that is bound to a given response instance, and that ensures the `response.elapsed` is set once the response is closed. + Uses weakref to avoid reference cycles with the response object. """ def __init__( self, stream: SyncByteStream, response: Response, start: float ) -> None: self._stream = stream - self._response = response + self._response_ref: weakref.ref[Response] = weakref.ref(response) self._start = start def __iter__(self) -> typing.Iterator[bytes]: @@ -155,7 +157,9 @@ def __iter__(self) -> typing.Iterator[bytes]: def close(self) -> None: elapsed = time.perf_counter() - self._start - self._response.elapsed = datetime.timedelta(seconds=elapsed) + response = self._response_ref() + if response is not None: + response.elapsed = datetime.timedelta(seconds=elapsed) self._stream.close() @@ -163,13 +167,14 @@ class BoundAsyncStream(AsyncByteStream): """ An async byte stream that is bound to a given response instance, and that ensures the `response.elapsed` is set once the response is closed. + Uses weakref to avoid reference cycles with the response object. """ def __init__( self, stream: AsyncByteStream, response: Response, start: float ) -> None: self._stream = stream - self._response = response + self._response_ref: weakref.ref[Response] = weakref.ref(response) self._start = start async def __aiter__(self) -> typing.AsyncIterator[bytes]: @@ -178,7 +183,9 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: async def aclose(self) -> None: elapsed = time.perf_counter() - self._start - self._response.elapsed = datetime.timedelta(seconds=elapsed) + response = self._response_ref() + if response is not None: + response.elapsed = datetime.timedelta(seconds=elapsed) await self._stream.aclose() diff --git a/tests/test_bound_stream.py b/tests/test_bound_stream.py new file mode 100644 index 0000000000..0422f64225 --- /dev/null +++ b/tests/test_bound_stream.py @@ -0,0 +1,100 @@ +""" +Tests for BoundSyncStream and BoundAsyncStream weakref behavior. +These tests verify that the streams properly break reference cycles +to allow garbage collection. +""" + +import gc +import typing +import weakref + +import pytest + +import httpx +from httpx._client import BoundAsyncStream, BoundSyncStream +from httpx._types import AsyncByteStream, SyncByteStream + + +class MockSyncStream(SyncByteStream): + def __init__(self) -> None: + self.closed = False + + def __iter__(self) -> typing.Iterator[bytes]: # pragma: no cover + yield b"test" + + def close(self) -> None: + self.closed = True + + +class MockAsyncStream(AsyncByteStream): + def __init__(self) -> None: + self.closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: # pragma: no cover + yield b"test" + + async def aclose(self) -> None: + self.closed = True + + +def test_bound_sync_stream_sets_elapsed(): + response = httpx.Response(200, content=b"") + stream = MockSyncStream() + bound_stream = BoundSyncStream(stream, response=response, start=0.0) + bound_stream.close() + assert hasattr(response, "_elapsed") + assert response.elapsed.total_seconds() >= 0 + + +def test_bound_sync_stream_handles_collected_response(): + response = httpx.Response(200, content=b"") + stream = MockSyncStream() + bound_stream = BoundSyncStream(stream, response=response, start=0.0) + del response + gc.collect() + bound_stream.close() + assert stream.closed + + +def test_bound_sync_stream_no_reference_cycle(): + response = httpx.Response(200, content=b"") + response_ref = weakref.ref(response) + stream = MockSyncStream() + bound_stream = BoundSyncStream(stream, response=response, start=0.0) + response.stream = bound_stream + del response + gc.collect() + assert response_ref() is None, "Response should have been garbage collected" + + +@pytest.mark.anyio +async def test_bound_async_stream_sets_elapsed(): + response = httpx.Response(200, content=b"") + stream = MockAsyncStream() + bound_stream = BoundAsyncStream(stream, response=response, start=0.0) + await bound_stream.aclose() + assert hasattr(response, "_elapsed") + assert response.elapsed.total_seconds() >= 0 + + +@pytest.mark.anyio +async def test_bound_async_stream_handles_collected_response(): + response = httpx.Response(200, content=b"") + stream = MockAsyncStream() + bound_stream = BoundAsyncStream(stream, response=response, start=0.0) + del response + gc.collect() + await bound_stream.aclose() + assert stream.closed + + +@pytest.mark.anyio +async def test_bound_async_stream_no_reference_cycle(): + response = httpx.Response(200, content=b"") + response_ref = weakref.ref(response) + stream = MockAsyncStream() + bound_stream = BoundAsyncStream(stream, response=response, start=0.0) + response.stream = bound_stream + del response + gc.collect() + assert response_ref() is None, "Response should have been garbage collected"