Skip to content
23 changes: 23 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from types import TracebackType
from typing import Any, Protocol, overload

import anyio.lowlevel
Expand Down Expand Up @@ -108,6 +109,8 @@ class ClientSession(
types.ServerNotification,
]
):
_entered: bool

def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
Expand Down Expand Up @@ -140,11 +143,31 @@ def __init__(
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
self._entered = False

# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()

async def __aenter__(self) -> "ClientSession":
self._entered = True
await super().__aenter__()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self._entered = False
await super().__aexit__(exc_type, exc_value, traceback)

def _check_is_active(self) -> None:
if not self._entered:
raise RuntimeError("ClientSession must be used within an 'async with' block.")

async def initialize(self) -> types.InitializeResult:
self._check_is_active()
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
if self._sampling_callback is not _default_sampling_callback
Expand Down
19 changes: 18 additions & 1 deletion tests/client/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, cast

import anyio
import pytest
Expand Down Expand Up @@ -768,3 +768,20 @@ async def mock_server():
await session.initialize()

await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)


@pytest.mark.anyio
async def test_initialize_without_context_manager_raises_error():
"""
Test that calling initialize() without entering the context manager raises RuntimeError.
"""
send_stream, receive_stream = anyio.create_memory_object_stream[Any](0)

read_stream = cast(Any, receive_stream)
write_stream = cast(Any, send_stream)

async with send_stream, receive_stream:
session = ClientSession(read_stream, write_stream)

with pytest.raises(RuntimeError, match="must be used within"):
await session.initialize()