diff --git a/pydantic_ai_slim/pydantic_ai/builtin_tools.py b/pydantic_ai_slim/pydantic_ai/builtin_tools.py index 1baaf09f3a..53075aa497 100644 --- a/pydantic_ai_slim/pydantic_ai/builtin_tools.py +++ b/pydantic_ai_slim/pydantic_ai/builtin_tools.py @@ -395,6 +395,33 @@ def unique_id(self) -> str: return ':'.join([self.kind, self.id]) +@dataclass(kw_only=True) +class ToolSearchTool(AbstractBuiltinTool): + """A builtin tool that searches for tools during dynamic tool discovery. + + To defer loading a tool's definition until the model finds it, mark it as `defer_loading=True`. + + Note that only models with `ModelProfile.supports_tool_search` use this builtin tool. These models receive all tool + definitions and natively implement search and loading. All other models rely on `SearchableToolset` instead. + + Supported by: + + * Anthropic + + """ + + search_type: Literal['regex', 'bm25'] | None = None + """Custom search type to use for tool discovery. Currently only supported by Anthropic models. + + See https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/tool-search-tool for more info. + + - `'regex'`: Constructs Python `re.search()` patterns. Max 200 characters per query. Case-sensitive by default. + - `'bm25'`: Uses natural language queries with semantic matching across tool metadata. + """ + + kind: str = 'tool_search' + + def _tool_discriminator(tool_data: dict[str, Any] | AbstractBuiltinTool) -> str: if isinstance(tool_data, dict): return tool_data.get('kind', AbstractBuiltinTool.kind) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 28395f56bd..1e83d9d761 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import logging import io from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator from contextlib import asynccontextmanager @@ -13,7 +14,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._run_context import RunContext from .._utils import guard_tool_call_id as _guard_tool_call_id -from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool +from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, ToolSearchTool, WebFetchTool, WebSearchTool from ..exceptions import ModelAPIError, UserError from ..messages import ( BinaryContent, @@ -114,6 +115,9 @@ BetaToolParam, BetaToolResultBlockParam, BetaToolUnionParam, + BetaToolSearchToolBm25_20251119Param, + BetaToolSearchToolRegex20251119Param, + BetaToolSearchToolResultBlock, BetaToolUseBlock, BetaToolUseBlockParam, BetaWebFetchTool20250910Param, @@ -125,6 +129,7 @@ BetaWebSearchToolResultBlockParam, BetaWebSearchToolResultBlockParamContentParam, ) + from anthropic.types.beta.beta_tool_search_tool_result_block import BetaToolSearchToolResultBlock from anthropic.types.beta.beta_web_fetch_tool_result_block_param import ( Content as WebFetchToolResultBlockParamContent, ) @@ -511,6 +516,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: elif isinstance(item, BetaMCPToolResultBlock): call_part = builtin_tool_calls.get(item.tool_use_id) items.append(_map_mcp_server_result_block(item, call_part, self.system)) + elif isinstance(item, BetaToolSearchToolResultBlock): + call_part = builtin_tool_calls.get(item.tool_use_id) + items.append(_map_mcp_server_result_block(item, call_part, self.system)) else: assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}' items.append( @@ -577,6 +585,8 @@ def _add_builtin_tools( ) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]: beta_features: set[str] = set() mcp_servers: list[BetaRequestMCPServerURLDefinitionParam] = [] + tool_search_type: Literal['regex', 'bm25'] | None = None + for tool in model_request_parameters.builtin_tools: if isinstance(tool, WebSearchTool): user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None @@ -629,10 +639,32 @@ def _add_builtin_tools( mcp_server_url_definition_param['authorization_token'] = tool.authorization_token mcp_servers.append(mcp_server_url_definition_param) beta_features.add('mcp-client-2025-04-04') + elif isinstance(tool, ToolSearchTool): + tool_search_type = tool.search_type else: # pragma: no cover raise UserError( f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.' ) + + needs_tool_search = any(tool.get('defer_loading') for tool in tools) + + if needs_tool_search: + beta_features.add('advanced-tool-use-2025-11-20') + if tool_search_type == 'bm25': + tools.append( + BetaToolSearchToolBm25_20251119Param( + name='tool_search_tool_bm25', + type='tool_search_tool_bm25_20251119', + ) + ) + else: + tools.append( + BetaToolSearchToolRegex20251119Param( + name='tool_search_tool_regex', + type='tool_search_tool_regex_20251119', + ) + ) + return tools, mcp_servers, beta_features def _infer_tool_choice( @@ -1062,6 +1094,8 @@ def _map_tool_definition(self, f: ToolDefinition) -> BetaToolParam: 'description': f.description or '', 'input_schema': f.parameters_json_schema, } + if f.defer_loading: + tool_param['defer_loading'] = True if f.strict and self.profile.supports_json_schema_output: tool_param['strict'] = f.strict return tool_param @@ -1297,8 +1331,12 @@ def _map_server_tool_use_block(item: BetaServerToolUseBlock, provider_name: str) elif item.name in ('bash_code_execution', 'text_editor_code_execution'): # pragma: no cover raise NotImplementedError(f'Anthropic built-in tool {item.name!r} is not currently supported.') elif item.name in ('tool_search_tool_regex', 'tool_search_tool_bm25'): # pragma: no cover - # NOTE this is being implemented in https://github.com/pydantic/pydantic-ai/pull/3550 - raise NotImplementedError(f'Anthropic built-in tool {item.name!r} is not currently supported.') + return BuiltinToolCallPart( + provider_name=provider_name, + tool_name=ToolSearchTool.kind, + args=cast(dict[str, Any], item.input) or None, + tool_call_id=item.id, + ) else: assert_never(item.name) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 84a1c04012..1a44ba19d1 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -65,6 +65,9 @@ class ModelProfile: This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`. """ + supports_tool_search: bool = False + """Whether the model has native support for tool search (builtin ToolSearchTool) and defer loading tools.""" + @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py index 6a59ab2dec..bc76b4d5a9 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py @@ -23,6 +23,7 @@ def anthropic_model_profile(model_name: str) -> ModelProfile | None: thinking_tags=('', ''), supports_json_schema_output=supports_json_schema_output, json_schema_transformer=AnthropicJsonSchemaTransformer, + supports_tool_search=True, ) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index dcd860b019..feb315d795 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -274,6 +274,7 @@ class Tool(Generic[ToolAgentDepsT]): requires_approval: bool metadata: dict[str, Any] | None function_schema: _function_schema.FunctionSchema + defer_loading: bool """ The base JSON schema for the tool's parameters. @@ -297,6 +298,7 @@ def __init__( requires_approval: bool = False, metadata: dict[str, Any] | None = None, function_schema: _function_schema.FunctionSchema | None = None, + defer_loading: bool = False, ): """Create a new tool instance. @@ -353,6 +355,7 @@ async def prep_my_tool( See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info. metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization. function_schema: The function schema to use for the tool. If not provided, it will be generated. + defer_loading: If True, hide the tool by default and only activate it when the model searches for tools. """ self.function = function self.function_schema = function_schema or _function_schema.function_schema( @@ -373,6 +376,7 @@ async def prep_my_tool( self.sequential = sequential self.requires_approval = requires_approval self.metadata = metadata + self.defer_loading = defer_loading @classmethod def from_schema( @@ -429,6 +433,7 @@ def tool_def(self): sequential=self.sequential, metadata=self.metadata, kind='unapproved' if self.requires_approval else 'function', + defer_loading=self.defer_loading, ) async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None: @@ -514,6 +519,14 @@ class ToolDefinition: For MCP tools, this contains the `meta`, `annotations`, and `output_schema` fields from the tool definition. """ + defer_loading: bool = False + """Whether to defer loading this tool until it is discovered via tool search. + + When `True`, this tool will not be loaded into the model's context initially. + + Instead, the model will discover it on-demand when needed, reducing token usage. + """ + @property def defer(self) -> bool: """Whether calls to this tool will be deferred. diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index e185ed0273..0a227bf870 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -39,6 +39,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): docstring_format: DocstringFormat require_parameter_descriptions: bool schema_generator: type[GenerateJsonSchema] + defer_loading: bool def __init__( self, @@ -53,6 +54,7 @@ def __init__( requires_approval: bool = False, metadata: dict[str, Any] | None = None, id: str | None = None, + defer_loading: bool = False, ): """Build a new function toolset. @@ -78,6 +80,7 @@ def __init__( Applies to all tools, unless overridden when adding a tool, which will be merged with the toolset's metadata. id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. + defer_loading: If True, default to hiding each tool and only activate it when the model searches for tools. """ self.max_retries = max_retries self._id = id @@ -88,6 +91,7 @@ def __init__( self.sequential = sequential self.requires_approval = requires_approval self.metadata = metadata + self.defer_loading = defer_loading self.tools = {} for tool in tools: @@ -137,6 +141,7 @@ def tool( sequential: bool | None = None, requires_approval: bool | None = None, metadata: dict[str, Any] | None = None, + defer_loading: bool | None = None, ) -> Any: """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. @@ -193,6 +198,8 @@ async def spam(ctx: RunContext[str], y: float) -> float: If `None`, the default value is determined by the toolset. metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization. If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata. + defer_loading: If True, hide the tool by default and only activate it when the model searches for tools. + If `None`, the default value is determined by the toolset. """ def tool_decorator( @@ -213,6 +220,7 @@ def tool_decorator( sequential=sequential, requires_approval=requires_approval, metadata=metadata, + defer_loading=defer_loading, ) return func_ @@ -233,6 +241,7 @@ def add_function( sequential: bool | None = None, requires_approval: bool | None = None, metadata: dict[str, Any] | None = None, + defer_loading: bool | None = None, ) -> None: """Add a function as a tool to the toolset. @@ -267,6 +276,8 @@ def add_function( If `None`, the default value is determined by the toolset. metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization. If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata. + defer_loading: If True, hide the tool by default and only activate it when the model searches for tools. + If `None`, the default value is determined by the toolset. """ if docstring_format is None: docstring_format = self.docstring_format @@ -280,6 +291,8 @@ def add_function( sequential = self.sequential if requires_approval is None: requires_approval = self.requires_approval + if defer_loading is None: + defer_loading = self.defer_loading tool = Tool[AgentDepsT]( func, @@ -295,6 +308,7 @@ def add_function( sequential=sequential, requires_approval=requires_approval, metadata=metadata, + defer_loading=defer_loading, ) self.add_tool(tool) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py b/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py new file mode 100644 index 0000000000..be2b66fce1 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py @@ -0,0 +1,137 @@ +import logging +import re +from collections.abc import Callable +from dataclasses import dataclass, field, replace +from typing import Any, TypedDict + +from pydantic import TypeAdapter +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import AbstractToolset, SchemaValidatorProt, ToolsetTool + +_SEARCH_TOOL_NAME = 'load_tools' + + +class _SearchToolArgs(TypedDict): + regex: str + + +def _search_tool_def() -> ToolDefinition: + return ToolDefinition( + name=_SEARCH_TOOL_NAME, + description="""Search and load additional tools to make them available to the agent. + +DO call this to find and load more tools needed for a task. +NEVER ask the user if you should try loading tools, just try. +""", + parameters_json_schema={ + 'type': 'object', + 'properties': { + 'regex': { + 'type': 'string', + 'description': 'Regex pattern to search for relevant tools', + } + }, + 'required': ['regex'], + }, + ) + + +def _search_tool_validator() -> SchemaValidatorProt: + return TypeAdapter(_SearchToolArgs).validator + + +@dataclass +class _SearchTool(ToolsetTool[AgentDepsT]): + """A tool that searches for more relevant tools from a SearchableToolSet.""" + + tool_def: ToolDefinition = field(default_factory=_search_tool_def) + args_validator: SchemaValidatorProt = field(default_factory=_search_tool_validator) + + +@dataclass +class SearchableToolset(AbstractToolset[AgentDepsT]): + """A toolset that implements tool search and deferred tool loading.""" + + toolset: AbstractToolset[AgentDepsT] + _active_tool_names: set[str] = field(default_factory=set) + + @property + def id(self) -> str | None: + return None # pragma: no cover + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({self.toolset.label})' # pragma: no cover + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + # Models that support built-in tool search are exposed to all the tools as-is. + if ctx.model.profile.supports_tool_search: + logging.debug("SearchableToolset.get_tools ==> built-in") + return await self.toolset.get_tools(ctx) + + # Otherwise add a Pydantic tool search tool and selectively expose activated tools. + logging.debug("SearchableToolset.get_tools ==> pydantic") + all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} + + # If the model does not support tool search natively, add a Pydantic tool search tool. + all_tools[_SEARCH_TOOL_NAME] = _SearchTool( + toolset=self, + max_retries=1, + ) + + toolset_tools = await self.toolset.get_tools(ctx) + for tool_name, tool in toolset_tools.items(): + # TODO proper error handling + assert tool_name != _SEARCH_TOOL_NAME + + # Expose the tool unless it defers loading and is not yet active. + if not tool.tool_def.defer_loading or tool_name in self._active_tool_names: + all_tools[tool_name] = tool + + logging.debug(f"SearchableToolset.get_tools ==> {[t for t in all_tools]}") + return all_tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + if isinstance(tool, _SearchTool): + adapter = TypeAdapter(_SearchToolArgs) + typed_args = adapter.validate_python(tool_args) + result = await self.call_search_tool(typed_args, ctx) + logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") + return result + else: + result = await self.toolset.call_tool(name, tool_args, ctx, tool) + logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") + return result + + async def call_search_tool(self, args: _SearchToolArgs, ctx: RunContext[AgentDepsT]) -> list[str]: + """Searches for tools matching the query, activates them and returns their names.""" + toolset_tools = await self.toolset.get_tools(ctx) + matching_tool_names: list[str] = [] + + for tool_name, tool in toolset_tools.items(): + rx = re.compile(args['regex']) + if rx.search(tool.tool_def.name) or rx.search(tool.tool_def.description): + matching_tool_names.append(tool.tool_def.name) + + self._active_tool_names.update(matching_tool_names) + return matching_tool_names + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self.toolset.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return replace(self, toolset=self.toolset.visit_and_replace(visitor)) + + async def __aenter__(self) -> Self: + await self.toolset.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.toolset.__aexit__(*args) diff --git a/test_searchable_example.py b/test_searchable_example.py new file mode 100644 index 0000000000..12230042bf --- /dev/null +++ b/test_searchable_example.py @@ -0,0 +1,136 @@ +"""Minimal example to test SearchableToolset functionality. + +Run with: uv run python test_searchable_example.py +Make sure you have ANTHROPIC_API_KEY set in your environment. +""" + +import asyncio +import logging +import sys + +# Configure logging to print to stdout +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + stream=sys.stdout +) + +# Silence noisy loggers +logging.getLogger('asyncio').setLevel(logging.WARNING) +logging.getLogger('httpx').setLevel(logging.WARNING) +logging.getLogger('httpcore.connection').setLevel(logging.WARNING) +logging.getLogger('httpcore.http11').setLevel(logging.WARNING) +logging.getLogger('anthropic._base_client').setLevel(logging.WARNING) + +from pydantic_ai import Agent +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.toolsets.searchable import SearchableToolset + + +# Create a toolset with various tools +toolset = FunctionToolset() + + +@toolset.tool(defer_loading=True) +def get_weather(city: str) -> str: + """Get the current weather for a given city. + + Args: + city: The name of the city to get weather for. + """ + return f"The weather in {city} is sunny and 72°F" + + +@toolset.tool(defer_loading=True) +def calculate_sum(a: float, b: float) -> float: + """Add two numbers together. + + Args: + a: The first number. + b: The second number. + """ + return a + b + + +@toolset.tool(defer_loading=True) +def calculate_product(a: float, b: float) -> float: + """Multiply two numbers together. + + Args: + a: The first number. + b: The second number. + """ + return a * b + + +@toolset.tool(defer_loading=True) +def fetch_user_data(user_id: int) -> dict: + """Fetch user data from the database. + + Args: + user_id: The ID of the user to fetch. + """ + return {"id": user_id, "name": "John Doe", "email": "john@example.com"} + + +@toolset.tool(defer_loading=True) +def send_email(recipient: str, subject: str, body: str) -> str: + """Send an email to a recipient. + + Args: + recipient: The email address of the recipient. + subject: The subject line of the email. + body: The body content of the email. + """ + return f"Email sent to {recipient} with subject '{subject}'" + + +@toolset.tool(defer_loading=True) +def list_database_tables() -> list[str]: + """List all tables in the database.""" + return ["users", "orders", "products", "reviews"] + + +# Wrap the toolset with SearchableToolset +searchable_toolset = SearchableToolset(toolset=toolset) + +# Create an agent with the searchable toolset +agent = Agent( + 'anthropic:claude-sonnet-4-5', + toolsets=[searchable_toolset], + system_prompt=( + "You are a helpful assistant." + ), +) + + +async def main(): + print("=" * 60) + print("Testing SearchableToolset") + print("=" * 60) + print() + + # Test 1: Ask something that requires searching for calculation tools + print("Test 1: Calculation task") + print("-" * 60) + result = await agent.run("What is 123 multiplied by 456?") + print(f"Result: {result.output}") + print() + + # Test 2: Ask something that requires searching for database tools + print("\nTest 2: Database task") + print("-" * 60) + result = await agent.run("Can you list the database tables and then fetch user 42?") + print(f"Result: {result.output}") + print() + + # Test 3: Ask something that requires weather tool + print("\nTest 3: Weather task") + print("-" * 60) + result = await agent.run("What's the weather like in San Francisco? Search for a weather tool") + print(f"Result: {result.output}") + print() + + +if __name__ == "__main__": + asyncio.run(main())