diff --git a/agent-schema.json b/agent-schema.json index b73775acb..48325f409 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -812,6 +812,10 @@ "type": "string", "description": "A comma-delimited list of regular expressions of tools to toonify" }, + "model": { + "type": "string", + "description": "Model to use for the LLM turn that processes tool results from this toolset. Enables per-tool model routing: cheaper/faster models handle simple tool results (e.g. knowledge-base lookups, file reads) while the agent's primary model handles complex reasoning. Value can be a model name from the models section or an inline provider/model format (e.g. 'openai/gpt-4o-mini')." + }, "ref": { "type": "string", "description": "Reference to a Docker MCP tool (e.g., 'docker:context7') or a named MCP definition from the top-level 'mcps' section" diff --git a/examples/per_tool_model_routing.yaml b/examples/per_tool_model_routing.yaml new file mode 100644 index 000000000..9738d59d8 --- /dev/null +++ b/examples/per_tool_model_routing.yaml @@ -0,0 +1,40 @@ +#!/usr/bin/env docker agent run + +# Per-Tool Model Routing Example +# +# This example demonstrates how to use the `model` field on toolsets +# to automatically route specific tool results through a cheaper/faster +# model, while keeping the agent's primary model for complex reasoning. +# +# When the LLM calls a tool from a toolset with a `model` field, the +# next LLM turn (processing the tool results) uses the specified model +# instead of the agent's primary model. This is a one-shot override: +# subsequent turns return to the primary model. + +models: + primary: + provider: anthropic + model: claude-sonnet-4-5 + fast: + provider: anthropic + model: claude-haiku-4-5 + +agents: + root: + model: primary + description: > + An assistant that uses a fast model for simple tool operations + and the primary model for complex reasoning. + instruction: > + You are a helpful assistant. Use the available tools to help the user. + toolsets: + # The filesystem toolset uses the fast model to process results. + # Reading files and listing directories are simple operations that + # don't need the most capable model to interpret. + - type: filesystem + model: fast + + # The shell toolset also uses the fast model. Most shell command + # outputs (ls, cat, grep, etc.) are straightforward to interpret. + - type: shell + model: fast diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 92b19ee42..c4abd8aaf 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -556,6 +556,11 @@ type Toolset struct { Instruction string `json:"instruction,omitempty"` Toon string `json:"toon,omitempty"` + // Model overrides the LLM used for the turn that processes tool results + // from this toolset, enabling per-toolset model routing. Value can be a + // model name from the models section or "provider/model" (e.g. "openai/gpt-4o-mini"). + Model string `json:"model,omitempty"` + Defer DeferConfig `json:"defer" yaml:"defer,omitempty"` // For the `mcp` tool diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index aca8b3383..bc8bf461e 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -117,23 +117,45 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st return nil } - // Try parsing as inline spec (provider/model) + // Try single inline spec (provider/model) + prov, err := r.resolveModelRef(ctx, modelRef) + if err != nil { + return fmt.Errorf("failed to resolve model %q: %w", modelRef, err) + } + a.SetModelOverride(prov) + slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID()) + return nil +} + +// resolveModelRef resolves a model reference to a single provider. +// The reference can be a named model from the config or an inline +// "provider/model" spec (e.g. "openai/gpt-4o-mini"). +func (r *LocalRuntime) resolveModelRef(ctx context.Context, modelRef string) (provider.Provider, error) { + if r.modelSwitcherCfg == nil { + return nil, fmt.Errorf("model switching not configured for this runtime") + } + + // Try named model from config first. + if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { + if isAlloyModelConfig(modelCfg) { + return nil, fmt.Errorf("model reference %q is an alloy (multi-model) config and cannot be used as a single model override", modelRef) + } + modelCfg.Name = modelRef + return r.createProviderFromConfig(ctx, &modelCfg) + } + + // Try inline "provider/model" format. providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) + if !ok || providerName == "" || modelName == "" { + return nil, fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) } inlineCfg := &latest.ModelConfig{ Provider: providerName, Model: modelName, } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) - if err != nil { - return fmt.Errorf("failed to create inline model: %w", err) - } - a.SetModelOverride(prov) - slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID()) - return nil + + return r.createProviderFromConfig(ctx, inlineCfg) } // isAlloyModelConfig checks if a model config is an alloy model (multiple models). diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index 477c5d29f..a4dd0038c 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -383,3 +383,58 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) { assert.NotEqual(t, "openai/gpt-4o", c.Ref, "should not include duplicates from config") } } + +func TestResolveModelRef_RejectsAlloyConfig(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{ + modelSwitcherCfg: &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + // Alloy config: no provider, comma-separated models + "alloy_model": {Model: "openai/gpt-4o,anthropic/claude-sonnet-4-0"}, + }, + }, + } + + _, err := r.resolveModelRef(t.Context(), "alloy_model") + require.Error(t, err) + assert.Contains(t, err.Error(), "alloy") +} + +func TestResolveModelRef_NilConfig(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{} + + _, err := r.resolveModelRef(t.Context(), "openai/gpt-4o") + require.Error(t, err) + assert.Contains(t, err.Error(), "not configured") +} + +func TestResolveModelRef_InvalidFormat(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{ + modelSwitcherCfg: &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{}, + }, + } + + tests := []struct { + name string + modelRef string + }{ + {"no slash", "invalid"}, + {"empty provider", "/model"}, + {"empty model", "provider/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := r.resolveModelRef(t.Context(), tt.modelRef) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid model reference") + }) + } +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 9e0956b13..497cd38ff 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -1043,9 +1043,21 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Use a runtime copy of maxIterations so we don't modify the session's persistent config runtimeMaxIterations := sess.MaxIterations + // toolModelOverride holds the per-toolset model from the most recent + // tool calls. It applies for one LLM turn, then resets. + var toolModelOverride string + var prevAgentName string + for { a = r.resolveSessionAgent(sess) + // Clear per-tool model override on agent switch so it doesn't + // leak from one agent's toolset into another agent's turn. + if a.Name() != prevAgentName { + toolModelOverride = "" + prevAgentName = a.Name() + } + r.emitAgentWarnings(a, events) r.configureToolsetHandlers(a, events) @@ -1118,6 +1130,21 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c )) model := a.Model() + defaultModelID := r.getEffectiveModelID(a) + + // Per-tool model routing: use a cheaper model for this turn + // if the previous tool calls specified one, then reset. + if toolModelOverride != "" { + if overrideModel, err := r.resolveModelRef(ctx, toolModelOverride); err != nil { + slog.Warn("Failed to resolve per-tool model override; using agent default", + "model_override", toolModelOverride, "error", err) + } else { + slog.Info("Using per-tool model override for this turn", + "agent", a.Name(), "override", overrideModel.ID(), "primary", model.ID()) + model = overrideModel + } + toolModelOverride = "" + } // Apply thinking setting based on session state. // When thinking is disabled: clone with thinking=false to clear any thinking config. @@ -1135,6 +1162,12 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c } modelID := model.ID() + + // Notify sidebar when this turn uses a different model (per-tool override). + if modelID != defaultModelID { + events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage()) + } + slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) m, err := r.modelsStore.GetModel(ctx, modelID) @@ -1209,10 +1242,16 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } - // Update model info if we used a fallback + // Update sidebar model info to reflect what was actually used this turn. + // Fallback models are sticky (cooldown system persists them), so we only + // emit once. Per-tool model overrides are temporary (one turn), so we + // emit the override and then revert to the agent's default. if usedModel != nil && usedModel.ID() != model.ID() { slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage()) + } else if model.ID() != defaultModelID { + // Per-tool override was active: revert sidebar to the agent's default model. + events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage()) } streamSpan.SetAttributes( attribute.Int("tool.calls", len(res.Calls)), @@ -1294,6 +1333,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c r.processToolCalls(ctx, sess, res.Calls, agentTools, events) + // Record per-toolset model override for the next LLM turn. + toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools) + if res.Stopped { slog.Debug("Conversation stopped", "agent", a.Name()) break diff --git a/pkg/runtime/tool_model_override.go b/pkg/runtime/tool_model_override.go new file mode 100644 index 000000000..2c0cbcccb --- /dev/null +++ b/pkg/runtime/tool_model_override.go @@ -0,0 +1,31 @@ +package runtime + +import ( + "log/slog" + + "github.com/docker/docker-agent/pkg/tools" +) + +// resolveToolCallModelOverride returns the per-toolset model override from the +// given tool calls, or "" if none. When multiple tools specify different +// overrides, the first one wins. +func resolveToolCallModelOverride(calls []tools.ToolCall, agentTools []tools.Tool) string { + if len(calls) == 0 { + return "" + } + + toolMap := make(map[string]tools.Tool, len(agentTools)) + for _, t := range agentTools { + toolMap[t.Name] = t + } + + for _, call := range calls { + if t, ok := toolMap[call.Function.Name]; ok && t.ModelOverride != "" { + slog.Debug("Per-tool model override detected", + "tool", call.Function.Name, "model", t.ModelOverride) + return t.ModelOverride + } + } + + return "" +} diff --git a/pkg/runtime/tool_model_override_test.go b/pkg/runtime/tool_model_override_test.go new file mode 100644 index 000000000..bdbc0c250 --- /dev/null +++ b/pkg/runtime/tool_model_override_test.go @@ -0,0 +1,82 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/tools" +) + +func TestResolveToolCallModelOverride_NoCalls(t *testing.T) { + result := resolveToolCallModelOverride(nil, nil) + assert.Empty(t, result) +} + +func TestResolveToolCallModelOverride_NoOverride(t *testing.T) { + agentTools := []tools.Tool{ + {Name: "read_file"}, + {Name: "write_file"}, + } + calls := []tools.ToolCall{ + {Function: tools.FunctionCall{Name: "read_file"}}, + } + + result := resolveToolCallModelOverride(calls, agentTools) + assert.Empty(t, result) +} + +func TestResolveToolCallModelOverride_SingleOverride(t *testing.T) { + agentTools := []tools.Tool{ + {Name: "read_file", ModelOverride: "openai/gpt-4o-mini"}, + {Name: "write_file"}, + } + calls := []tools.ToolCall{ + {Function: tools.FunctionCall{Name: "read_file"}}, + } + + result := resolveToolCallModelOverride(calls, agentTools) + assert.Equal(t, "openai/gpt-4o-mini", result) +} + +func TestResolveToolCallModelOverride_FirstOverrideWins(t *testing.T) { + agentTools := []tools.Tool{ + {Name: "read_file", ModelOverride: "openai/gpt-4o-mini"}, + {Name: "search_kb", ModelOverride: "anthropic/claude-haiku"}, + } + calls := []tools.ToolCall{ + {Function: tools.FunctionCall{Name: "read_file"}}, + {Function: tools.FunctionCall{Name: "search_kb"}}, + } + + result := resolveToolCallModelOverride(calls, agentTools) + assert.Equal(t, "openai/gpt-4o-mini", result) +} + +func TestResolveToolCallModelOverride_MixedOverrideAndNonOverride(t *testing.T) { + agentTools := []tools.Tool{ + {Name: "read_file"}, + {Name: "search_kb", ModelOverride: "openai/gpt-4o-mini"}, + } + calls := []tools.ToolCall{ + {Function: tools.FunctionCall{Name: "read_file"}}, + {Function: tools.FunctionCall{Name: "search_kb"}}, + } + + // read_file has no override, search_kb does. Since read_file is first + // but has no override, we skip it and use search_kb's. + result := resolveToolCallModelOverride(calls, agentTools) + assert.Equal(t, "openai/gpt-4o-mini", result) +} + +func TestResolveToolCallModelOverride_UnknownTool(t *testing.T) { + agentTools := []tools.Tool{ + {Name: "read_file"}, + } + calls := []tools.ToolCall{ + {Function: tools.FunctionCall{Name: "unknown_tool"}}, + } + + result := resolveToolCallModelOverride(calls, agentTools) + assert.Empty(t, result) +} diff --git a/pkg/teamloader/model_override.go b/pkg/teamloader/model_override.go new file mode 100644 index 000000000..33de47900 --- /dev/null +++ b/pkg/teamloader/model_override.go @@ -0,0 +1,53 @@ +package teamloader + +import ( + "context" + + "github.com/docker/docker-agent/pkg/tools" +) + +// WithModelOverride wraps a toolset so that every tool it produces carries the +// given model in its ModelOverride field, enabling per-toolset model routing. +func WithModelOverride(inner tools.ToolSet, model string) tools.ToolSet { + if model == "" { + return inner + } + + return &modelOverrideToolset{ + ToolSet: inner, + model: model, + } +} + +type modelOverrideToolset struct { + tools.ToolSet + model string +} + +var ( + _ tools.Instructable = (*modelOverrideToolset)(nil) + _ tools.Unwrapper = (*modelOverrideToolset)(nil) +) + +func (m *modelOverrideToolset) Unwrap() tools.ToolSet { + return m.ToolSet +} + +func (m *modelOverrideToolset) Instructions() string { + return tools.GetInstructions(m.ToolSet) +} + +func (m *modelOverrideToolset) Tools(ctx context.Context) ([]tools.Tool, error) { + innerTools, err := m.ToolSet.Tools(ctx) + if err != nil { + return nil, err + } + + result := make([]tools.Tool, len(innerTools)) + for i, t := range innerTools { + t.ModelOverride = m.model + result[i] = t + } + + return result, nil +} diff --git a/pkg/teamloader/model_override_test.go b/pkg/teamloader/model_override_test.go new file mode 100644 index 000000000..e99cb0ad4 --- /dev/null +++ b/pkg/teamloader/model_override_test.go @@ -0,0 +1,93 @@ +package teamloader + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/tools" +) + +func TestWithModelOverride_Empty(t *testing.T) { + inner := &mockToolSet{ + toolsFunc: func(_ context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "read_file"}}, nil + }, + } + + // Empty model string should return the inner toolset as-is. + wrapped := WithModelOverride(inner, "") + assert.Same(t, inner, wrapped) +} + +func TestWithModelOverride_SetsModelOnTools(t *testing.T) { + inner := &mockToolSet{ + toolsFunc: func(_ context.Context) ([]tools.Tool, error) { + return []tools.Tool{ + {Name: "read_file"}, + {Name: "write_file"}, + }, nil + }, + } + + wrapped := WithModelOverride(inner, "openai/gpt-4o-mini") + result, err := wrapped.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, result, 2) + assert.Equal(t, "openai/gpt-4o-mini", result[0].ModelOverride) + assert.Equal(t, "openai/gpt-4o-mini", result[1].ModelOverride) +} + +func TestWithModelOverride_DoesNotMutateOriginal(t *testing.T) { + inner := &mockToolSet{ + toolsFunc: func(_ context.Context) ([]tools.Tool, error) { + return []tools.Tool{ + {Name: "read_file"}, + }, nil + }, + } + + wrapped := WithModelOverride(inner, "openai/gpt-4o-mini") + result, err := wrapped.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, "openai/gpt-4o-mini", result[0].ModelOverride) + + // Original tools should be unaffected since we copy. + originalTools, err := inner.Tools(t.Context()) + require.NoError(t, err) + assert.Empty(t, originalTools[0].ModelOverride) +} + +func TestWithModelOverride_Unwrap(t *testing.T) { + inner := &mockToolSet{ + toolsFunc: func(_ context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "read_file"}}, nil + }, + } + + wrapped := WithModelOverride(inner, "openai/gpt-4o-mini") + + unwrapper, ok := wrapped.(tools.Unwrapper) + require.True(t, ok) + assert.Same(t, inner, unwrapper.Unwrap()) +} + +func TestWithModelOverride_Instructions(t *testing.T) { + inner := &instructableToolSet{ + mockToolSet: mockToolSet{ + toolsFunc: func(_ context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "read_file"}}, nil + }, + }, + instructions: "Use this for file operations", + } + + wrapped := WithModelOverride(inner, "openai/gpt-4o-mini") + + inst, ok := wrapped.(tools.Instructable) + require.True(t, ok) + assert.Equal(t, "Use this for file operations", inst.Instructions()) +} diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 7a8dcfa58..adf6bfffd 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -448,6 +448,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri wrapped := WithToolsFilter(tool, toolset.Tools...) wrapped = WithInstructions(wrapped, toolset.Instruction) wrapped = WithToon(wrapped, toolset.Toon) + wrapped = WithModelOverride(wrapped, toolset.Model) // Handle deferred tools if !toolset.Defer.IsEmpty() { diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index f4128a3ec..fbe854522 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -95,6 +95,9 @@ type Tool struct { OutputSchema any `json:"outputSchema"` Handler ToolHandler `json:"-"` AddDescriptionParameter bool `json:"-"` + // ModelOverride is the per-toolset model for the LLM turn that processes + // this tool's results. Set automatically from the toolset "model" field. + ModelOverride string `json:"-"` } type ToolAnnotations mcp.ToolAnnotations