diff --git a/pkg/config/auto_test.go b/pkg/config/auto_test.go index 0d1606e9b..8fb073872 100644 --- a/pkg/config/auto_test.go +++ b/pkg/config/auto_test.go @@ -1,23 +1,14 @@ package config import ( - "context" "testing" "github.com/stretchr/testify/assert" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" ) -type mockEnvProvider struct { - envVars map[string]string -} - -func (m *mockEnvProvider) Get(_ context.Context, name string) (string, bool) { - val, found := m.envVars[name] - return val, found -} - func TestAvailableProviders_NoGateway(t *testing.T) { t.Parallel() @@ -96,7 +87,7 @@ func TestAvailableProviders_NoGateway(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - providers := AvailableProviders(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}) + providers := AvailableProviders(t.Context(), "", environment.NewMapEnvProvider(tt.envVars)) assert.NotEmpty(t, providers) assert.Equal(t, tt.expectedProvider, providers[0]) @@ -152,7 +143,7 @@ func TestAvailableProviders_WithGateway(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - providers := AvailableProviders(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}) + providers := AvailableProviders(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars)) assert.Len(t, providers, 1) assert.Equal(t, tt.expectedProvider, providers[0]) @@ -228,7 +219,7 @@ func TestAutoModelConfig(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}, nil) + modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil) assert.Equal(t, tt.expectedProvider, modelConfig.Provider) assert.Equal(t, tt.expectedModel, modelConfig.Model) @@ -328,7 +319,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { envVars["MISTRAL_API_KEY"] = "test-key" } - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}, nil) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil) // Verify the returned model matches the DefaultModels entry expectedModel := DefaultModels[provider] @@ -341,7 +332,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { t.Run("dmr", func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, nil) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil) assert.Equal(t, "dmr", modelConfig.Provider) assert.Equal(t, DefaultModels["dmr"], modelConfig.Model) @@ -353,51 +344,41 @@ func TestAvailableProviders_PrecedenceOrder(t *testing.T) { t.Parallel() // All keys present - anthropic should win - env := &mockEnvProvider{ - envVars: map[string]string{ - "ANTHROPIC_API_KEY": "test-key", - "OPENAI_API_KEY": "test-key", - "GOOGLE_API_KEY": "test-key", - "MISTRAL_API_KEY": "test-key", - }, - } + var env environment.Provider = environment.NewMapEnvProvider(map[string]string{ + "ANTHROPIC_API_KEY": "test-key", + "OPENAI_API_KEY": "test-key", + "GOOGLE_API_KEY": "test-key", + "MISTRAL_API_KEY": "test-key", + }) providers := AvailableProviders(t.Context(), "", env) assert.Equal(t, "anthropic", providers[0]) // No anthropic - openai should win - env = &mockEnvProvider{ - envVars: map[string]string{ - "OPENAI_API_KEY": "test-key", - "GOOGLE_API_KEY": "test-key", - "MISTRAL_API_KEY": "test-key", - }, - } + env = environment.NewMapEnvProvider(map[string]string{ + "OPENAI_API_KEY": "test-key", + "GOOGLE_API_KEY": "test-key", + "MISTRAL_API_KEY": "test-key", + }) providers = AvailableProviders(t.Context(), "", env) assert.Equal(t, "openai", providers[0]) // No anthropic or openai - google should win - env = &mockEnvProvider{ - envVars: map[string]string{ - "GOOGLE_API_KEY": "test-key", - "MISTRAL_API_KEY": "test-key", - }, - } + env = environment.NewMapEnvProvider(map[string]string{ + "GOOGLE_API_KEY": "test-key", + "MISTRAL_API_KEY": "test-key", + }) providers = AvailableProviders(t.Context(), "", env) assert.Equal(t, "google", providers[0]) // No anthropic, openai, or google - mistral should win - env = &mockEnvProvider{ - envVars: map[string]string{ - "MISTRAL_API_KEY": "test-key", - }, - } + env = environment.NewMapEnvProvider(map[string]string{ + "MISTRAL_API_KEY": "test-key", + }) providers = AvailableProviders(t.Context(), "", env) assert.Equal(t, "mistral", providers[0]) // No keys at all - dmr should be selected - env = &mockEnvProvider{ - envVars: map[string]string{}, - } + env = environment.NewNoEnvProvider() providers = AvailableProviders(t.Context(), "", env) assert.Equal(t, "dmr", providers[0]) } @@ -467,7 +448,7 @@ func TestAutoModelConfig_UserDefaultModel(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}, tt.defaultModel) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel) assert.Equal(t, tt.expectedProvider, modelConfig.Provider) assert.Equal(t, tt.expectedModel, modelConfig.Model) @@ -490,7 +471,7 @@ func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) { ThinkingBudget: thinkingBudget, } - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, defaultModel) + modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel) assert.Equal(t, "anthropic", modelConfig.Provider) assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model) diff --git a/pkg/config/sources_test.go b/pkg/config/sources_test.go index be9bb2bb8..6d06e0073 100644 --- a/pkg/config/sources_test.go +++ b/pkg/config/sources_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/environment" ) func TestURLSource_Read(t *testing.T) { @@ -308,11 +310,9 @@ func TestURLSource_Read_WithGitHubAuth(t *testing.T) { t.Cleanup(server.Close) // Create a mock env provider that returns a GitHub token - envProvider := &mockEnvProvider{ - envVars: map[string]string{ - "GITHUB_TOKEN": "test-token-123", - }, - } + envProvider := environment.NewMapEnvProvider(map[string]string{ + "GITHUB_TOKEN": "test-token-123", + }) // For non-GitHub URLs, auth should not be added even with token available source := NewURLSource(server.URL, envProvider) @@ -340,11 +340,9 @@ func TestURLSource_Read_WithGitHubAuth_GitHubURL(t *testing.T) { })) t.Cleanup(server.Close) - envProvider := &mockEnvProvider{ - envVars: map[string]string{ - "GITHUB_TOKEN": "test-token-456", - }, - } + envProvider := environment.NewMapEnvProvider(map[string]string{ + "GITHUB_TOKEN": "test-token-456", + }) // URL with GitHub host in path (not hostname) should NOT receive auth // This prevents token leakage to attacker-controlled domains @@ -369,9 +367,7 @@ func TestURLSource_Read_WithGitHubAuth_NoToken(t *testing.T) { t.Cleanup(server.Close) // Create a mock env provider without a GitHub token - envProvider := &mockEnvProvider{ - envVars: map[string]string{}, - } + envProvider := environment.NewNoEnvProvider() source := NewURLSource(server.URL, envProvider) _, err := source.Read(t.Context()) @@ -436,11 +432,9 @@ func TestIsGitHubURL(t *testing.T) { func TestResolve_URLReference_WithEnvProvider(t *testing.T) { t.Parallel() - envProvider := &mockEnvProvider{ - envVars: map[string]string{ - "GITHUB_TOKEN": "test-token", - }, - } + envProvider := environment.NewMapEnvProvider(map[string]string{ + "GITHUB_TOKEN": "test-token", + }) source, err := Resolve("https://github.com/owner/repo/raw/main/agent.yaml", envProvider) require.NoError(t, err) @@ -455,11 +449,9 @@ func TestResolve_URLReference_WithEnvProvider(t *testing.T) { func TestResolveSources_URLReference_WithEnvProvider(t *testing.T) { t.Parallel() - envProvider := &mockEnvProvider{ - envVars: map[string]string{ - "GITHUB_TOKEN": "test-token", - }, - } + envProvider := environment.NewMapEnvProvider(map[string]string{ + "GITHUB_TOKEN": "test-token", + }) url := "https://github.com/owner/repo/raw/main/agent.yaml" sources, err := ResolveSources(url, envProvider) diff --git a/pkg/environment/env.go b/pkg/environment/env.go index c62c41072..896661aec 100644 --- a/pkg/environment/env.go +++ b/pkg/environment/env.go @@ -38,6 +38,31 @@ func (p *EnvListProvider) Get(_ context.Context, name string) (string, bool) { return "", false } +// MapEnvProvider provides access to a static map of environment variables. +type MapEnvProvider struct { + vars map[string]string +} + +func NewMapEnvProvider(vars map[string]string) *MapEnvProvider { + return &MapEnvProvider{vars: vars} +} + +func (p *MapEnvProvider) Get(_ context.Context, name string) (string, bool) { + v, ok := p.vars[name] + return v, ok +} + +// NoEnvProvider is a provider that never finds any variable. +type NoEnvProvider struct{} + +func NewNoEnvProvider() *NoEnvProvider { + return &NoEnvProvider{} +} + +func (p *NoEnvProvider) Get(context.Context, string) (string, bool) { + return "", false +} + // EnvFilesProvider provides access env files. type EnvFilesProvider struct { values []KeyValuePair diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 829ca6fb7..dd3bb462d 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -1,7 +1,6 @@ package bedrock import ( - "context" "encoding/base64" "net/http" "net/http/httptest" @@ -342,24 +341,10 @@ func TestConvertImageURL_ValidImage(t *testing.T) { // NewClient validation tests -type mockEnvProvider struct { - values map[string]string -} - -func (m *mockEnvProvider) Get(_ context.Context, key string) (string, bool) { - if m.values == nil { - return "", false - } - v, ok := m.values[key] - return v, ok -} - -var _ environment.Provider = (*mockEnvProvider)(nil) - func TestNewClient_NilConfig(t *testing.T) { t.Parallel() - _, err := NewClient(t.Context(), nil, &mockEnvProvider{}) + _, err := NewClient(t.Context(), nil, environment.NewNoEnvProvider()) require.Error(t, err) assert.Contains(t, err.Error(), "model configuration is required") } @@ -371,7 +356,7 @@ func TestNewClient_WrongProvider(t *testing.T) { Provider: "openai", Model: "gpt-4", } - _, err := NewClient(t.Context(), cfg, &mockEnvProvider{}) + _, err := NewClient(t.Context(), cfg, environment.NewNoEnvProvider()) require.Error(t, err) assert.Contains(t, err.Error(), "model type must be 'amazon-bedrock'") } @@ -422,7 +407,7 @@ func TestBuildAWSConfig_DefaultRegion(t *testing.T) { ProviderOpts: map[string]any{}, } - env := &mockEnvProvider{values: map[string]string{}} + env := environment.NewNoEnvProvider() awsCfg, err := buildAWSConfig(t.Context(), cfg, env) require.NoError(t, err) @@ -442,7 +427,7 @@ func TestBuildAWSConfig_RegionFromProviderOpts(t *testing.T) { }, } - env := &mockEnvProvider{values: map[string]string{}} + env := environment.NewNoEnvProvider() awsCfg, err := buildAWSConfig(t.Context(), cfg, env) require.NoError(t, err) @@ -459,9 +444,9 @@ func TestBuildAWSConfig_RegionFromEnv(t *testing.T) { ProviderOpts: map[string]any{}, } - env := &mockEnvProvider{values: map[string]string{ + env := environment.NewMapEnvProvider(map[string]string{ "AWS_REGION": "ap-northeast-1", - }} + }) awsCfg, err := buildAWSConfig(t.Context(), cfg, env) require.NoError(t, err) @@ -480,9 +465,9 @@ func TestBuildAWSConfig_ProviderOptsOverridesEnv(t *testing.T) { }, } - env := &mockEnvProvider{values: map[string]string{ + env := environment.NewMapEnvProvider(map[string]string{ "AWS_REGION": "us-west-2", - }} + }) awsCfg, err := buildAWSConfig(t.Context(), cfg, env) require.NoError(t, err) @@ -504,9 +489,7 @@ func TestNewClient_ValidConfig(t *testing.T) { }, } - env := &mockEnvProvider{values: map[string]string{}} - - client, err := NewClient(t.Context(), cfg, env) + client, err := NewClient(t.Context(), cfg, environment.NewNoEnvProvider()) require.NoError(t, err) require.NotNil(t, client) @@ -527,11 +510,9 @@ func TestNewClient_WithBearerToken(t *testing.T) { }, } - env := &mockEnvProvider{values: map[string]string{ + client, err := NewClient(t.Context(), cfg, environment.NewMapEnvProvider(map[string]string{ "MY_BEDROCK_TOKEN": "test-bearer-token", - }} - - client, err := NewClient(t.Context(), cfg, env) + })) require.NoError(t, err) require.NotNil(t, client) } @@ -547,11 +528,9 @@ func TestNewClient_WithBearerTokenFromEnv(t *testing.T) { }, } - env := &mockEnvProvider{values: map[string]string{ + client, err := NewClient(t.Context(), cfg, environment.NewMapEnvProvider(map[string]string{ "AWS_BEARER_TOKEN_BEDROCK": "env-bearer-token", - }} - - client, err := NewClient(t.Context(), cfg, env) + })) require.NoError(t, err) require.NotNil(t, client) } diff --git a/pkg/model/provider/custom_provider_test.go b/pkg/model/provider/custom_provider_test.go index f36ffcdd1..cc8b9aa8c 100644 --- a/pkg/model/provider/custom_provider_test.go +++ b/pkg/model/provider/custom_provider_test.go @@ -1,7 +1,6 @@ package provider import ( - "context" "encoding/json" "errors" "io" @@ -20,20 +19,6 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -// mockEnvProvider is a simple env provider for testing -type mockEnvProvider struct { - values map[string]string -} - -func (m *mockEnvProvider) Get(_ context.Context, name string) (string, bool) { - v, ok := m.values[name] - return v, ok -} - -func newMockEnvProvider(values map[string]string) environment.Provider { - return &mockEnvProvider{values: values} -} - // TestCustomProvider_WithProvidersOption tests the full flow using options.WithProviders func TestCustomProvider_WithProvidersOption(t *testing.T) { t.Parallel() @@ -78,7 +63,7 @@ func TestCustomProvider_WithProvidersOption(t *testing.T) { Model: "gpt-4o", } - env := newMockEnvProvider(map[string]string{ + env := environment.NewMapEnvProvider(map[string]string{ "MY_GATEWAY_TOKEN": "secret-from-provider", }) @@ -161,7 +146,7 @@ func TestCustomProvider_RequestReachesServer(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{ + env := environment.NewMapEnvProvider(map[string]string{ customTokenKey: expectedToken, }) @@ -224,7 +209,7 @@ func TestCustomProvider_ResponsesAPIType(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{"API_KEY": "test"}) + env := environment.NewMapEnvProvider(map[string]string{"API_KEY": "test"}) provider, err := New(t.Context(), modelCfg, env) require.NoError(t, err) @@ -284,7 +269,7 @@ func TestCustomProvider_ChatCompletionsAPIType(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{"OPENAI_API_KEY": "test"}) + env := environment.NewMapEnvProvider(map[string]string{"OPENAI_API_KEY": "test"}) provider, err := New(t.Context(), modelCfg, env) require.NoError(t, err) @@ -319,7 +304,7 @@ func TestCustomProvider_MissingAPIKey(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{}) // Empty - key not set + env := environment.NewNoEnvProvider() // key not set _, err := New(t.Context(), modelCfg, env) require.Error(t, err) diff --git a/pkg/model/provider/openai/api_type_test.go b/pkg/model/provider/openai/api_type_test.go index a368e8ee6..239b649fb 100644 --- a/pkg/model/provider/openai/api_type_test.go +++ b/pkg/model/provider/openai/api_type_test.go @@ -1,7 +1,6 @@ package openai import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -16,20 +15,6 @@ import ( "github.com/docker/docker-agent/pkg/environment" ) -// mockEnvProvider is a simple env provider for testing -type mockEnvProvider struct { - values map[string]string -} - -func (m *mockEnvProvider) Get(_ context.Context, name string) (string, bool) { - v, ok := m.values[name] - return v, ok -} - -func newMockEnvProvider(values map[string]string) environment.Provider { - return &mockEnvProvider{values: values} -} - func TestGetAPIType(t *testing.T) { t.Parallel() @@ -183,7 +168,7 @@ func TestCustomProvider_WithTokenKey(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{ + env := environment.NewMapEnvProvider(map[string]string{ "MY_CUSTOM_TOKEN": "secret-token-123", }) @@ -234,7 +219,7 @@ func TestCustomProvider_WithoutTokenKey(t *testing.T) { }, } - env := newMockEnvProvider(map[string]string{}) + env := environment.NewNoEnvProvider() client, err := NewClient(t.Context(), cfg, env) require.NoError(t, err) diff --git a/pkg/model/provider/rulebased/client_test.go b/pkg/model/provider/rulebased/client_test.go index c5ea1c790..d918113cd 100644 --- a/pkg/model/provider/rulebased/client_test.go +++ b/pkg/model/provider/rulebased/client_test.go @@ -372,7 +372,7 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) { } // Create a mock env provider - mockEnv := &mockEnvProvider{} + mockEnv := environment.NewNoEnvProvider() client, err := NewClient(t.Context(), cfg, models, mockEnv, mockProviderFactory) require.NoError(t, err) @@ -385,10 +385,3 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) { assert.NotNil(t, baseConfig.Env, "Env should be stored in base config for cloning") assert.Equal(t, mockEnv, baseConfig.Env, "Env should match what was passed to NewClient") } - -// mockEnvProvider is a minimal mock for environment.Provider. -type mockEnvProvider struct{} - -func (m *mockEnvProvider) Get(_ context.Context, _ string) (string, bool) { - return "", false -} diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index 79ee3966e..477c5d29f 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -8,19 +8,10 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/modelsdev" ) -// mockEnvProvider is a simple environment provider for testing -type mockEnvProvider struct { - vars map[string]string -} - -func (m *mockEnvProvider) Get(_ context.Context, name string) (string, bool) { - v, ok := m.vars[name] - return v, ok -} - // mockCatalogStore implements ModelStore for testing type mockCatalogStore struct { ModelStore @@ -232,7 +223,7 @@ func TestGetAvailableProviders(t *testing.T) { r := &LocalRuntime{ modelSwitcherCfg: &ModelSwitcherConfig{ - EnvProvider: &mockEnvProvider{vars: tt.envVars}, + EnvProvider: environment.NewMapEnvProvider(tt.envVars), ModelsGateway: tt.modelsGateway, }, } @@ -312,10 +303,10 @@ func TestBuildCatalogChoices(t *testing.T) { r := &LocalRuntime{ modelsStore: &mockCatalogStore{db: db}, modelSwitcherCfg: &ModelSwitcherConfig{ - EnvProvider: &mockEnvProvider{vars: map[string]string{ + EnvProvider: environment.NewMapEnvProvider(map[string]string{ "OPENAI_API_KEY": "sk-test", "ANTHROPIC_API_KEY": "sk-ant-test", - }}, + }), Models: map[string]latest.ModelConfig{ "my_model": {Provider: "openai", Model: "gpt-4o"}, // This should be excluded from catalog (duplicate) }, @@ -375,9 +366,9 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) { r := &LocalRuntime{ modelsStore: &mockCatalogStore{db: db}, modelSwitcherCfg: &ModelSwitcherConfig{ - EnvProvider: &mockEnvProvider{vars: map[string]string{ + EnvProvider: environment.NewMapEnvProvider(map[string]string{ "OPENAI_API_KEY": "sk-test", - }}, + }), Models: map[string]latest.ModelConfig{ // This model has the same provider/model as the catalog entry "my_gpt4o": {Provider: "openai", Model: "gpt-4o"},