diff --git a/pkg/model/provider/rulebased/client.go b/pkg/model/provider/rulebased/client.go index 8712c178e..213db2984 100644 --- a/pkg/model/provider/rulebased/client.go +++ b/pkg/model/provider/rulebased/client.go @@ -1,8 +1,5 @@ // Package rulebased provides a rule-based model router that selects -// the appropriate model based on NLP analysis of the input using Bleve. -// -// Routes are defined with example texts, and Bleve's full-text search -// determines the best matching route based on text similarity. +// the appropriate model based on text similarity using Bleve full-text search. // // A model becomes a rule-based router when it has routing rules configured. // The model's provider/model fields define the fallback model, and each @@ -43,17 +40,11 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri // Client implements the Provider interface for rule-based model routing. type Client struct { base.Config - routes []route + routes []Provider fallback Provider index bleve.Index } -// route represents a single routing rule. -type route struct { - model string - provider Provider -} - // NewClient creates a new rule-based routing client. // The cfg parameter should have Routing rules configured. The provider/model // fields of cfg define the fallback model that is used when no routing rule matches. @@ -69,11 +60,21 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l return nil, fmt.Errorf("creating bleve index: %w", err) } - // Create fallback provider from the model's provider/model fields + // On any subsequent error, close the index before returning. + var cleanupErr error + defer func() { + if cleanupErr != nil { + _ = index.Close() + } + }() + + routeOpts := filterOutMaxTokens(opts) + + // Create fallback provider from the model's provider/model fields. fallbackSpec := cfg.Provider + "/" + cfg.Model - fallback, err := providerFactory(ctx, fallbackSpec, models, env, filterOutMaxTokens(opts)...) + fallback, err := providerFactory(ctx, fallbackSpec, models, env, routeOpts...) if err != nil { - _ = index.Close() + cleanupErr = err return nil, fmt.Errorf("creating fallback provider %q: %w", fallbackSpec, err) } @@ -87,27 +88,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l fallback: fallback, } - // Process routing rules + // Process routing rules. Each example is indexed with a doc ID + // that encodes the route index (e.g. "r0_e1") so we can map + // search hits back to the corresponding provider. for i, rule := range cfg.Routing { if rule.Model == "" { - _ = index.Close() - return nil, fmt.Errorf("routing rule %d: 'model' field is required", i) + cleanupErr = fmt.Errorf("routing rule %d: 'model' field is required", i) + return nil, cleanupErr } - provider, err := providerFactory(ctx, rule.Model, models, env, filterOutMaxTokens(opts)...) + provider, err := providerFactory(ctx, rule.Model, models, env, routeOpts...) if err != nil { - _ = index.Close() + cleanupErr = err return nil, fmt.Errorf("creating provider for routing rule %q: %w", rule.Model, err) } routeIndex := len(client.routes) - client.routes = append(client.routes, route{model: rule.Model, provider: provider}) + client.routes = append(client.routes, provider) - // Index examples for this route for j, example := range rule.Examples { docID := fmt.Sprintf("r%d_e%d", routeIndex, j) - if err := index.Index(docID, map[string]any{"text": example, "route": routeIndex}); err != nil { - _ = index.Close() + if err := index.Index(docID, map[string]any{"text": example}); err != nil { + cleanupErr = err return nil, fmt.Errorf("indexing example: %w", err) } } @@ -124,7 +126,6 @@ func createIndex() (bleve.Index, error) { textField := mapping.NewTextFieldMapping() textField.Analyzer = "en" docMapping.AddFieldMappingsAt("text", textField) - docMapping.AddFieldMappingsAt("route", mapping.NewNumericFieldMapping()) indexMapping.DefaultMapping = docMapping @@ -132,19 +133,16 @@ func createIndex() (bleve.Index, error) { } // filterOutMaxTokens removes WithMaxTokens options from the slice. -// This is necessary because child providers may have different token limits -// than the parent router, and should determine their own limits. +// Child providers may have different token limits than the parent router. func filterOutMaxTokens(opts []options.Opt) []options.Opt { var filtered []options.Opt for _, opt := range opts { if opt == nil { continue } - // Test if this option sets maxTokens by applying it to an empty ModelOptions - var test options.ModelOptions - opt(&test) - // If maxTokens was set, skip this option - if test.MaxTokens() != 0 { + var probe options.ModelOptions + opt(&probe) + if probe.MaxTokens() != 0 { continue } filtered = append(filtered, opt) @@ -173,6 +171,7 @@ func (c *Client) CreateChatCompletionStream( } // selectProvider finds the best matching provider for the messages. +// Bleve returns hits sorted by score, so the top hit determines the route. func (c *Client) selectProvider(messages []chat.Message) Provider { userMessage := getLastUserMessage(messages) if userMessage == "" { @@ -183,8 +182,7 @@ func (c *Client) selectProvider(messages []chat.Message) Provider { query.SetField("text") searchRequest := bleve.NewSearchRequest(query) - searchRequest.Size = 10 - searchRequest.Fields = []string{"route"} + searchRequest.Size = 1 results, err := c.index.Search(searchRequest) if err != nil { @@ -196,33 +194,28 @@ func (c *Client) selectProvider(messages []chat.Message) Provider { return c.defaultProvider() } - // Find best matching route by aggregating scores - scores := make(map[int]float64) - for _, hit := range results.Hits { - var routeIdx int - if _, err := fmt.Sscanf(hit.ID, "r%d_e", &routeIdx); err == nil { - if hit.Score > scores[routeIdx] { - scores[routeIdx] = hit.Score - } - } + // Parse the route index from the top hit's doc ID (e.g. "r2_e0" → 2). + hit := results.Hits[0] + routeIdx, ok := parseRouteIndex(hit.ID) + if !ok || routeIdx >= len(c.routes) { + return c.defaultProvider() } - bestRoute, bestScore := -1, 0.0 - for idx, score := range scores { - if score > bestScore { - bestRoute, bestScore = idx, score - } - } + selected := c.routes[routeIdx] + slog.Debug("Route matched", + "model", selected.ID(), + "score", hit.Score, + ) + return selected +} - if bestRoute >= 0 && bestRoute < len(c.routes) { - slog.Debug("Route matched", - "model", c.routes[bestRoute].model, - "score", bestScore, - ) - return c.routes[bestRoute].provider +// parseRouteIndex extracts the route index from a doc ID like "r2_e0". +func parseRouteIndex(docID string) (int, bool) { + var idx int + if _, err := fmt.Sscanf(docID, "r%d_e", &idx); err != nil || idx < 0 { + return 0, false } - - return c.defaultProvider() + return idx, true } func (c *Client) defaultProvider() Provider { @@ -230,7 +223,7 @@ func (c *Client) defaultProvider() Provider { return c.fallback } if len(c.routes) > 0 { - return c.routes[0].provider + return c.routes[0] } return nil } diff --git a/pkg/model/provider/rulebased/client_test.go b/pkg/model/provider/rulebased/client_test.go index d918113cd..8ef3a2fee 100644 --- a/pkg/model/provider/rulebased/client_test.go +++ b/pkg/model/provider/rulebased/client_test.go @@ -40,11 +40,9 @@ func (m *mockProvider) BaseConfig() base.Config { // mockProviderFactory creates a mock provider factory for testing. // It resolves model references from the models map or parses inline specs. func mockProviderFactory(_ context.Context, modelSpec string, models map[string]latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { - // Check if it's a model reference if cfg, exists := models[modelSpec]; exists { return &mockProvider{id: cfg.Provider + "/" + cfg.Model}, nil } - // Otherwise treat as inline spec return &mockProvider{id: modelSpec}, nil } @@ -62,7 +60,7 @@ func TestNewClient(t *testing.T) { name: "valid config with routing rules", modelCfg: latest.ModelConfig{ Provider: "openai", - Model: "gpt-4o", // fallback + Model: "gpt-4o", Routing: []latest.RoutingRule{ { Model: "anthropic/claude-3-haiku", @@ -80,7 +78,7 @@ func TestNewClient(t *testing.T) { name: "routing with model references", modelCfg: latest.ModelConfig{ Provider: "anthropic", - Model: "claude-haiku-4-5", // fallback + Model: "claude-haiku-4-5", Routing: []latest.RoutingRule{ { Model: "fast", @@ -183,7 +181,7 @@ func TestClient_SelectProvider(t *testing.T) { cfg := &latest.ModelConfig{ Provider: "openai", - Model: "gpt-4o", // fallback + Model: "gpt-4o", Routing: []latest.RoutingRule{ { Model: "anthropic/claude-3-haiku", @@ -262,11 +260,9 @@ func TestCreateIndex(t *testing.T) { require.NoError(t, err) defer index.Close() - // Index a document - err = index.Index("test", map[string]any{"text": "hello world", "route": 0}) + err = index.Index("test", map[string]any{"text": "hello world"}) require.NoError(t, err) - // Search for it query := bleve.NewMatchQuery("hello") query.SetField("text") results, err := index.Search(bleve.NewSearchRequest(query)) @@ -298,10 +294,9 @@ func TestClient_ID(t *testing.T) { func TestClient_DefaultProvider(t *testing.T) { t.Parallel() - // Test that fallback is always used for empty messages cfg := &latest.ModelConfig{ Provider: "openai", - Model: "gpt-4o", // fallback + Model: "gpt-4o", Routing: []latest.RoutingRule{ { Model: "anthropic/claude-3-haiku", @@ -314,7 +309,6 @@ func TestClient_DefaultProvider(t *testing.T) { require.NoError(t, err) defer client.Close() - // Empty message should use fallback provider := client.selectProvider(nil) assert.Equal(t, "openai/gpt-4o", provider.ID()) } @@ -322,8 +316,6 @@ func TestClient_DefaultProvider(t *testing.T) { func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) { t.Parallel() - // Create a client with no routes and no fallback by directly manipulating the struct - // This simulates an edge case where defaultProvider returns nil index, err := createIndex() require.NoError(t, err) @@ -335,7 +327,6 @@ func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) { } defer client.Close() - // Attempt to create stream should return error, not panic messages := []chat.Message{{Role: chat.MessageRoleUser, Content: "hello"}} _, err = client.CreateChatCompletionStream(t.Context(), messages, nil) require.Error(t, err) @@ -348,8 +339,6 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) { // This test verifies that the models map and env are stored in the base config. // This is required for CloneWithOptions to work correctly with routers // that use model references (e.g., "fast" instead of "anthropic/claude-haiku-4-5"). - // Without this, cloning a router would fail because model references can't be resolved - // and the environment provider would be nil. models := map[string]latest.ModelConfig{ "fast": {Provider: "anthropic", Model: "claude-haiku-4-5"}, @@ -358,7 +347,7 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) { cfg := &latest.ModelConfig{ Provider: "anthropic", - Model: "claude-haiku-4-5", // fallback + Model: "claude-haiku-4-5", Routing: []latest.RoutingRule{ { Model: "fast", @@ -371,17 +360,42 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) { }, } - // Create a mock env provider mockEnv := environment.NewNoEnvProvider() client, err := NewClient(t.Context(), cfg, models, mockEnv, mockProviderFactory) require.NoError(t, err) defer client.Close() - // Verify the models map and env are stored in the base config baseConfig := client.BaseConfig() assert.NotNil(t, baseConfig.Models, "Models map should be stored in base config for cloning") assert.Equal(t, models, baseConfig.Models, "Models map should match what was passed to NewClient") assert.NotNil(t, baseConfig.Env, "Env should be stored in base config for cloning") assert.Equal(t, mockEnv, baseConfig.Env, "Env should match what was passed to NewClient") } + +func TestParseRouteIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + docID string + wantIdx int + wantOK bool + }{ + {"r0_e0", 0, true}, + {"r2_e5", 2, true}, + {"r10_e3", 10, true}, + {"invalid", 0, false}, + {"", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.docID, func(t *testing.T) { + t.Parallel() + idx, ok := parseRouteIndex(tt.docID) + assert.Equal(t, tt.wantOK, ok) + if ok { + assert.Equal(t, tt.wantIdx, idx) + } + }) + } +}