diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..92524ea17 --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,69 @@ +name: Conformance Test + +on: + pull_request: + +permissions: + contents: read + +jobs: + conformance: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + # Fetch full history to access merge-base + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "go.mod" + + - name: Download dependencies + run: go mod download + + - name: Run conformance test + id: conformance + run: | + # Run conformance test, capture stdout for summary + script/conformance-test > conformance-summary.txt 2>&1 || true + + # Output the summary + cat conformance-summary.txt + + # Check result + if grep -q "RESULT: ALL TESTS PASSED" conformance-summary.txt; then + echo "status=passed" >> $GITHUB_OUTPUT + else + echo "status=differences" >> $GITHUB_OUTPUT + fi + + - name: Generate Job Summary + run: | + # Add the full markdown report to the job summary + echo "# MCP Server Conformance Report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Comparing PR branch against merge-base with \`origin/main\`" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Extract and append the report content (skip the header since we added our own) + tail -n +5 conformance-report/CONFORMANCE_REPORT.md >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "---" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Add interpretation note + if [ "${{ steps.conformance.outputs.status }}" = "passed" ]; then + echo "✅ **All conformance tests passed** - No behavioral differences detected." >> $GITHUB_STEP_SUMMARY + else + echo "⚠️ **Differences detected** - Review the diffs above to ensure changes are intentional." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Common expected differences:" >> $GITHUB_STEP_SUMMARY + echo "- New tools/toolsets added" >> $GITHUB_STEP_SUMMARY + echo "- Tool descriptions updated" >> $GITHUB_STEP_SUMMARY + echo "- Capability changes (intentional improvements)" >> $GITHUB_STEP_SUMMARY + fi diff --git a/.gitignore b/.gitignore index b018fafac..5684108b0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ bin/ # binary github-mcp-server -.history \ No newline at end of file +.history +conformance-report/ diff --git a/README.md b/README.md index bcd9f85c8..e8737cd25 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,7 @@ You can also configure specific tools using the `--tools` flag. Tools can be use - Tools, toolsets, and dynamic toolsets can all be used together - Read-only mode takes priority: write tools are skipped if `--read-only` is set, even if explicitly requested via `--tools` - Tool names must match exactly (e.g., `get_file_contents`, not `getFileContents`). Invalid tool names will cause the server to fail at startup with an error message +- When tools are renamed, old names are preserved as aliases for backward compatibility. See [Deprecated Tool Aliases](docs/deprecated-tool-aliases.md) for details. ### Using Toolsets With Docker @@ -459,7 +460,6 @@ The following sets of tools are available: | `code_security` | Code security related tools, such as GitHub Code Scanning | | `dependabot` | Dependabot tools | | `discussions` | GitHub Discussions related tools | -| `experiments` | Experimental features that are not considered stable yet | | `gists` | GitHub Gist related tools | | `git` | GitHub Git API related tools for low-level Git operations | | `issues` | GitHub Issues related tools | diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 61459d7f0..8760c3f32 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -1,23 +1,17 @@ package main import ( - "context" "fmt" "net/url" "os" - "regexp" "sort" "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/lockdown" - "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" - gogithub "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/shurcooL/githubv4" "github.com/spf13/cobra" ) @@ -34,30 +28,21 @@ func init() { rootCmd.AddCommand(generateDocsCmd) } -// mockGetClient returns a mock GitHub client for documentation generation -func mockGetClient(_ context.Context) (*gogithub.Client, error) { - return gogithub.NewClient(nil), nil -} - -// mockGetGQLClient returns a mock GraphQL client for documentation generation -func mockGetGQLClient(_ context.Context) (*githubv4.Client, error) { - return githubv4.NewClient(nil), nil -} - -// mockGetRawClient returns a mock raw client for documentation generation -func mockGetRawClient(_ context.Context) (*raw.Client, error) { - return nil, nil -} - func generateAllDocs() error { - if err := generateReadmeDocs("README.md"); err != nil { - return fmt.Errorf("failed to generate README docs: %w", err) - } - - if err := generateRemoteServerDocs("docs/remote-server.md"); err != nil { - return fmt.Errorf("failed to generate remote-server docs: %w", err) + for _, doc := range []struct { + path string + fn func(string) error + }{ + // File to edit, function to generate its docs + {"README.md", generateReadmeDocs}, + {"docs/remote-server.md", generateRemoteServerDocs}, + {"docs/deprecated-tool-aliases.md", generateDeprecatedAliasesDocs}, + } { + if err := doc.fn(doc.path); err != nil { + return fmt.Errorf("failed to generate docs for %s: %w", doc.path, err) + } + fmt.Printf("Successfully updated %s with automated documentation\n", doc.path) } - return nil } @@ -65,15 +50,14 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + // Build registry - stateless, no dependencies needed for doc generation + r := github.NewRegistry(t).Build() // Generate toolsets documentation - toolsetsDoc := generateToolsetsDoc(tsg) + toolsetsDoc := generateToolsetsDoc(r) // Generate tools documentation - toolsDoc := generateToolsDoc(tsg) + toolsDoc := generateToolsDoc(r) // Read the current README.md // #nosec G304 - readmePath is controlled by command line flag, not user input @@ -83,10 +67,16 @@ func generateReadmeDocs(readmePath string) error { } // Replace toolsets section - updatedContent := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err + } // Replace tools section - updatedContent = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + updatedContent, err = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + if err != nil { + return err + } // Write back to file err = os.WriteFile(readmePath, []byte(updatedContent), 0600) @@ -94,7 +84,6 @@ func generateReadmeDocs(readmePath string) error { return fmt.Errorf("failed to write README.md: %w", err) } - fmt.Println("Successfully updated README.md with automated documentation") return nil } @@ -107,93 +96,73 @@ func generateRemoteServerDocs(docsPath string) error { toolsetsDoc := generateRemoteToolsetsDoc() // Replace content between markers - startMarker := "" - endMarker := "" - - contentStr := string(content) - startIndex := strings.Index(contentStr, startMarker) - endIndex := strings.Index(contentStr, endMarker) - - if startIndex == -1 || endIndex == -1 { - return fmt.Errorf("automation markers not found in %s", docsPath) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err } - newContent := contentStr[:startIndex] + startMarker + "\n" + toolsetsDoc + "\n" + endMarker + contentStr[endIndex+len(endMarker):] - - return os.WriteFile(docsPath, []byte(newContent), 0600) //#nosec G306 + return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } -func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { - var lines []string +func generateToolsetsDoc(r *registry.Registry) string { + var buf strings.Builder // Add table header and separator - lines = append(lines, "| Toolset | Description |") - lines = append(lines, "| ----------------------- | ------------------------------------------------------------- |") + buf.WriteString("| Toolset | Description |\n") + buf.WriteString("| ----------------------- | ------------------------------------------------------------- |\n") - // Add the context toolset row (handled separately in README) - lines = append(lines, "| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |") + // Add the context toolset row with custom description (strongly recommended) + buf.WriteString("| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |\n") - // Get all toolsets except context (which is handled separately above) - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (custom description above) and dynamic (internal only) + for _, ts := range r.AvailableToolsets("context", "dynamic") { + fmt.Fprintf(&buf, "| `%s` | %s |\n", ts.ID, ts.Description) } - // Sort toolset names for consistent output - sort.Strings(toolsetNames) - - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] - lines = append(lines, fmt.Sprintf("| `%s` | %s |", name, toolset.Description)) - } - - return strings.Join(lines, "\n") + return strings.TrimSuffix(buf.String(), "\n") } -func generateToolsDoc(tsg *toolsets.ToolsetGroup) string { - var sections []string - - // Get all toolset names and sort them alphabetically for deterministic order - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "dynamic" { // Skip dynamic toolset as it's handled separately - toolsetNames = append(toolsetNames, name) - } +func generateToolsDoc(r *registry.Registry) string { + // AllTools() returns tools sorted by toolset ID then tool name. + // We iterate once, grouping by toolset as we encounter them. + tools := r.AllTools() + if len(tools) == 0 { + return "" } - sort.Strings(toolsetNames) - for _, toolsetName := range toolsetNames { - toolset := tsg.Toolsets[toolsetName] + var buf strings.Builder + var toolBuf strings.Builder + var currentToolsetID registry.ToolsetID + firstSection := true - tools := toolset.GetAvailableTools() - if len(tools) == 0 { - continue + writeSection := func() { + if toolBuf.Len() == 0 { + return } - - // Sort tools by name for deterministic order - sort.Slice(tools, func(i, j int) bool { - return tools[i].Tool.Name < tools[j].Tool.Name - }) - - // Generate section header - capitalize first letter and replace underscores - sectionName := formatToolsetName(toolsetName) - - var toolDocs []string - for _, serverTool := range tools { - toolDoc := generateToolDoc(serverTool.Tool) - toolDocs = append(toolDocs, toolDoc) + if !firstSection { + buf.WriteString("\n\n") } + firstSection = false + sectionName := formatToolsetName(string(currentToolsetID)) + fmt.Fprintf(&buf, "
\n\n%s\n\n%s\n\n
", sectionName, strings.TrimSuffix(toolBuf.String(), "\n\n")) + toolBuf.Reset() + } - if len(toolDocs) > 0 { - section := fmt.Sprintf("
\n\n%s\n\n%s\n\n
", - sectionName, strings.Join(toolDocs, "\n\n")) - sections = append(sections, section) + for _, tool := range tools { + // When toolset changes, emit the previous section + if tool.Toolset.ID != currentToolsetID { + writeSection() + currentToolsetID = tool.Toolset.ID } + writeToolDoc(&toolBuf, tool.Tool) + toolBuf.WriteString("\n\n") } - return strings.Join(sections, "\n\n") + // Emit the last section + writeSection() + + return buf.String() } func formatToolsetName(name string) string { @@ -220,21 +189,19 @@ func formatToolsetName(name string) string { } } -func generateToolDoc(tool mcp.Tool) string { - var lines []string - +func writeToolDoc(buf *strings.Builder, tool mcp.Tool) { // Tool name only (using annotation name instead of verbose description) - lines = append(lines, fmt.Sprintf("- **%s** - %s", tool.Name, tool.Annotations.Title)) + fmt.Fprintf(buf, "- **%s** - %s\n", tool.Name, tool.Annotations.Title) // Parameters if tool.InputSchema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } schema, ok := tool.InputSchema.(*jsonschema.Schema) if !ok || schema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } if len(schema.Properties) > 0 { @@ -245,7 +212,7 @@ func generateToolDoc(tool mcp.Tool) string { } sort.Strings(paramNames) - for _, propName := range paramNames { + for i, propName := range paramNames { prop := schema.Properties[propName] required := contains(schema.Required, propName) requiredStr := "optional" @@ -253,7 +220,7 @@ func generateToolDoc(tool mcp.Tool) string { requiredStr = "required" } - var typeStr, description string + var typeStr string // Get the type and description switch prop.Type { @@ -267,19 +234,17 @@ func generateToolDoc(tool mcp.Tool) string { typeStr = prop.Type } - description = prop.Description - // Indent any continuation lines in the description to maintain markdown formatting - description = indentMultilineDescription(description, " ") + description := indentMultilineDescription(prop.Description, " ") - paramLine := fmt.Sprintf(" - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) - lines = append(lines, paramLine) + fmt.Fprintf(buf, " - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) + if i < len(paramNames)-1 { + buf.WriteString("\n") + } } } else { - lines = append(lines, " - No parameters required") + buf.WriteString(" - No parameters required") } - - return strings.Join(lines, "\n") } func contains(slice []string, item string) bool { @@ -294,25 +259,38 @@ func contains(slice []string, item string) bool { // indentMultilineDescription adds the specified indent to all lines after the first line. // This ensures that multi-line descriptions maintain proper markdown list formatting. func indentMultilineDescription(description, indent string) string { - lines := strings.Split(description, "\n") - if len(lines) <= 1 { + if !strings.Contains(description, "\n") { return description } + var buf strings.Builder + lines := strings.Split(description, "\n") + buf.WriteString(lines[0]) for i := 1; i < len(lines); i++ { - lines[i] = indent + lines[i] + buf.WriteString("\n") + buf.WriteString(indent) + buf.WriteString(lines[i]) } - return strings.Join(lines, "\n") + return buf.String() } -func replaceSection(content, startMarker, endMarker, newContent string) string { - startPattern := fmt.Sprintf(``, regexp.QuoteMeta(startMarker)) - endPattern := fmt.Sprintf(``, regexp.QuoteMeta(endMarker)) - - re := regexp.MustCompile(fmt.Sprintf(`(?s)%s.*?%s`, startPattern, endPattern)) +func replaceSection(content, startMarker, endMarker, newContent string) (string, error) { + start := fmt.Sprintf("", startMarker) + end := fmt.Sprintf("", endMarker) - replacement := fmt.Sprintf("\n%s\n", startMarker, newContent, endMarker) + startIdx := strings.Index(content, start) + endIdx := strings.Index(content, end) + if startIdx == -1 || endIdx == -1 { + return "", fmt.Errorf("markers not found: %s / %s", start, end) + } - return re.ReplaceAllString(content, replacement) + var buf strings.Builder + buf.WriteString(content[:startIdx]) + buf.WriteString(start) + buf.WriteString("\n") + buf.WriteString(newContent) + buf.WriteString("\n") + buf.WriteString(content[endIdx:]) + return buf.String(), nil } func generateRemoteToolsetsDoc() string { @@ -321,34 +299,24 @@ func generateRemoteToolsetsDoc() string { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + // Build registry - stateless + r := github.NewRegistry(t).Build() // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") buf.WriteString("|----------------|--------------------------------------------------|-------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n") - // Get all toolsets - toolsetNames := make([]string, 0, len(tsg.Toolsets)) - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } - } - sort.Strings(toolsetNames) - // Add "all" toolset first (special case) buf.WriteString("| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) |\n") - // Add individual toolsets - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (handled separately) and dynamic (internal only) + for _, ts := range r.AvailableToolsets("context", "dynamic") { + idStr := string(ts.ID) - formattedName := formatToolsetName(name) - description := toolset.Description - apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", name) - readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", name) + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) // Create install config JSON (URL encoded) installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) @@ -358,17 +326,72 @@ func generateRemoteToolsetsDoc() string { installConfig = strings.ReplaceAll(installConfig, "+", "%20") readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") - installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, installConfig) - readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, readonlyConfig) + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) - buf.WriteString(fmt.Sprintf("| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", + fmt.Fprintf(&buf, "| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", formattedName, - description, + ts.Description, apiURL, installLink, fmt.Sprintf("[read-only](%s)", readonlyURL), readonlyInstallLink, - )) + ) + } + + return strings.TrimSuffix(buf.String(), "\n") +} + +func generateDeprecatedAliasesDocs(docsPath string) error { + // Read the current file + content, err := os.ReadFile(docsPath) //#nosec G304 + if err != nil { + return fmt.Errorf("failed to read docs file: %w", err) + } + + // Generate the table + aliasesDoc := generateDeprecatedAliasesTable() + + // Replace content between markers + updatedContent, err := replaceSection(string(content), "START AUTOMATED ALIASES", "END AUTOMATED ALIASES", aliasesDoc) + if err != nil { + return err + } + + // Write back to file + err = os.WriteFile(docsPath, []byte(updatedContent), 0600) + if err != nil { + return fmt.Errorf("failed to write deprecated aliases docs: %w", err) + } + + return nil +} + +func generateDeprecatedAliasesTable() string { + var buf strings.Builder + + // Add table header + buf.WriteString("| Old Name | New Name |\n") + buf.WriteString("|----------|----------|\n") + + aliases := github.DeprecatedToolAliases + if len(aliases) == 0 { + buf.WriteString("| *(none currently)* | |") + } else { + // Sort keys for deterministic output + var oldNames []string + for oldName := range aliases { + oldNames = append(oldNames, oldName) + } + sort.Strings(oldNames) + + for i, oldName := range oldNames { + newName := aliases[oldName] + fmt.Fprintf(&buf, "| `%s` | `%s` |", oldName, newName) + if i < len(oldNames)-1 { + buf.WriteString("\n") + } + } } return buf.String() diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 87eeedd2e..034b0e238 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -41,10 +41,16 @@ var ( // it's because viper doesn't handle comma-separated values correctly for env // vars when using GetStringSlice. // https://github.com/spf13/viper/issues/380 + // + // Additionally, viper.UnmarshalKey returns an empty slice even when the flag + // is not set, but we need nil to indicate "use defaults". So we check IsSet first. var enabledToolsets []string - if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { - return fmt.Errorf("failed to unmarshal toolsets: %w", err) + if viper.IsSet("toolsets") { + if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { + return fmt.Errorf("failed to unmarshal toolsets: %w", err) + } } + // else: enabledToolsets stays nil, meaning "use defaults" // Parse tools (similar to toolsets) var enabledTools []string @@ -52,9 +58,10 @@ var ( return fmt.Errorf("failed to unmarshal tools: %w", err) } - // If neither toolset config nor tools config is passed we enable the default toolset - if len(enabledToolsets) == 0 && len(enabledTools) == 0 { - enabledToolsets = []string{github.ToolsetMetadataDefault.ID} + // Parse enabled features (similar to toolsets) + var enabledFeatures []string + if err := viper.UnmarshalKey("features", &enabledFeatures); err != nil { + return fmt.Errorf("failed to unmarshal features: %w", err) } ttl := viper.GetDuration("repo-access-cache-ttl") @@ -64,6 +71,7 @@ var ( Token: token, EnabledToolsets: enabledToolsets, EnabledTools: enabledTools, + EnabledFeatures: enabledFeatures, DynamicToolsets: viper.GetBool("dynamic_toolsets"), ReadOnly: viper.GetBool("read-only"), ExportTranslations: viper.GetBool("export-translations"), @@ -87,6 +95,7 @@ func init() { // Add global flags that will be shared by all commands rootCmd.PersistentFlags().StringSlice("toolsets", nil, github.GenerateToolsetsHelp()) rootCmd.PersistentFlags().StringSlice("tools", nil, "Comma-separated list of specific tools to enable") + rootCmd.PersistentFlags().StringSlice("features", nil, "Comma-separated list of feature flags to enable") rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") @@ -100,6 +109,7 @@ func init() { // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("tools", rootCmd.PersistentFlags().Lookup("tools")) + _ = viper.BindPFlag("features", rootCmd.PersistentFlags().Lookup("features")) _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) diff --git a/docs/deprecated-tool-aliases.md b/docs/deprecated-tool-aliases.md new file mode 100644 index 000000000..6ea2ba1de --- /dev/null +++ b/docs/deprecated-tool-aliases.md @@ -0,0 +1,31 @@ +# Deprecated Tool Aliases + +This document tracks tool renames in the GitHub MCP Server. When tools are renamed, the old names are preserved as aliases for backward compatibility. Using a deprecated alias will still work, but clients should migrate to the new canonical name. + +## Current Deprecations + + +| Old Name | New Name | +|----------|----------| +| *(none currently)* | | + + +## How It Works + +When a tool is renamed: + +1. The old name is added to `DeprecatedToolAliases` in [pkg/github/deprecated_tool_aliases.go](../pkg/github/deprecated_tool_aliases.go) +2. Clients using the old name will receive the new tool +3. A deprecation notice is logged when the alias is used + +## For Developers + +To deprecate a tool name when renaming: + +```go +var DeprecatedToolAliases = map[string]string{ + "old_tool_name": "new_tool_name", +} +``` + +The alias resolution happens at server startup, ensuring backward compatibility for existing client configurations. diff --git a/docs/remote-server.md b/docs/remote-server.md index e06d41a75..53fe36127 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -24,7 +24,6 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Code Security | Code security related tools, such as GitHub Code Scanning | https://api.githubcopilot.com/mcp/x/code_security | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/code_security/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%2Freadonly%22%7D) | | Dependabot | Dependabot tools | https://api.githubcopilot.com/mcp/x/dependabot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/dependabot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%2Freadonly%22%7D) | | Discussions | GitHub Discussions related tools | https://api.githubcopilot.com/mcp/x/discussions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/discussions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%2Freadonly%22%7D) | -| Experiments | Experimental features that are not considered stable yet | https://api.githubcopilot.com/mcp/x/experiments | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/experiments/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%2Freadonly%22%7D) | | Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | | Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | | Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | @@ -38,7 +37,6 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Security Advisories | Security advisories related tools | https://api.githubcopilot.com/mcp/x/security_advisories | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/security_advisories/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%2Freadonly%22%7D) | | Stargazers | GitHub Stargazers related tools | https://api.githubcopilot.com/mcp/x/stargazers | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/stargazers/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%2Freadonly%22%7D) | | Users | GitHub User related tools | https://api.githubcopilot.com/mcp/x/users | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/users/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%2Freadonly%22%7D) | - ### Additional _Remote_ Server Toolsets diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 5f67fb84c..e286930a2 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -178,7 +178,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcp.ClientSession { // so that there is a shared setup mechanism, but let's wait till we feel more friction. enabledToolsets := opts.enabledToolsets if enabledToolsets == nil { - enabledToolsets = github.GetDefaultToolsetIDs() + enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).Build().DefaultToolsetIDs() } ghServer, err := ghmcp.NewMCPServer(ghmcp.MCPServerConfig{ diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index c0f4e25e7..e98637067 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,6 +18,7 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -42,6 +43,10 @@ type MCPServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -64,149 +69,193 @@ type MCPServerConfig struct { RepoAccessTTL *time.Duration } -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } +// githubClients holds all the GitHub API clients created for a server instance. +type githubClients struct { + rest *gogithub.Client + gql *githubv4.Client + gqlHTTP *http.Client // retained for middleware to modify transport + raw *raw.Client + repoAccess *lockdown.RepoAccessCache +} - // Construct our REST client +// createGitHubClients creates all the GitHub API clients needed by the server. +func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { + // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) restClient.BaseURL = apiHost.baseRESTURL restClient.UploadURL = apiHost.uploadURL - // Construct our GraphQL client - // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already - // did the necessary API host parsing so that github.com will return the correct URL anyway. + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &bearerAuthTransport{ transport: http.DefaultTransport, token: cfg.Token, }, - } // We're going to wrap the Transport later in beforeInit - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessOpts := []lockdown.RepoAccessOption{} - if cfg.RepoAccessTTL != nil { - repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } + gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessLogger := cfg.Logger.With("component", "lockdown") - repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger)) + // Create raw content client (shares REST client's HTTP transport) + rawClient := raw.NewClient(restClient, apiHost.rawURL) + + // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { - repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) + opts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), + } + if cfg.RepoAccessTTL != nil { + opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + repoAccessCache = lockdown.GetInstance(gqlClient, opts...) } + return &githubClients{ + rest: restClient, + gql: gqlClient, + gqlHTTP: gqlHTTPClient, + raw: rawClient, + repoAccess: repoAccessCache, + }, nil +} + +// resolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func resolveEnabledToolsets(cfg MCPServerConfig) []string { enabledToolsets := cfg.EnabledToolsets - // If dynamic toolsets are enabled, remove "all" and "default" from the enabled toolsets - if cfg.DynamicToolsets { - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataAll.ID) - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) } - // Clean up the passed toolsets - enabledToolsets, invalidToolsets := github.CleanToolsets(enabledToolsets) - - // If "all" is present, override all other toolsets - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataAll.ID) { - enabledToolsets = []string{github.ToolsetMetadataAll.ID} + if enabledToolsets != nil { + return enabledToolsets + } + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} + } + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} } - // If "default" is present, expand to real toolset IDs - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) { - enabledToolsets = github.AddDefaultToolset(enabledToolsets) + // nil means "use defaults" in WithToolsets + return nil +} + +func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { + apiHost, err := parseAPIHost(cfg.Host) + if err != nil { + return nil, fmt.Errorf("failed to parse API host: %w", err) } - if len(invalidToolsets) > 0 { - fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) + clients, err := createGitHubClients(cfg, apiHost) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - // Generate instructions based on enabled toolsets - instructions := github.GenerateInstructions(enabledToolsets) + enabledToolsets := resolveEnabledToolsets(cfg) - getClient := func(_ context.Context) (*gogithub.Client, error) { - return restClient, nil // closing over client + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in registry, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = github.GetDefaultToolsetIDs() } - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: github.GenerateInstructions(instructionToolsets), + Logger: cfg.Logger, + CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { + return clients.rest, nil + }), } - getRawClient := func(ctx context.Context) (*raw.Client, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - return raw.NewClient(client, apiHost.rawURL), nil // closing over client + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.HasTools = true + serverOpts.HasResources = true + serverOpts.HasPrompts = true } - ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ - Instructions: instructions, - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(getClient), - }) + ghServer := github.NewServer(cfg.Version, serverOpts) // Add middlewares ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient)) - - // Create the dependencies struct for tool handlers - deps := github.ToolDependencies{ - GetClient: getClient, - GetGQLClient: getGQLClient, - GetRawClient: getRawClient, - RepoAccessCache: repoAccessCache, - T: cfg.Translator, - Flags: github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - ContentWindowSize: cfg.ContentWindowSize, - } - - // Create default toolsets - tsg := github.DefaultToolsetGroup( - cfg.ReadOnly, - getClient, - getGQLClient, - getRawClient, + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) + + // Create dependencies for tool handlers + deps := github.NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, cfg.Translator, - cfg.ContentWindowSize, github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - repoAccessCache, + cfg.ContentWindowSize, ) - // Enable and register toolsets if configured - // This always happens if toolsets are specified, regardless of whether tools are also specified - if len(enabledToolsets) > 0 { - err = tsg.EnableToolsets(enabledToolsets, nil) - if err != nil { - return nil, fmt.Errorf("failed to enable toolsets: %w", err) - } + // Build and register the tool/resource/prompt registry + registry := github.NewRegistry(cfg.Translator). + WithDeprecatedAliases(github.DeprecatedToolAliases). + WithReadOnly(cfg.ReadOnly). + WithToolsets(enabledToolsets). + WithTools(github.CleanTools(cfg.EnabledTools)). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)). + Build() - // Register all mcp functionality with the server - tsg.RegisterAll(ghServer) + if unrecognized := registry.UnrecognizedToolsets(); len(unrecognized) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - // Register specific tools if configured - if len(cfg.EnabledTools) > 0 { - enabledTools := github.CleanTools(cfg.EnabledTools) - enabledTools, _ = tsg.ResolveToolAliases(enabledTools) + // Register GitHub tools/resources/prompts from the registry. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + registry.RegisterAll(context.Background(), ghServer, deps) - // Register the specified tools (additive to any toolsets already enabled) - err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly, deps) - if err != nil { - return nil, fmt.Errorf("failed to register tools: %w", err) - } - } - - // Register dynamic toolsets if configured (additive to toolsets and tools) + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the registry, not part of the registry itself if cfg.DynamicToolsets { - dynamic := github.InitDynamicToolset(ghServer, tsg, cfg.Translator) - dynamic.RegisterTools(ghServer) + registerDynamicTools(ghServer, registry, deps, cfg.Translator) } return ghServer, nil } +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, registry *registry.Registry, deps *github.BaseDeps, t translations.TranslationHelperFunc) { + dynamicDeps := github.DynamicToolDependencies{ + Server: server, + Registry: registry, + ToolDeps: deps, + T: t, + } + for _, tool := range github.DynamicTools(registry) { + tool.RegisterFunc(server, dynamicDeps) + } +} + +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) registry.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true + } + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil + } +} + type StdioServerConfig struct { // Version of the server Version string @@ -225,6 +274,10 @@ type StdioServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -282,6 +335,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Token: cfg.Token, EnabledToolsets: cfg.EnabledToolsets, EnabledTools: cfg.EnabledTools, + EnabledFeatures: cfg.EnabledFeatures, DynamicToolsets: cfg.DynamicToolsets, ReadOnly: cfg.ReadOnly, Translator: t, diff --git a/pkg/github/actions.go b/pkg/github/actions.go index e9c7c11a8..584a23200 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -11,7 +11,7 @@ import ( "github.com/github/github-mcp-server/internal/profiler" buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,8 +25,9 @@ const ( ) // ListWorkflows creates a tool to list workflows in a repository -func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflows(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflows", Description: t("TOOL_LIST_WORKFLOWS_DESCRIPTION", "List workflows in a repository"), @@ -95,8 +96,9 @@ func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow -func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRuns(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_runs", Description: t("TOOL_LIST_WORKFLOW_RUNS_DESCRIPTION", "List workflow runs for a specific workflow"), @@ -248,8 +250,9 @@ func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool } // RunWorkflow creates a tool to run an Actions workflow -func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RunWorkflow(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "run_workflow", Description: t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID or filename"), @@ -359,8 +362,9 @@ func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRun creates a tool to get details of a specific workflow run -func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run", Description: t("TOOL_GET_WORKFLOW_RUN_DESCRIPTION", "Get details of a specific workflow run"), @@ -426,8 +430,9 @@ func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run -func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run_logs", Description: t("TOOL_GET_WORKFLOW_RUN_LOGS_DESCRIPTION", "Download logs for a specific workflow run (EXPENSIVE: downloads ALL logs as ZIP. Consider using get_job_logs with failed_only=true for debugging failed jobs)"), @@ -503,8 +508,9 @@ func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerToo } // ListWorkflowJobs creates a tool to list jobs for a specific workflow run -func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_jobs", Description: t("TOOL_LIST_WORKFLOW_JOBS_DESCRIPTION", "List jobs for a specific workflow run"), @@ -602,8 +608,9 @@ func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run -func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetJobLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_job_logs", Description: t("TOOL_GET_JOB_LOGS_DESCRIPTION", "Download logs for a specific workflow job or efficiently get all failed job logs for a workflow run"), @@ -699,10 +706,10 @@ func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { if failedOnly && runID > 0 { // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.ContentWindowSize) + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) } else if jobID > 0 { // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.ContentWindowSize) + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) } return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil @@ -866,8 +873,9 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi } // RerunWorkflowRun creates a tool to re-run an entire workflow run -func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "rerun_workflow_run", Description: t("TOOL_RERUN_WORKFLOW_RUN_DESCRIPTION", "Re-run an entire workflow run"), @@ -940,8 +948,9 @@ func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run -func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunFailedJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "rerun_failed_jobs", Description: t("TOOL_RERUN_FAILED_JOBS_DESCRIPTION", "Re-run only the failed jobs in a workflow run"), @@ -1014,8 +1023,9 @@ func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CancelWorkflowRun creates a tool to cancel a workflow run -func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CancelWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "cancel_workflow_run", Description: t("TOOL_CANCEL_WORKFLOW_RUN_DESCRIPTION", "Cancel a workflow run"), @@ -1090,8 +1100,9 @@ func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run -func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_run_artifacts", Description: t("TOOL_LIST_WORKFLOW_RUN_ARTIFACTS_DESCRIPTION", "List artifacts for a workflow run"), @@ -1169,8 +1180,9 @@ func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.Ser } // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact -func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "download_workflow_run_artifact", Description: t("TOOL_DOWNLOAD_WORKFLOW_RUN_ARTIFACT_DESCRIPTION", "Get download URL for a workflow run artifact"), @@ -1245,8 +1257,9 @@ func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets. } // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run -func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "delete_workflow_run_logs", Description: t("TOOL_DELETE_WORKFLOW_RUN_LOGS_DESCRIPTION", "Delete logs for a workflow run"), @@ -1320,8 +1333,9 @@ func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.Server } // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run -func GetWorkflowRunUsage(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunUsage(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run_usage", Description: t("TOOL_GET_WORKFLOW_RUN_USAGE_DESCRIPTION", "Get usage metrics for a workflow run"), diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 09ab3b2cc..4d56f01aa 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -105,8 +105,8 @@ func Test_ListWorkflows(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -194,8 +194,8 @@ func Test_RunWorkflow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -290,8 +290,8 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -398,8 +398,8 @@ func Test_CancelWorkflowRun(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -528,8 +528,8 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -618,8 +618,8 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -704,8 +704,8 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -808,8 +808,8 @@ func Test_GetWorkflowRunUsage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -1072,8 +1072,8 @@ func Test_GetJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1136,8 +1136,8 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1188,8 +1188,8 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1240,8 +1240,8 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 518855a59..8826e4cf6 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,8 +16,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCodeScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataCodeSecurity, mcp.Tool{ Name: "get_code_scanning_alert", Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."), @@ -93,8 +94,9 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerT ) } -func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataCodeSecurity, mcp.Tool{ Name: "list_code_scanning_alerts", Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."), diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 5e56e6788..44c7a7e95 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -88,8 +88,8 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -220,8 +220,8 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index e8043731a..837de00f7 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,7 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -37,8 +37,9 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetMe(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), @@ -110,8 +111,9 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeams(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), @@ -208,8 +210,9 @@ func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeamMembers(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 0e28aad49..e9faefc40 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -3,7 +3,7 @@ package github import ( "context" "encoding/json" - "fmt" + "net/http" "testing" "time" @@ -48,7 +48,8 @@ func Test_GetMe(t *testing.T) { tests := []struct { name string - stubbedGetClientFn GetClientFn + mockedClient *http.Client + clientErr string // if set, GetClient returns this error requestArgs map[string]any expectToolError bool expectedUser *github.User @@ -56,12 +57,10 @@ func Test_GetMe(t *testing.T) { }{ { name: "successful get user", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{}, @@ -70,12 +69,10 @@ func Test_GetMe(t *testing.T) { }, { name: "successful get user with reason", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{ @@ -86,19 +83,17 @@ func Test_GetMe(t *testing.T) { }, { name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), + clientErr: "expected test error", requestArgs: map[string]any{}, expectToolError: true, expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), ), ), requestArgs: map[string]any{}, @@ -109,8 +104,11 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, + var deps ToolDependencies + if tc.clientErr != "" { + deps = stubDeps{clientFn: stubClientFnErr(tc.clientErr)} + } else { + deps = BaseDeps{Client: github.NewClient(tc.mockedClient)} } handler := serverTool.Handler(deps) @@ -223,49 +221,83 @@ func Test_GetTeams(t *testing.T) { }, }) + // Create GQL clients for different test scenarios - these are factory functions + // to ensure each test gets a fresh client + gqlClientForTestuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientForSpecificuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "specificuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoTeams := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + // Factory function for mock HTTP clients with user response + httpClientWithUser := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ) + } + + httpClientUserFails := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), + ), + ) + } + tests := []struct { - name string - stubbedGetClientFn GetClientFn - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedTeamsCount int + name string + makeDeps func() ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedTeamsCount int }{ { name: "successful get teams", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientForTestuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 2, }, { - name: "successful get teams for specific user", - stubbedGetClientFn: nil, - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "specificuser", + name: "successful get teams for specific user", + makeDeps: func() ToolDependencies { + return BaseDeps{ + GQLClient: gqlClientForSpecificuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{ "user": "specificuser", @@ -275,62 +307,43 @@ func Test_GetTeams(t *testing.T) { }, { name: "no teams found", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientNoTeams(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 0, }, { - name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "failed to get GitHub client: expected test error", + name: "getting client fails", + makeDeps: func() ToolDependencies { + return stubDeps{clientFn: stubClientFnErr("expected test error")} + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), - ), - ), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "expected test failure", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientUserFails()), + } + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "expected test failure", }, { name: "getting GraphQL client fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") + makeDeps: func() ToolDependencies { + return stubDeps{ + clientFn: stubClientFnFromHTTP(httpClientWithUser()), + gqlClientFn: stubGQLClientFnErr("GraphQL client error"), + } }, requestArgs: map[string]any{}, expectToolError: true, @@ -340,11 +353,7 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.makeDeps()) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) @@ -422,26 +431,40 @@ func Test_GetTeamMembers(t *testing.T) { }, }) + // Create GQL clients for different test scenarios + gqlClientWithMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "testteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "emptyteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + tests := []struct { - name string - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedMembersCount int + name string + deps ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedMembersCount int }{ { name: "successful get team members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "testteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientWithMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -451,16 +474,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "team with no members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "emptyteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientNoMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "emptyteam", @@ -470,9 +484,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "getting GraphQL client fails", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") - }, + deps: stubDeps{gqlClientFn: stubGQLClientFnErr("GraphQL client error")}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -484,10 +496,7 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.deps) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index b80fd0aa0..daa2a124a 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,8 +16,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDependabotAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDependabot, mcp.Tool{ Name: "get_dependabot_alert", Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), @@ -93,8 +94,9 @@ func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerToo ) } -func ListDependabotAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDependabotAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDependabot, mcp.Tool{ Name: "list_dependabot_alerts", Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index ace0eb07a..614c6f383 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -81,7 +81,7 @@ func Test_GetDependabotAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -232,7 +232,7 @@ func Test_ListDependabotAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 3124f6bd0..040e61883 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -1,53 +1,123 @@ package github import ( + "context" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/shurcooL/githubv4" ) -// ToolDependencies contains all dependencies that tool handlers might need. -// This is a properly-typed struct that lives in pkg/github to avoid circular -// dependencies. The toolsets package uses `any` for deps and tool handlers -// type-assert to this struct. -type ToolDependencies struct { +// ToolDependencies defines the interface for dependencies that tool handlers need. +// This is an interface to allow different implementations: +// - Local server: stores closures that create clients on demand +// - Remote server: can store pre-created clients per-request for efficiency +// +// The toolsets package uses `any` for deps and tool handlers type-assert to this interface. +type ToolDependencies interface { // GetClient returns a GitHub REST API client - GetClient GetClientFn + GetClient(ctx context.Context) (*gogithub.Client, error) // GetGQLClient returns a GitHub GraphQL client - GetGQLClient GetGQLClientFn + GetGQLClient(ctx context.Context) (*githubv4.Client, error) + + // GetRawClient returns a raw content client for GitHub + GetRawClient(ctx context.Context) (*raw.Client, error) - // GetRawClient returns a raw HTTP client for GitHub - GetRawClient raw.GetRawClientFn + // GetRepoAccessCache returns the lockdown mode repo access cache + GetRepoAccessCache() *lockdown.RepoAccessCache - // RepoAccessCache is the lockdown mode repo access cache - RepoAccessCache *lockdown.RepoAccessCache + // GetT returns the translation helper function + GetT() translations.TranslationHelperFunc - // T is the translation helper function - T translations.TranslationHelperFunc + // GetFlags returns feature flags + GetFlags() FeatureFlags + + // GetContentWindowSize returns the content window size for log truncation + GetContentWindowSize() int +} - // Flags are feature flags - Flags FeatureFlags +// BaseDeps is the standard implementation of ToolDependencies for the local server. +// It stores pre-created clients. The remote server can create its own struct +// implementing ToolDependencies with different client creation strategies. +type BaseDeps struct { + // Pre-created clients + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client - // ContentWindowSize is the size of the content window for log truncation + // Static dependencies + RepoAccessCache *lockdown.RepoAccessCache + T translations.TranslationHelperFunc + Flags FeatureFlags ContentWindowSize int } -// NewTool creates a ServerTool with fully-typed ToolDependencies. +// NewBaseDeps creates a BaseDeps with the provided clients and configuration. +func NewBaseDeps( + client *gogithub.Client, + gqlClient *githubv4.Client, + rawClient *raw.Client, + repoAccessCache *lockdown.RepoAccessCache, + t translations.TranslationHelperFunc, + flags FeatureFlags, + contentWindowSize int, +) *BaseDeps { + return &BaseDeps{ + Client: client, + GQLClient: gqlClient, + RawClient: rawClient, + RepoAccessCache: repoAccessCache, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } +} + +// GetClient implements ToolDependencies. +func (d BaseDeps) GetClient(_ context.Context) (*gogithub.Client, error) { + return d.Client, nil +} + +// GetGQLClient implements ToolDependencies. +func (d BaseDeps) GetGQLClient(_ context.Context) (*githubv4.Client, error) { + return d.GQLClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { + return d.RawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } + +// GetT implements ToolDependencies. +func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } + +// NewTool creates a ServerTool with fully-typed ToolDependencies and toolset metadata. // This helper isolates the type assertion from `any` to `ToolDependencies`, // so tool implementations remain fully typed without assertions scattered throughout. -func NewTool[In, Out any](tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, func(d any) mcp.ToolHandlerFor[In, Out] { +func NewTool[In, Out any](toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { return handler(d.(ToolDependencies)) }) } -// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies +// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies and toolset metadata // for handlers that conform to mcp.ToolHandler directly. -func NewToolFromHandler(tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { - return toolsets.NewServerToolFromHandler(tool, func(d any) mcp.ToolHandler { +func NewToolFromHandler(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) registry.ServerTool { + return registry.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { return handler(d.(ToolDependencies)) }) } diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 94f7f6f1b..50364fa58 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -122,8 +122,9 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { return &BasicNoOrder{} } -func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussions(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "list_discussions", Description: t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation."), @@ -275,8 +276,9 @@ func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussion(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "get_discussion", Description: t("TOOL_GET_DISCUSSION_DESCRIPTION", "Get a specific discussion by ID"), @@ -379,8 +381,9 @@ func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussionComments(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "get_discussion_comments", Description: t("TOOL_GET_DISCUSSION_COMMENTS_DESCRIPTION", "Get comments from a discussion"), @@ -506,8 +509,9 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.Server ) } -func ListDiscussionCategories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussionCategories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "list_discussion_categories", Description: t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation."), diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 758c82200..73ae66748 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -447,7 +447,7 @@ func Test_ListDiscussions(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) @@ -559,7 +559,7 @@ func Test_GetDiscussion(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetDiscussion, vars, tc.response) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} @@ -639,7 +639,7 @@ func Test_GetDiscussionComments(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetComments, vars, mockResponse) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{ @@ -791,7 +791,7 @@ func Test_ListDiscussionCategories(t *testing.T) { httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 75d74f676..a749ecd1b 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -5,28 +5,64 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) []any { - toolsetNames := make([]any, 0, len(toolsetGroup.Toolsets)) - for name := range toolsetGroup.Toolsets { - toolsetNames = append(toolsetNames, name) +// DynamicToolDependencies contains dependencies for dynamic toolset management tools. +// It includes the managed Registry, the server for registration, and the deps +// that will be passed to tools when they are dynamically enabled. +type DynamicToolDependencies struct { + // Server is the MCP server to register tools with + Server *mcp.Server + // Registry contains all available tools that can be enabled dynamically + Registry *registry.Registry + // ToolDeps are the dependencies passed to tools when they are registered + ToolDeps any + // T is the translation helper function + T translations.TranslationHelperFunc +} + +// NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. +func NewDynamicTool(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { + return handler(d.(DynamicToolDependencies)) + }) +} + +// toolsetIDsEnum returns the list of toolset IDs as an enum for JSON Schema. +func toolsetIDsEnum(r *registry.Registry) []any { + toolsetIDs := r.ToolsetIDs() + result := make([]any, len(toolsetIDs)) + for i, id := range toolsetIDs { + result[i] = id + } + return result +} + +// DynamicTools returns the tools for dynamic toolset management. +// These tools allow runtime discovery and enablement of registry. +// The r parameter provides the available toolset IDs for JSON Schema enums. +func DynamicTools(r *registry.Registry) []registry.ServerTool { + return []registry.ServerTool{ + ListAvailableToolsets(), + GetToolsetsTools(r), + EnableToolset(r), } - return toolsetNames } -func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// EnableToolset creates a tool that enables a toolset at runtime. +func EnableToolset(r *registry.Registry) registry.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "enable_toolset", - Description: t("TOOL_ENABLE_TOOLSET_DESCRIPTION", "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable"), + Description: "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_ENABLE_TOOLSET_USER_TITLE", "Enable a toolset"), - // Not modifying GitHub data so no need to show a warning + Title: "Enable a toolset", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -35,44 +71,53 @@ func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t transla "toolset": { Type: "string", Description: "The name of the toolset to enable", - Enum: ToolsetEnum(toolsetGroup), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsets back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - if toolset.Enabled { - return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil - } + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + toolsetID := registry.ToolsetID(toolsetName) + + if !deps.Registry.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } + + if deps.Registry.IsToolsetEnabled(toolsetID) { + return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil + } - toolset.Enabled = true + // Mark the toolset as enabled so IsToolsetEnabled returns true + deps.Registry.EnableToolset(toolsetID) - // caution: this currently affects the global tools and notifies all clients: - // - // Send notification to all initialized sessions - // s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - toolset.RegisterTools(s) + // Get tools for this toolset and register them with the managed deps + toolsForToolset := deps.Registry.ToolsForToolset(toolsetID) + for _, st := range toolsForToolset { + st.RegisterFunc(deps.Server, deps.ToolDeps) + } - return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled", toolsetName)), nil, nil - }) + return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled with %d tools", toolsetName, len(toolsForToolset))), nil, nil + } + }, + ) } -func ListAvailableToolsets(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// ListAvailableToolsets creates a tool that lists all available registry. +func ListAvailableToolsets() registry.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "list_available_toolsets", - Description: t("TOOL_LIST_AVAILABLE_TOOLSETS_DESCRIPTION", "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call"), + Description: "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_AVAILABLE_TOOLSETS_USER_TITLE", "List available toolsets"), + Title: "List available toolsets", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -80,38 +125,42 @@ func ListAvailableToolsets(toolsetGroup *toolsets.ToolsetGroup, t translations.T Properties: map[string]*jsonschema.Schema{}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization - - payload := []map[string]string{} + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + toolsetIDs := deps.Registry.ToolsetIDs() + descriptions := deps.Registry.ToolsetDescriptions() - for name, ts := range toolsetGroup.Toolsets { - { + payload := make([]map[string]string, 0, len(toolsetIDs)) + for _, id := range toolsetIDs { t := map[string]string{ - "name": name, - "description": ts.Description, + "name": string(id), + "description": descriptions[id], "can_enable": "true", - "currently_enabled": fmt.Sprintf("%t", ts.Enabled), + "currently_enabled": fmt.Sprintf("%t", deps.Registry.IsToolsetEnabled(id)), } payload = append(payload, t) } - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - }) + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetToolsetsTools(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// GetToolsetsTools creates a tool that lists all tools in a specific toolset. +func GetToolsetsTools(r *registry.Registry) registry.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "get_toolset_tools", - Description: t("TOOL_GET_TOOLSET_TOOLS_DESCRIPTION", "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task"), + Description: "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_TOOLSET_TOOLS_USER_TITLE", "List all tools in a toolset"), + Title: "List all tools in a toolset", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -120,39 +169,46 @@ func GetToolsetsTools(toolsetGroup *toolsets.ToolsetGroup, t translations.Transl "toolset": { Type: "string", Description: "The name of the toolset you want to get the tools for", - Enum: ToolsetEnum(toolsetGroup), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - payload := []map[string]string{} - - for _, st := range toolset.GetAvailableTools() { - tool := map[string]string{ - "name": st.Tool.Name, - "description": st.Tool.Description, - "can_enable": "true", - "toolset": toolsetName, + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - payload = append(payload, tool) - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + toolsetID := registry.ToolsetID(toolsetName) - return utils.NewToolResultText(string(r)), nil, nil - }) + if !deps.Registry.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } + + // Get all tools for this toolset (ignoring current filters for discovery) + toolsInToolset := deps.Registry.ToolsForToolset(toolsetID) + payload := make([]map[string]string, 0, len(toolsInToolset)) + + for _, st := range toolsInToolset { + tool := map[string]string{ + "name": st.Tool.Name, + "description": st.Tool.Description, + "can_enable": "true", + "toolset": toolsetName, + } + payload = append(payload, tool) + } + + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } diff --git a/pkg/github/dynamic_tools_test.go b/pkg/github/dynamic_tools_test.go new file mode 100644 index 000000000..4558204dc --- /dev/null +++ b/pkg/github/dynamic_tools_test.go @@ -0,0 +1,231 @@ +package github + +import ( + "context" + "encoding/json" + "testing" + + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createDynamicRequest creates an MCP request with the given arguments for dynamic tools. +func createDynamicRequest(args map[string]any) *mcp.CallToolRequest { + argsJSON, _ := json.Marshal(args) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: json.RawMessage(argsJSON), + }, + } +} + +func TestDynamicTools_ListAvailableToolsets(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the list_available_toolsets tool + tool := ListAvailableToolsets() + handler := tool.Handler(deps) + + // Call the handler + result, err := handler(context.Background(), createDynamicRequest(map[string]any{})) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var toolsets []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &toolsets) + require.NoError(t, err) + + // Verify we got toolsets + assert.NotEmpty(t, toolsets, "should have available toolsets") + + // Find the repos toolset and verify it's not enabled + var reposToolset map[string]string + for _, ts := range toolsets { + if ts["name"] == "repos" { + reposToolset = ts + break + } + } + require.NotNil(t, reposToolset, "repos toolset should exist") + assert.Equal(t, "false", reposToolset["currently_enabled"], "repos should not be enabled initially") +} + +func TestDynamicTools_GetToolsetTools(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the get_toolset_tools tool + tool := GetToolsetsTools(reg) + handler := tool.Handler(deps) + + // Call the handler for repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var tools []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &tools) + require.NoError(t, err) + + // Verify we got tools for the repos toolset + assert.NotEmpty(t, tools, "repos toolset should have tools") + + // Verify at least get_commit is there (a repos toolset tool) + var foundGetCommit bool + for _, tool := range tools { + if tool["name"] == "get_commit" { + foundGetCommit = true + break + } + } + assert.True(t, foundGetCommit, "get_commit should be in repos toolset") +} + +func TestDynamicTools_EnableToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: NewBaseDeps(nil, nil, nil, nil, translations.NullTranslationHelper, FeatureFlags{}, 0), + T: translations.NullTranslationHelper, + } + + // Verify repos is not enabled initially + assert.False(t, reg.IsToolsetEnabled(registry.ToolsetID("repos"))) + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Enable the repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Verify the toolset is now enabled + assert.True(t, reg.IsToolsetEnabled(registry.ToolsetID("repos")), "repos should be enabled after enable_toolset") + + // Verify the success message + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "enabled") + + // Try enabling again - should say already enabled + result2, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + textContent2 := result2.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent2.Text, "already enabled") +} + +func TestDynamicTools_EnableToolset_InvalidToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Try to enable a non-existent toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "nonexistent", + })) + require.NoError(t, err) + require.NotNil(t, result) + + // Should be an error result + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "not found") +} + +func TestDynamicTools_ToolsetsEnum(t *testing.T) { + // Build a registry + reg := NewRegistry(translations.NullTranslationHelper).Build() + + // Get tools to verify they have proper enum values + tools := DynamicTools(reg) + + // Find enable_toolset and get_toolset_tools + for _, tool := range tools { + if tool.Tool.Name == "enable_toolset" || tool.Tool.Name == "get_toolset_tools" { + // Verify the toolset property has an enum + schema := tool.Tool.InputSchema.(*jsonschema.Schema) + toolsetProp := schema.Properties["toolset"] + require.NotNil(t, toolsetProp, "toolset property should exist") + assert.NotEmpty(t, toolsetProp.Enum, "toolset property should have enum values") + + // Verify repos is in the enum + var foundRepos bool + for _, v := range toolsetProp.Enum { + if v == registry.ToolsetID("repos") { + foundRepos = true + break + } + } + assert.True(t, foundRepos, "repos should be in toolset enum for %s", tool.Tool.Name) + } + } +} diff --git a/pkg/github/gists.go b/pkg/github/gists.go index baca42399..7b8313f37 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,8 +16,9 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGists(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "list_gists", Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), @@ -103,8 +104,9 @@ func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetGist creates a tool to get the content of a gist -func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "get_gist", Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), @@ -161,8 +163,9 @@ func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateGist creates a tool to create a new gist -func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "create_gist", Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), @@ -264,8 +267,9 @@ func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "update_gist", Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index 44b294eb6..7c6f69833 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -158,8 +158,8 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -275,8 +275,8 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -421,8 +421,8 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -580,8 +580,8 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/git.go b/pkg/github/git.go index c5fdded7c..4755e2eb0 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -7,7 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -38,8 +38,9 @@ type TreeResponse struct { } // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. -func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetRepositoryTree(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataGit, mcp.Tool{ Name: "get_repository_tree", Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 69442e312..c971995b2 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -148,8 +148,8 @@ func Test_GetRepositoryTree(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..4fdf5f928 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -216,7 +216,7 @@ func TestOptionalParamOK(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Test with string type assertion if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { - val, ok, err := OptionalParamOK[string, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[string](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -231,7 +231,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with bool type assertion if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { - val, ok, err := OptionalParamOK[bool, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[bool](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -246,7 +246,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with float64 type assertion (for number case) if _, isFloat := tc.expectedVal.(float64); isFloat { - val, ok, err := OptionalParamOK[float64, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[float64](tc.args, tc.paramName) if tc.expectError { // This case shouldn't happen for float64 in the defined tests require.Fail(t, "Unexpected error case for float64") diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 142bdd421..3e449f8c5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -11,8 +11,8 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -230,7 +230,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -263,6 +263,7 @@ Options are: WithPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "issue_read", Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."), @@ -309,13 +310,13 @@ Options are: switch method { case "get": - result, err := GetIssue(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, deps.Flags) + result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -544,8 +545,9 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, } // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. -func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssueTypes(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "list_issue_types", Description: t("TOOL_LIST_ISSUE_TYPES_FOR_ORG", "List supported issue types for repository owner (organization)."), @@ -600,8 +602,9 @@ func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddIssueComment(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "add_issue_comment", Description: t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to a specific issue in a GitHub repository. Use this tool to add comments to pull requests as well (in this case pass pull request number as issue_number), but only if user is not asking specifically to add review comments."), @@ -684,8 +687,9 @@ func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SubIssueWrite creates a tool to add a sub-issue to a parent issue. -func SubIssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SubIssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "sub_issue_write", Description: t("TOOL_SUB_ISSUE_WRITE_DESCRIPTION", "Add a sub-issue to a parent issue in a GitHub repository."), @@ -912,7 +916,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri } // SearchIssues creates a tool to search for issues. -func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -956,6 +960,7 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "search_issues", Description: t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues in GitHub repositories using issues search syntax already scoped to is:issue"), @@ -974,8 +979,9 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { } // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. -func IssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "issue_write", Description: t("TOOL_ISSUE_WRITE_DESCRIPTION", "Create a new or update an existing issue in a GitHub repository."), @@ -1332,7 +1338,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1376,6 +1382,7 @@ func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { WithCursorPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "list_issues", Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."), @@ -1599,7 +1606,7 @@ func (d *mvpDescription) String() string { return sb.String() } -func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AssignCopilotToIssue(t translations.TranslationHelperFunc) registry.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", outcomes: []string{ @@ -1611,6 +1618,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerT } return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "assign_copilot_to_issue", Description: t("TOOL_ASSIGN_COPILOT_TO_ISSUE_DESCRIPTION", description.String()), @@ -1797,8 +1805,10 @@ func parseISOTimestamp(timestamp string) (time.Time, error) { return time.Time{}, fmt.Errorf("invalid ISO 8601 timestamp: %s (supported formats: YYYY-MM-DDThh:mm:ssZ or YYYY-MM-DD)", timestamp) } -func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, mcp.PromptHandler) { - return mcp.Prompt{ +func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "AssignCodingAgent", Description: t("PROMPT_ASSIGN_CODING_AGENT_DESCRIPTION", "Assign GitHub Coding Agent to multiple tasks in a GitHub repository."), Arguments: []*mcp.PromptArgument{ @@ -1808,7 +1818,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, Required: true, }, }, - }, func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, + func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { repo := request.Params.Arguments["repo"] messages := []*mcp.PromptMessage{ @@ -1852,5 +1863,6 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c832f031a..4c686cc57 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -330,9 +330,9 @@ func Test_GetIssue(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -447,8 +447,8 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -781,8 +781,8 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -952,9 +952,9 @@ func Test_CreateIssue(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1268,8 +1268,8 @@ func Test_ListIssues(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1769,9 +1769,9 @@ func Test_UpdateIssue(t *testing.T) { // Setup clients with mocks restClient := github.NewClient(tc.mockedRESTClient) gqlClient := githubv4.NewClient(tc.mockedGQLClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -2016,9 +2016,9 @@ func Test_GetIssueComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 15*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -2136,9 +2136,9 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -2560,8 +2560,8 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2791,8 +2791,8 @@ func Test_AddSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3035,9 +3035,9 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -3275,8 +3275,8 @@ func Test_RemoveSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3564,8 +3564,8 @@ func Test_ReprioritizeSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3698,8 +3698,8 @@ func Test_ListIssueTypes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/labels.go b/pkg/github/labels.go index 25ac9f7fe..6088aaa8e 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -7,6 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -15,377 +16,391 @@ import ( ) // GetLabel retrieves a specific label by name from a GitHub repository -func GetLabel(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_label", - Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_LABEL_TITLE", "Get a specific label from a repository."), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization name)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "name": { - Type: "string", - Description: "Label name.", +func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ + Name: "get_label", + Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_LABEL_TITLE", "Get a specific label from a repository."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "name": { + Type: "string", + Description: "Label name.", + }, }, + Required: []string{"owner", "repo", "name"}, }, - Required: []string{"owner", "repo", "name"}, }, - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + var query struct { + Repository struct { + Label struct { + ID githubv4.ID + Name githubv4.String + Color githubv4.String + Description githubv4.String + } `graphql:"label(name: $name)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "name": githubv4.String(name), + } + + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil + } + + if query.Repository.Label.Name == "" { + return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil + } + + label := map[string]any{ + "id": fmt.Sprintf("%v", query.Repository.Label.ID), + "name": string(query.Repository.Label.Name), + "color": string(query.Repository.Label.Color), + "description": string(query.Repository.Label.Description), + } + + out, err := json.Marshal(label) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal label: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + } + }, + ) +} - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var query struct { - Repository struct { - Label struct { - ID githubv4.ID - Name githubv4.String - Color githubv4.String - Description githubv4.String - } `graphql:"label(name: $name)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "name": githubv4.String(name), - } - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil - } - - if query.Repository.Label.Name == "" { - return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil - } - - label := map[string]any{ - "id": fmt.Sprintf("%v", query.Repository.Label.ID), - "name": string(query.Repository.Label.Name), - "color": string(query.Repository.Label.Color), - "description": string(query.Repository.Label.Description), - } - - out, err := json.Marshal(label) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal label: %w", err) - } - - return utils.NewToolResultText(string(out)), nil, nil - }) - - return tool, handler +// GetLabelForLabelsToolset returns the same GetLabel tool but registered in the labels toolset. +// This provides conformance with the original behavior where get_label was in both toolsets. +func GetLabelForLabelsToolset(t translations.TranslationHelperFunc) registry.ServerTool { + tool := GetLabel(t) + tool.Toolset = ToolsetLabels + return tool } // ListLabels lists labels from a repository -func ListLabels(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_label", - Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository."), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization name) - required for all operations", - }, - "repo": { - Type: "string", - Description: "Repository name - required for all operations", +func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetLabels, + mcp.Tool{ + Name: "list_label", + Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name) - required for all operations", + }, + "repo": { + Type: "string", + Description: "Repository name - required for all operations", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - var query struct { - Repository struct { - Labels struct { - Nodes []struct { - ID githubv4.ID - Name githubv4.String - Color githubv4.String - Description githubv4.String + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + var query struct { + Repository struct { + Labels struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + Color githubv4.String + Description githubv4.String + } + TotalCount githubv4.Int + } `graphql:"labels(first: 100)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + } + + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil + } + + labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) + for i, labelNode := range query.Repository.Labels.Nodes { + labels[i] = map[string]any{ + "id": fmt.Sprintf("%v", labelNode.ID), + "name": string(labelNode.Name), + "color": string(labelNode.Color), + "description": string(labelNode.Description), } - TotalCount githubv4.Int - } `graphql:"labels(first: 100)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - } - - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil - } - - labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) - for i, labelNode := range query.Repository.Labels.Nodes { - labels[i] = map[string]any{ - "id": fmt.Sprintf("%v", labelNode.ID), - "name": string(labelNode.Name), - "color": string(labelNode.Color), - "description": string(labelNode.Description), - } - } + } - response := map[string]any{ - "labels": labels, - "totalCount": int(query.Repository.Labels.TotalCount), - } + response := map[string]any{ + "labels": labels, + "totalCount": int(query.Repository.Labels.TotalCount), + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) - } + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) + } - return utils.NewToolResultText(string(out)), nil, nil - }) - - return tool, handler + return utils.NewToolResultText(string(out)), nil, nil + } + }, + ) } // LabelWrite handles create, update, and delete operations for GitHub labels -func LabelWrite(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "label_write", - Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LABEL_WRITE_TITLE", "Write operations on repository labels."), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "method": { - Type: "string", - Description: "Operation to perform: 'create', 'update', or 'delete'", - Enum: []any{"create", "update", "delete"}, - }, - "owner": { - Type: "string", - Description: "Repository owner (username or organization name)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "name": { - Type: "string", - Description: "Label name - required for all operations", - }, - "new_name": { - Type: "string", - Description: "New name for the label (used only with 'update' method to rename)", - }, - "color": { - Type: "string", - Description: "Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'.", - }, - "description": { - Type: "string", - Description: "Label description text. Optional for 'create' and 'update'.", +func LabelWrite(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetLabels, + mcp.Tool{ + Name: "label_write", + Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LABEL_WRITE_TITLE", "Write operations on repository labels."), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "Operation to perform: 'create', 'update', or 'delete'", + Enum: []any{"create", "update", "delete"}, + }, + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "name": { + Type: "string", + Description: "Label name - required for all operations", + }, + "new_name": { + Type: "string", + Description: "New name for the label (used only with 'update' method to rename)", + }, + "color": { + Type: "string", + Description: "Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'.", + }, + "description": { + Type: "string", + Description: "Label description text. Optional for 'create' and 'update'.", + }, }, + Required: []string{"method", "owner", "repo", "name"}, }, - Required: []string{"method", "owner", "repo", "name"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Get and validate required parameters - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - method = strings.ToLower(method) - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Get optional parameters - newName, _ := OptionalParam[string](args, "new_name") - color, _ := OptionalParam[string](args, "color") - description, _ := OptionalParam[string](args, "description") - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - switch method { - case "create": - // Validate required params for create - if color == "" { - return utils.NewToolResultError("color is required for create"), nil, nil - } - - // Get repository ID - repoID, err := getRepositoryID(ctx, client, owner, repo) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Get and validate required parameters + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + method = strings.ToLower(method) + + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + // Get optional parameters + newName, _ := OptionalParam[string](args, "new_name") + color, _ := OptionalParam[string](args, "color") + description, _ := OptionalParam[string](args, "description") + + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + switch method { + case "create": + // Validate required params for create + if color == "" { + return utils.NewToolResultError("color is required for create"), nil, nil + } - input := githubv4.CreateLabelInput{ - RepositoryID: repoID, - Name: githubv4.String(name), - Color: githubv4.String(color), - } - if description != "" { - d := githubv4.String(description) - input.Description = &d - } + // Get repository ID + repoID, err := getRepositoryID(ctx, client, owner, repo) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil + } - var mutation struct { - CreateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID + input := githubv4.CreateLabelInput{ + RepositoryID: repoID, + Name: githubv4.String(name), + Color: githubv4.String(color), + } + if description != "" { + d := githubv4.String(description) + input.Description = &d } - } `graphql:"createLabel(input: $input)"` - } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil - } + var mutation struct { + CreateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"createLabel(input: $input)"` + } - return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil + } - case "update": - // Validate required params for update - if newName == "" && color == "" && description == "" { - return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil - } + return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + case "update": + // Validate required params for update + if newName == "" && color == "" && description == "" { + return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil + } - input := githubv4.UpdateLabelInput{ - ID: labelID, - } - if newName != "" { - n := githubv4.String(newName) - input.Name = &n - } - if color != "" { - c := githubv4.String(color) - input.Color = &c - } - if description != "" { - d := githubv4.String(description) - input.Description = &d - } + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var mutation struct { - UpdateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID + input := githubv4.UpdateLabelInput{ + ID: labelID, + } + if newName != "" { + n := githubv4.String(newName) + input.Name = &n + } + if color != "" { + c := githubv4.String(color) + input.Color = &c + } + if description != "" { + d := githubv4.String(description) + input.Description = &d } - } `graphql:"updateLabel(input: $input)"` - } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil - } + var mutation struct { + UpdateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"updateLabel(input: $input)"` + } - return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil + } - case "delete": - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil - input := githubv4.DeleteLabelInput{ - ID: labelID, - } + case "delete": + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var mutation struct { - DeleteLabel struct { - ClientMutationID githubv4.String - } `graphql:"deleteLabel(input: $input)"` - } + input := githubv4.DeleteLabelInput{ + ID: labelID, + } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil - } + var mutation struct { + DeleteLabel struct { + ClientMutationID githubv4.String + } `graphql:"deleteLabel(input: $input)"` + } - return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil + } - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil - } - }) + return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil - return tool, handler + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil + } + } + }, + ) } // Helper function to get repository ID diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go index 12d447d72..fa646e884 100644 --- a/pkg/github/labels_test.go +++ b/pkg/github/labels_test.go @@ -17,8 +17,8 @@ func TestGetLabel(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := GetLabel(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetLabel(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_label", tool.Name) @@ -114,10 +114,13 @@ func TestGetLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := GetLabel(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -139,8 +142,8 @@ func TestListLabels(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := ListLabels(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListLabels(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_label", tool.Name) @@ -209,10 +212,13 @@ func TestListLabels(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := ListLabels(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -234,8 +240,8 @@ func TestWriteLabel(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := LabelWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := LabelWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "label_write", tool.Name) @@ -457,10 +463,13 @@ func TestWriteLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := LabelWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) assert.NotNil(t, result) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 23d63c946..569bef002 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -10,7 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,8 +25,9 @@ const ( ) // ListNotifications creates a tool to list notifications for the current user. -func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListNotifications(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "list_notifications", Description: t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub."), @@ -162,8 +163,9 @@ func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool } // DismissNotification creates a tool to mark a notification as read/done. -func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DismissNotification(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "dismiss_notification", Description: t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done"), @@ -244,8 +246,9 @@ func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTo } // MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MarkAllNotificationsRead(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "mark_all_notifications_read", Description: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read"), @@ -336,8 +339,9 @@ func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.Ser } // GetNotificationDetails creates a tool to get details for a specific notification. -func GetNotificationDetails(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetNotificationDetails(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "get_notification_details", Description: t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first."), @@ -405,8 +409,9 @@ const ( ) // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) -func ManageNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "manage_notification_subscription", Description: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription."), @@ -501,8 +506,9 @@ const ( ) // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) -func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "manage_repository_notification_subscription", Description: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository."), diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 0a330c316..f730654db 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -125,8 +125,8 @@ func Test_ListNotifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -258,8 +258,8 @@ func Test_ManageNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -421,8 +421,8 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -563,8 +563,8 @@ func Test_DismissNotification(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -688,8 +688,8 @@ func Test_MarkAllNotificationsRead(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -772,8 +772,8 @@ func Test_GetNotificationDetails(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/projects.go b/pkg/github/projects.go index 79dfb25ce..4c0b5c09d 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -9,6 +9,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -24,8 +25,10 @@ const ( MaxProjectsPerPage = 50 ) -func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjects(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_projects", Description: t("TOOL_LIST_PROJECTS_DESCRIPTION", `List Projects for a user or organization`), Annotations: &mcp.ToolAnnotations{ @@ -63,81 +66,88 @@ func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, Required: []string{"owner_type", "owner"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - queryStr, err := OptionalParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + queryStr, err := OptionalParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projects []*github.ProjectV2 - var queryPtr *string + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryStr != "" { - queryPtr = &queryStr - } + var resp *github.Response + var projects []*github.ProjectV2 + var queryPtr *string - minimalProjects := []MinimalProject{} - opts := &github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - Query: queryPtr, - } + if queryStr != "" { + queryPtr = &queryStr + } - if ownerType == "org" { - projects, resp, err = client.Projects.ListOrganizationProjects(ctx, owner, opts) - } else { - projects, resp, err = client.Projects.ListUserProjects(ctx, owner, opts) - } + minimalProjects := []MinimalProject{} + opts := &github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + Query: queryPtr, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list projects", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projects, resp, err = client.Projects.ListOrganizationProjects(ctx, owner, opts) + } else { + projects, resp, err = client.Projects.ListUserProjects(ctx, owner, opts) + } - for _, project := range projects { - minimalProjects = append(minimalProjects, *convertToMinimalProject(project)) - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list projects", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - response := map[string]any{ - "projects": minimalProjects, - "pageInfo": buildPageInfo(resp), - } + for _, project := range projects { + minimalProjects = append(minimalProjects, *convertToMinimalProject(project)) + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "projects": minimalProjects, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil - } + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProject(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project", Description: t("TOOL_GET_PROJECT_DESCRIPTION", "Get Project for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -163,65 +173,71 @@ func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, Required: []string{"project_number", "owner_type", "owner"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var project *github.ProjectV2 + var resp *github.Response + var project *github.ProjectV2 - if ownerType == "org" { - project, resp, err = client.Projects.GetOrganizationProject(ctx, owner, projectNumber) - } else { - project, resp, err = client.Projects.GetUserProject(ctx, owner, projectNumber) - } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + project, resp, err = client.Projects.GetOrganizationProject(ctx, owner, projectNumber) + } else { + project, resp, err = client.Projects.GetUserProject(ctx, owner, projectNumber) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to get project: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + minimalProject := convertToMinimalProject(project) + r, err := json.Marshal(minimalProject) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get project: %s", string(body))), nil, nil - } - minimalProject := convertToMinimalProject(project) - r, err := json.Marshal(minimalProject) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjectFields(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_project_fields", Description: t("TOOL_LIST_PROJECT_FIELDS_DESCRIPTION", "List Project fields for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -259,70 +275,77 @@ func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectFields []*github.ProjectV2Field + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - } + var resp *github.Response + var projectFields []*github.ProjectV2Field - if ownerType == "org" { - projectFields, resp, err = client.Projects.ListOrganizationProjectFields(ctx, owner, projectNumber, opts) - } else { - projectFields, resp, err = client.Projects.ListUserProjectFields(ctx, owner, projectNumber, opts) - } + opts := &github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list project fields", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectFields, resp, err = client.Projects.ListOrganizationProjectFields(ctx, owner, projectNumber, opts) + } else { + projectFields, resp, err = client.Projects.ListUserProjectFields(ctx, owner, projectNumber, opts) + } - response := map[string]any{ - "fields": projectFields, - "pageInfo": buildPageInfo(resp), - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list project fields", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "fields": projectFields, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil - } + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProjectField(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project_field", Description: t("TOOL_GET_PROJECT_FIELD_DESCRIPTION", "Get Project field for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -352,64 +375,71 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc }, Required: []string{"owner_type", "owner", "project_number", "field_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fieldID, err := RequiredBigInt(args, "field_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var resp *github.Response - var projectField *github.ProjectV2Field + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fieldID, err := RequiredBigInt(args, "field_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if ownerType == "org" { - projectField, resp, err = client.Projects.GetOrganizationProjectField(ctx, owner, projectNumber, fieldID) - } else { - projectField, resp, err = client.Projects.GetUserProjectField(ctx, owner, projectNumber, fieldID) - } + var resp *github.Response + var projectField *github.ProjectV2Field - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project field", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectField, resp, err = client.Projects.GetOrganizationProjectField(ctx, owner, projectNumber, fieldID) + } else { + projectField, resp, err = client.Projects.GetUserProjectField(ctx, owner, projectNumber, fieldID) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project field", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to get project field: %s", string(body))), nil, nil + } + r, err := json.Marshal(projectField) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get project field: %s", string(body))), nil, nil - } - r, err := json.Marshal(projectField) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjectItems(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_project_items", Description: t("TOOL_LIST_PROJECT_ITEMS_DESCRIPTION", `Search project items with advanced filtering`), Annotations: &mcp.ToolAnnotations{ @@ -458,89 +488,96 @@ func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFun }, Required: []string{"owner_type", "owner", "project_number"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - queryStr, err := OptionalParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - fields, err := OptionalBigIntArrayParam(args, "fields") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + queryStr, err := OptionalParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := extractPaginationOptionsFromArgs(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + fields, err := OptionalBigIntArrayParam(args, "fields") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := extractPaginationOptionsFromArgs(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectItems []*github.ProjectV2Item - var queryPtr *string + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryStr != "" { - queryPtr = &queryStr - } + var resp *github.Response + var projectItems []*github.ProjectV2Item + var queryPtr *string - opts := &github.ListProjectItemsOptions{ - Fields: fields, - ListProjectsOptions: github.ListProjectsOptions{ - ListProjectsPaginationOptions: pagination, - Query: queryPtr, - }, - } + if queryStr != "" { + queryPtr = &queryStr + } - if ownerType == "org" { - projectItems, resp, err = client.Projects.ListOrganizationProjectItems(ctx, owner, projectNumber, opts) - } else { - projectItems, resp, err = client.Projects.ListUserProjectItems(ctx, owner, projectNumber, opts) - } + opts := &github.ListProjectItemsOptions{ + Fields: fields, + ListProjectsOptions: github.ListProjectsOptions{ + ListProjectsPaginationOptions: pagination, + Query: queryPtr, + }, + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectListFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + projectItems, resp, err = client.Projects.ListOrganizationProjectItems(ctx, owner, projectNumber, opts) + } else { + projectItems, resp, err = client.Projects.ListUserProjectItems(ctx, owner, projectNumber, opts) + } - response := map[string]any{ - "items": projectItems, - "pageInfo": buildPageInfo(resp), - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectListFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + response := map[string]any{ + "items": projectItems, + "pageInfo": buildPageInfo(resp), + } - return utils.NewToolResultText(string(r)), nil, nil - } + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project_item", Description: t("TOOL_GET_PROJECT_ITEM_DESCRIPTION", "Get a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -577,71 +614,78 @@ func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) }, Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fields, err := OptionalBigIntArrayParam(args, "fields") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fields, err := OptionalBigIntArrayParam(args, "fields") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var projectItem *github.ProjectV2Item - var opts *github.GetProjectItemOptions + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if len(fields) > 0 { - opts = &github.GetProjectItemOptions{ - Fields: fields, + var resp *github.Response + var projectItem *github.ProjectV2Item + var opts *github.GetProjectItemOptions + + if len(fields) > 0 { + opts = &github.GetProjectItemOptions{ + Fields: fields, + } } - } - if ownerType == "org" { - projectItem, resp, err = client.Projects.GetOrganizationProjectItem(ctx, owner, projectNumber, itemID, opts) - } else { - projectItem, resp, err = client.Projects.GetUserProjectItem(ctx, owner, projectNumber, itemID, opts) - } + if ownerType == "org" { + projectItem, resp, err = client.Projects.GetOrganizationProjectItem(ctx, owner, projectNumber, itemID, opts) + } else { + projectItem, resp, err = client.Projects.GetUserProjectItem(ctx, owner, projectNumber, itemID, opts) + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get project item", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get project item", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(projectItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(projectItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func AddProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "add_project_item", Description: t("TOOL_ADD_PROJECT_ITEM_DESCRIPTION", "Add a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -676,78 +720,85 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) }, Required: []string{"owner_type", "owner", "project_number", "item_type", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - itemType, err := RequiredParam[string](args, "item_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if itemType != "issue" && itemType != "pull_request" { - return utils.NewToolResultError("item_type must be either 'issue' or 'pull_request'"), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + itemType, err := RequiredParam[string](args, "item_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if itemType != "issue" && itemType != "pull_request" { + return utils.NewToolResultError("item_type must be either 'issue' or 'pull_request'"), nil, nil + } - newItem := &github.AddProjectItemOptions{ - ID: itemID, - Type: toNewProjectType(itemType), - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var addedItem *github.ProjectV2Item + newItem := &github.AddProjectItemOptions{ + ID: itemID, + Type: toNewProjectType(itemType), + } - if ownerType == "org" { - addedItem, resp, err = client.Projects.AddOrganizationProjectItem(ctx, owner, projectNumber, newItem) - } else { - addedItem, resp, err = client.Projects.AddUserProjectItem(ctx, owner, projectNumber, newItem) - } + var resp *github.Response + var addedItem *github.ProjectV2Item - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectAddFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + addedItem, resp, err = client.Projects.AddOrganizationProjectItem(ctx, owner, projectNumber, newItem) + } else { + addedItem, resp, err = client.Projects.AddUserProjectItem(ctx, owner, projectNumber, newItem) + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectAddFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectAddFailedError, string(body))), nil, nil + } + r, err := json.Marshal(addedItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectAddFailedError, string(body))), nil, nil - } - r, err := json.Marshal(addedItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func UpdateProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "update_project_item", Description: t("TOOL_UPDATE_PROJECT_ITEM_DESCRIPTION", "Update a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -781,80 +832,87 @@ func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number", "item_id", "updated_field"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - rawUpdatedField, exists := args["updated_field"] - if !exists { - return utils.NewToolResultError("missing required parameter: updated_field"), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - fieldValue, ok := rawUpdatedField.(map[string]any) - if !ok || fieldValue == nil { - return utils.NewToolResultError("field_value must be an object"), nil, nil - } + rawUpdatedField, exists := args["updated_field"] + if !exists { + return utils.NewToolResultError("missing required parameter: updated_field"), nil, nil + } - updatePayload, err := buildUpdateProjectItem(fieldValue) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + fieldValue, ok := rawUpdatedField.(map[string]any) + if !ok || fieldValue == nil { + return utils.NewToolResultError("field_value must be an object"), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + updatePayload, err := buildUpdateProjectItem(fieldValue) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - var updatedItem *github.ProjectV2Item + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if ownerType == "org" { - updatedItem, resp, err = client.Projects.UpdateOrganizationProjectItem(ctx, owner, projectNumber, itemID, updatePayload) - } else { - updatedItem, resp, err = client.Projects.UpdateUserProjectItem(ctx, owner, projectNumber, itemID, updatePayload) - } + var resp *github.Response + var updatedItem *github.ProjectV2Item - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectUpdateFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if ownerType == "org" { + updatedItem, resp, err = client.Projects.UpdateOrganizationProjectItem(ctx, owner, projectNumber, itemID, updatePayload) + } else { + updatedItem, resp, err = client.Projects.UpdateUserProjectItem(ctx, owner, projectNumber, itemID, updatePayload) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectUpdateFailedError, + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectUpdateFailedError, string(body))), nil, nil + } + r, err := json.Marshal(updatedItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectUpdateFailedError, string(body))), nil, nil - } - r, err := json.Marshal(updatedItem) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DeleteProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "delete_project_item", Description: t("TOOL_DELETE_PROJECT_ITEM_DESCRIPTION", "Delete a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -884,53 +942,58 @@ func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ownerType, err := RequiredParam[string](args, "owner_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - projectNumber, err := RequiredInt(args, "project_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - itemID, err := RequiredBigInt(args, "item_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var resp *github.Response - if ownerType == "org" { - resp, err = client.Projects.DeleteOrganizationProjectItem(ctx, owner, projectNumber, itemID) - } else { - resp, err = client.Projects.DeleteUserProjectItem(ctx, owner, projectNumber, itemID) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ownerType, err := RequiredParam[string](args, "owner_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + projectNumber, err := RequiredInt(args, "project_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + itemID, err := RequiredBigInt(args, "item_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - ProjectDeleteFailedError, - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + var resp *github.Response + if ownerType == "org" { + resp, err = client.Projects.DeleteOrganizationProjectItem(ctx, owner, projectNumber, itemID) + } else { + resp, err = client.Projects.DeleteUserProjectItem(ctx, owner, projectNumber, itemID) + } - if resp.StatusCode != http.StatusNoContent { - body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + ProjectDeleteFailedError, + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectDeleteFailedError, string(body))), nil, nil + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectDeleteFailedError, string(body))), nil, nil + } + return utils.NewToolResultText("project item successfully deleted"), nil, nil } - return utils.NewToolResultText("project item successfully deleted"), nil, nil - } + }, + ) } type pageInfo struct { diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index e2814c8f9..67ecd8800 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -17,8 +17,8 @@ import ( ) func Test_ListProjects(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjects(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjects(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_projects", tool.Name) @@ -141,9 +141,12 @@ func Test_ListProjects(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjects(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -177,8 +180,8 @@ func Test_ListProjects(t *testing.T) { } func Test_GetProject(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProject(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProject(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project", tool.Name) @@ -277,9 +280,12 @@ func Test_GetProject(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProject(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -310,8 +316,8 @@ func Test_GetProject(t *testing.T) { } func Test_ListProjectFields(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjectFields(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjectFields(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_project_fields", tool.Name) @@ -426,9 +432,12 @@ func Test_ListProjectFields(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjectFields(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -464,8 +473,8 @@ func Test_ListProjectFields(t *testing.T) { } func Test_GetProjectField(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProjectField(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProjectField(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project_field", tool.Name) @@ -583,9 +592,12 @@ func Test_GetProjectField(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProjectField(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -622,8 +634,8 @@ func Test_GetProjectField(t *testing.T) { } func Test_ListProjectItems(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjectItems(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjectItems(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_project_items", tool.Name) @@ -786,9 +798,12 @@ func Test_ListProjectItems(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjectItems(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -824,8 +839,8 @@ func Test_ListProjectItems(t *testing.T) { } func Test_GetProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project_item", tool.Name) @@ -980,9 +995,12 @@ func Test_GetProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -1019,8 +1037,8 @@ func Test_GetProjectItem(t *testing.T) { } func Test_AddProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := AddProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_project_item", tool.Name) @@ -1206,10 +1224,13 @@ func Test_AddProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := AddProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -1255,8 +1276,8 @@ func Test_AddProjectItem(t *testing.T) { } func Test_UpdateProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := UpdateProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_project_item", tool.Name) @@ -1488,9 +1509,12 @@ func Test_UpdateProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := UpdateProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { @@ -1532,8 +1556,8 @@ func Test_UpdateProjectItem(t *testing.T) { } func Test_DeleteProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := DeleteProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := DeleteProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "delete_project_item", tool.Name) @@ -1652,9 +1676,12 @@ func Test_DeleteProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := DeleteProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) if tc.expectError { diff --git a/pkg/github/prompts.go b/pkg/github/prompts.go new file mode 100644 index 000000000..229902d90 --- /dev/null +++ b/pkg/github/prompts.go @@ -0,0 +1,16 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllPrompts returns all prompts with their embedded toolset metadata. +// Prompt functions return ServerPrompt directly with toolset info. +func AllPrompts(t translations.TranslationHelperFunc) []registry.ServerPrompt { + return []registry.ServerPrompt{ + // Issue prompts + AssignCodingAgentPrompt(t), + IssueToFixWorkflowPrompt(t), + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index bfe870775..4e7ede755 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -15,14 +15,14 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -58,6 +58,7 @@ Possible options: WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "pull_request_read", Description: t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository."), @@ -98,7 +99,7 @@ Possible options: switch method { case "get": - result, err := GetPullRequest(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -110,13 +111,13 @@ Possible options: result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) return result, nil, err case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetPullRequestReviewComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil @@ -389,7 +390,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -430,6 +431,7 @@ func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "create_pull_request", Description: t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository."), @@ -529,7 +531,7 @@ func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -582,6 +584,7 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "update_pull_request", Description: t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository."), @@ -823,7 +826,7 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -864,6 +867,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "list_pull_requests", Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), @@ -966,7 +970,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MergePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1000,6 +1004,7 @@ func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "merge_pull_request", Description: t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository."), @@ -1074,7 +1079,7 @@ func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // SearchPullRequests creates a tool to search for pull requests. -func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1118,6 +1123,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerToo WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "search_pull_requests", Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr"), @@ -1136,7 +1142,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerToo } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1161,6 +1167,7 @@ func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.Serv } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "update_pull_request_branch", Description: t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch."), @@ -1240,7 +1247,7 @@ type PullRequestReviewWriteParams struct { CommitID *string } -func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestReviewWrite(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1282,6 +1289,7 @@ func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.Serve } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "pull_request_review_write", Description: t("TOOL_PULL_REQUEST_REVIEW_WRITE_DESCRIPTION", `Create and/or submit, delete review of a pull request. @@ -1551,7 +1559,7 @@ func DeletePendingPullRequestReview(ctx context.Context, client *githubv4.Client } // AddCommentToPendingReview creates a tool to add a comment to a pull request review. -func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddCommentToPendingReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1612,6 +1620,7 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.Se } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "add_comment_to_pending_review", Description: t("TOOL_ADD_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add review comment to the requester's latest pending pull request review. A pending review needs to already exist to call this (check with the user if not sure)."), @@ -1743,7 +1752,7 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.Se // RequestCopilotReview creates a tool to request a Copilot review for a pull request. // Note that this tool will not work on GHES where this feature is unsupported. In future, we should not expose this // tool if the configured host does not support it. -func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RequestCopilotReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1764,6 +1773,7 @@ func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerT } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "request_copilot_review", Description: t("TOOL_REQUEST_COPILOT_REVIEW_DESCRIPTION", "Request a GitHub Copilot code review for a pull request. Use this for automated feedback on pull requests, usually before requesting a human reviewer."), diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 7531edf6d..a22245700 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -105,9 +105,9 @@ func Test_GetPullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -370,9 +370,9 @@ func Test_UpdatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -561,9 +561,9 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) serverTool := UpdatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -693,8 +693,8 @@ func Test_ListPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := ListPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -817,8 +817,8 @@ func Test_MergePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := MergePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1123,8 +1123,8 @@ func Test_SearchPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := SearchPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1276,8 +1276,8 @@ func Test_GetPullRequestFiles(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1451,8 +1451,8 @@ func Test_GetPullRequestStatus(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1589,8 +1589,8 @@ func Test_UpdatePullRequestBranch(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1763,8 +1763,8 @@ func Test_GetPullRequestComments(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -1951,8 +1951,8 @@ func Test_GetPullRequestReviews(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -2115,8 +2115,8 @@ func Test_CreatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := CreatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2331,8 +2331,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2447,8 +2447,8 @@ func Test_RequestCopilotReview(t *testing.T) { client := github.NewClient(tc.mockedClient) serverTool := RequestCopilotReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2641,8 +2641,8 @@ func TestCreatePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2824,8 +2824,8 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2929,8 +2929,8 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3028,8 +3028,8 @@ func TestDeletePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3119,8 +3119,8 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } diff --git a/pkg/github/registry.go b/pkg/github/registry.go new file mode 100644 index 000000000..88795ee1e --- /dev/null +++ b/pkg/github/registry.go @@ -0,0 +1,18 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" +) + +// NewRegistry creates a Registry with all available tools, resources, and prompts. +// Tools, resources, and prompts are self-describing with their toolset metadata embedded. +// This function is stateless - no dependencies are captured. +// Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). +// The "default" keyword in WithToolsets will expand to toolsets marked with Default: true. +func NewRegistry(t translations.TranslationHelperFunc) *registry.Builder { + return registry.NewBuilder(). + SetTools(AllTools(t)). + SetResources(AllResources(t)). + SetPrompts(AllPrompts(t)) +} diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index de5aaea5e..854da679a 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -11,7 +11,7 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -19,8 +19,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCommit(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_commit", Description: t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository"), @@ -117,8 +118,9 @@ func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCommits(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_commits", Description: t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100)."), @@ -225,8 +227,9 @@ func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListBranches creates a tool to list branches in a GitHub repository. -func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListBranches(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_branches", Description: t("TOOL_LIST_BRANCHES_DESCRIPTION", "List branches in a GitHub repository"), @@ -312,8 +315,9 @@ func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateOrUpdateFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_or_update_file", Description: t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations."), @@ -439,8 +443,9 @@ func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerToo } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_repository", Description: t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account or specified organization"), @@ -545,8 +550,9 @@ func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetFileContents(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_file_contents", Description: t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository"), @@ -764,8 +770,9 @@ func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ForkRepository creates a tool to fork a repository. -func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ForkRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "fork_repository", Description: t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization"), @@ -862,8 +869,9 @@ func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { // unlike how the endpoint backing the create_or_update_files tool does. This appears to be a quirk of the API. // The approach implemented here gets automatic commit signing when used with either the github-actions user or as an app, // both of which suit an LLM well. -func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "delete_file", Description: t("TOOL_DELETE_FILE_DESCRIPTION", "Delete a file from a GitHub repository"), @@ -1047,8 +1055,9 @@ func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateBranch creates a tool to create a new branch. -func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateBranch(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_branch", Description: t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository"), @@ -1160,8 +1169,9 @@ func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PushFiles(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "push_files", Description: t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit"), @@ -1344,8 +1354,9 @@ func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListTags creates a tool to list tags in a GitHub repository. -func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListTags(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_tags", Description: t("TOOL_LIST_TAGS_DESCRIPTION", "List git tags in a GitHub repository"), @@ -1423,8 +1434,9 @@ func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetTag creates a tool to get details about a specific tag in a GitHub repository. -func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_tag", Description: t("TOOL_GET_TAG_DESCRIPTION", "Get details about a specific git tag in a GitHub repository"), @@ -1521,8 +1533,9 @@ func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListReleases creates a tool to list releases in a GitHub repository. -func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListReleases(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_releases", Description: t("TOOL_LIST_RELEASES_DESCRIPTION", "List releases in a GitHub repository"), @@ -1596,8 +1609,9 @@ func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetLatestRelease creates a tool to get the latest release in a GitHub repository. -func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetLatestRelease(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_latest_release", Description: t("TOOL_GET_LATEST_RELEASE_DESCRIPTION", "Get the latest release in a GitHub repository"), @@ -1661,8 +1675,9 @@ func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetReleaseByTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetReleaseByTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_release_by_tag", Description: t("TOOL_GET_RELEASE_BY_TAG_DESCRIPTION", "Get a specific release by its tag name in a GitHub repository"), @@ -1874,8 +1889,9 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner } // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. -func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListStarredRepositories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "list_starred_repositories", Description: t("TOOL_LIST_STARRED_REPOSITORIES_DESCRIPTION", "List starred repositories"), @@ -2006,8 +2022,9 @@ func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.Serv } // StarRepository creates a tool to star a repository. -func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func StarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "star_repository", Description: t("TOOL_STAR_REPOSITORY_DESCRIPTION", "Star a GitHub repository"), @@ -2071,8 +2088,9 @@ func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UnstarRepository creates a tool to unstar a repository. -func UnstarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UnstarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "unstar_repository", Description: t("TOOL_UNSTAR_REPOSITORY_DESCRIPTION", "Unstar a GitHub repository"), diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 949686d92..55f0866cb 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -286,9 +286,9 @@ func Test_GetFileContents(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetRawClient: stubGetRawClientFn(mockRawClient), + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, } handler := serverTool.Handler(deps) @@ -410,8 +410,8 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -606,8 +606,8 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -738,8 +738,8 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -964,8 +964,8 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1143,8 +1143,8 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1329,8 +1329,8 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1668,8 +1668,8 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1789,8 +1789,8 @@ func Test_ListBranches(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create mock client mockClient := github.NewClient(mock.NewMockedHTTPClient(tt.mockResponses...)) - deps := ToolDependencies{ - GetClient: stubGetClientFn(mockClient), + deps := BaseDeps{ + Client: mockClient, } handler := serverTool.Handler(deps) @@ -1977,8 +1977,8 @@ func Test_DeleteFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2104,8 +2104,8 @@ func Test_ListTags(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2264,8 +2264,8 @@ func Test_GetTag(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2377,8 +2377,8 @@ func Test_ListReleases(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2468,8 +2468,8 @@ func Test_GetLatestRelease(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2616,8 +2616,8 @@ func Test_GetReleaseByTag(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3128,8 +3128,8 @@ func Test_ListStarredRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3229,8 +3229,8 @@ func Test_StarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3320,8 +3320,8 @@ func Test_UnstarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 5dea9f4e9..af001af6f 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -28,58 +29,81 @@ var ( repositoryResourcePrContentURITemplate = uritemplate.MustNew("repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}") ) -// GetRepositoryResourceContent defines the resource template and handler for getting repository content. -func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceContent defines the resource template for getting repository content. +func GetRepositoryResourceContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content", - URITemplate: repositoryResourceContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceContentURITemplate), + ) } -// GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. -func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceBranchContent defines the resource template for getting repository content for a branch. +func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_branch", - URITemplate: repositoryResourceBranchContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceBranchContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceBranchContentURITemplate), + ) } -// GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceCommitContent defines the resource template for getting repository content for a commit. +func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_commit", - URITemplate: repositoryResourceCommitContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceCommitContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceCommitContentURITemplate), + ) } -// GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceTagContent defines the resource template for getting repository content for a tag. +func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_tag", - URITemplate: repositoryResourceTagContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceTagContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceTagContentURITemplate), + ) } -// GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourcePrContent defines the resource template for getting repository content for a pull request. +func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_pr", - URITemplate: repositoryResourcePrContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourcePrContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourcePrContentURITemplate), + ) +} + +// repositoryResourceContentsHandlerFunc returns a ResourceHandlerFunc that creates handlers on-demand. +func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) registry.ResourceHandlerFunc { + return func(deps any) mcp.ResourceHandler { + d := deps.(ToolDependencies) + return RepositoryResourceContentsHandler(d, resourceURITemplate) + } } // RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { +func RepositoryResourceContentsHandler(deps ToolDependencies, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { return func(ctx context.Context, request *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Match the URI to extract parameters uriValues := resourceURITemplate.Match(request.Params.URI) @@ -133,7 +157,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G prNumber := uriValues.Get("prNumber").String() if prNumber != "" { // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) + githubClient, err := deps.GetClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -153,7 +177,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G if path == "" || strings.HasSuffix(path, "/") { return nil, fmt.Errorf("directories are not supported: %s", path) } - rawClient, err := getRawClient(ctx) + rawClient, err := deps.GetRawClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err) diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index 113f46d89..99c06cdd6 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -28,7 +27,7 @@ func Test_repositoryResourceContents(t *testing.T) { name string mockedClient *http.Client uri string - handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler + handlerFn func(deps ToolDependencies) mcp.ResourceHandler expectedResponseType resourceResponseType expectError string expectedResult *mcp.ReadResourceResult @@ -46,9 +45,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo:///repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", @@ -66,9 +64,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner//refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", @@ -86,9 +83,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/data.png", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -111,9 +107,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -138,9 +133,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/pkg/github/actions.go", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -163,9 +157,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -188,9 +181,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceTagContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceTagContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -213,9 +205,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/sha/abc123/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceCommitContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceCommitContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -246,9 +237,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourcePrContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourcePrContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -270,9 +260,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/nonexistent.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", @@ -283,7 +272,11 @@ func Test_repositoryResourceContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, base) - handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, + } + handler := tc.handlerFn(deps) request := &mcp.ReadResourceRequest{ Params: &mcp.ReadResourceParams{ diff --git a/pkg/github/resources.go b/pkg/github/resources.go new file mode 100644 index 000000000..6acf5eb6a --- /dev/null +++ b/pkg/github/resources.go @@ -0,0 +1,19 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllResources returns all resource templates with their embedded toolset metadata. +// Resource definitions are stateless - handlers are generated on-demand during registration. +func AllResources(t translations.TranslationHelperFunc) []registry.ServerResourceTemplate { + return []registry.ServerResourceTemplate{ + // Repository resources + GetRepositoryResourceContent(t), + GetRepositoryResourceBranchContent(t), + GetRepositoryResourceCommitContent(t), + GetRepositoryResourceTagContent(t), + GetRepositoryResourcePrContent(t), + } +} diff --git a/pkg/github/search.go b/pkg/github/search.go index eaaf49369..4b35f3f0d 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -17,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchRepositories(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -46,6 +46,7 @@ func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerToo WithPagination(schema) return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), @@ -166,7 +167,7 @@ func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerToo } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchCode(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -189,6 +190,7 @@ func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), @@ -349,7 +351,7 @@ func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchUsers(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -373,6 +375,7 @@ func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataUsers, mcp.Tool{ Name: "search_users", Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), @@ -389,7 +392,7 @@ func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchOrgs(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -413,6 +416,7 @@ func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataOrgs, mcp.Tool{ Name: "search_orgs", Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 41d12df1b..707b55349 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -134,8 +134,8 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -209,8 +209,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { client := github.NewClient(mockedClient) serverTool := SearchRepositories(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -358,8 +358,8 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -558,8 +558,8 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -733,8 +733,8 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index fa5618791..e840072b0 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,8 +16,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetSecretScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecretProtection, mcp.Tool{ Name: "get_secret_scanning_alert", Description: t("TOOL_GET_SECRET_SCANNING_ALERT_DESCRIPTION", "Get details of a specific secret scanning alert in a GitHub repository."), @@ -93,8 +94,9 @@ func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.Serve ) } -func ListSecretScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListSecretScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecretProtection, mcp.Tool{ Name: "list_secret_scanning_alerts", Description: t("TOOL_LIST_SECRET_SCANNING_ALERTS_DESCRIPTION", "List secret scanning alerts in a GitHub repository."), diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 83de16409..b63617a46 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -87,8 +87,8 @@ func Test_GetSecretScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -228,8 +228,8 @@ func Test_ListSecretScanningAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index 8c7df4265..28acb8156 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,8 +15,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_global_security_advisories", Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), @@ -206,8 +207,9 @@ func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets ) } -func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_repository_security_advisories", Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), @@ -310,8 +312,9 @@ func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) tool ) } -func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "get_global_security_advisory", Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), @@ -367,8 +370,9 @@ func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.Se ) } -func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_org_repository_security_advisories", Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index 16506a3e8..3970949ec 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -103,7 +103,7 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -224,7 +224,7 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -372,7 +372,7 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -517,7 +517,7 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server.go b/pkg/github/server.go index e74596906..a9c9305a2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -18,12 +18,7 @@ import ( func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { if opts == nil { - // Add default options - opts = &mcp.ServerOptions{ - HasTools: true, - HasResources: true, - HasPrompts: true, - } + opts = &mcp.ServerOptions{} } // Create a new MCP server diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 2e9ab43a3..a59cd9a93 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -11,32 +11,67 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" ) -func stubGetClientFn(client *github.Client) GetClientFn { - return func(_ context.Context) (*github.Client, error) { - return client, nil +// stubDeps is a test helper that implements ToolDependencies with configurable behavior. +// Use this when you need to test error paths or when you need closure-based client creation. +type stubDeps struct { + clientFn func(context.Context) (*github.Client, error) + gqlClientFn func(context.Context) (*githubv4.Client, error) + rawClientFn func(context.Context) (*raw.Client, error) + + repoAccessCache *lockdown.RepoAccessCache + t translations.TranslationHelperFunc + flags FeatureFlags + contentWindowSize int +} + +func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) { + if s.clientFn != nil { + return s.clientFn(ctx) } + return nil, nil } -func stubGetClientFromHTTPFn(client *http.Client) GetClientFn { +func (s stubDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + if s.gqlClientFn != nil { + return s.gqlClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + if s.rawClientFn != nil { + return s.rawClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } +func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } +func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } + +// Helper functions to create stub client functions for error testing +func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return github.NewClient(client), nil + return github.NewClient(httpClient), nil } } -func stubGetClientFnErr(err string) GetClientFn { +func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return nil, errors.New(err) + return nil, errors.New(errMsg) } } -func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { +func stubGQLClientFnErr(errMsg string) func(context.Context) (*githubv4.Client, error) { return func(_ context.Context) (*githubv4.Client, error) { - return client, nil + return nil, errors.New(errMsg) } } @@ -51,12 +86,6 @@ func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { } } -func stubGetRawClientFn(client *raw.Client) raw.GetRawClientFn { - return func(_ context.Context) (*raw.Client, error) { - return client, nil - } -} - func badRequestHandler(msg string) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { structuredErrorResponse := github.ErrorResponse{ diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 6af9ce40d..1fde6e6bb 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -2,432 +2,248 @@ package github import ( "context" - "fmt" "strings" - "github.com/github/github-mcp-server/pkg/lockdown" - "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" ) type GetClientFn func(context.Context) (*github.Client, error) type GetGQLClientFn func(context.Context) (*githubv4.Client, error) -// ToolsetMetadata holds metadata for a toolset including its ID and description -type ToolsetMetadata struct { - ID string - Description string -} - +// Toolset metadata constants - these define all available toolsets and their descriptions. +// Tools use these constants to declare which toolset they belong to. var ( - ToolsetMetadataAll = ToolsetMetadata{ + ToolsetMetadataAll = registry.ToolsetMetadata{ ID: "all", Description: "Special toolset that enables all available toolsets", } - ToolsetMetadataDefault = ToolsetMetadata{ + ToolsetMetadataDefault = registry.ToolsetMetadata{ ID: "default", Description: "Special toolset that enables the default toolset configuration. When no toolsets are specified, this is the set that is enabled", } - ToolsetMetadataContext = ToolsetMetadata{ + ToolsetMetadataContext = registry.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", + Default: true, } - ToolsetMetadataRepos = ToolsetMetadata{ + ToolsetMetadataRepos = registry.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", + Default: true, } - ToolsetMetadataGit = ToolsetMetadata{ + ToolsetMetadataGit = registry.ToolsetMetadata{ ID: "git", Description: "GitHub Git API related tools for low-level Git operations", } - ToolsetMetadataIssues = ToolsetMetadata{ + ToolsetMetadataIssues = registry.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", + Default: true, } - ToolsetMetadataPullRequests = ToolsetMetadata{ + ToolsetMetadataPullRequests = registry.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", + Default: true, } - ToolsetMetadataUsers = ToolsetMetadata{ + ToolsetMetadataUsers = registry.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", + Default: true, } - ToolsetMetadataOrgs = ToolsetMetadata{ + ToolsetMetadataOrgs = registry.ToolsetMetadata{ ID: "orgs", Description: "GitHub Organization related tools", } - ToolsetMetadataActions = ToolsetMetadata{ + ToolsetMetadataActions = registry.ToolsetMetadata{ ID: "actions", Description: "GitHub Actions workflows and CI/CD operations", } - ToolsetMetadataCodeSecurity = ToolsetMetadata{ + ToolsetMetadataCodeSecurity = registry.ToolsetMetadata{ ID: "code_security", Description: "Code security related tools, such as GitHub Code Scanning", } - ToolsetMetadataSecretProtection = ToolsetMetadata{ + ToolsetMetadataSecretProtection = registry.ToolsetMetadata{ ID: "secret_protection", Description: "Secret protection related tools, such as GitHub Secret Scanning", } - ToolsetMetadataDependabot = ToolsetMetadata{ + ToolsetMetadataDependabot = registry.ToolsetMetadata{ ID: "dependabot", Description: "Dependabot tools", } - ToolsetMetadataNotifications = ToolsetMetadata{ + ToolsetMetadataNotifications = registry.ToolsetMetadata{ ID: "notifications", Description: "GitHub Notifications related tools", } - ToolsetMetadataExperiments = ToolsetMetadata{ + ToolsetMetadataExperiments = registry.ToolsetMetadata{ ID: "experiments", Description: "Experimental features that are not considered stable yet", } - ToolsetMetadataDiscussions = ToolsetMetadata{ + ToolsetMetadataDiscussions = registry.ToolsetMetadata{ ID: "discussions", Description: "GitHub Discussions related tools", } - ToolsetMetadataGists = ToolsetMetadata{ + ToolsetMetadataGists = registry.ToolsetMetadata{ ID: "gists", Description: "GitHub Gist related tools", } - ToolsetMetadataSecurityAdvisories = ToolsetMetadata{ + ToolsetMetadataSecurityAdvisories = registry.ToolsetMetadata{ ID: "security_advisories", Description: "Security advisories related tools", } - ToolsetMetadataProjects = ToolsetMetadata{ + ToolsetMetadataProjects = registry.ToolsetMetadata{ ID: "projects", Description: "GitHub Projects related tools", } - ToolsetMetadataStargazers = ToolsetMetadata{ + ToolsetMetadataStargazers = registry.ToolsetMetadata{ ID: "stargazers", Description: "GitHub Stargazers related tools", } - ToolsetMetadataDynamic = ToolsetMetadata{ + ToolsetMetadataDynamic = registry.ToolsetMetadata{ ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } - ToolsetLabels = ToolsetMetadata{ + ToolsetLabels = registry.ToolsetMetadata{ ID: "labels", Description: "GitHub Labels related tools", } ) -func AvailableTools() []ToolsetMetadata { - return []ToolsetMetadata{ - ToolsetMetadataContext, - ToolsetMetadataRepos, - ToolsetMetadataIssues, - ToolsetMetadataPullRequests, - ToolsetMetadataUsers, - ToolsetMetadataOrgs, - ToolsetMetadataActions, - ToolsetMetadataCodeSecurity, - ToolsetMetadataSecretProtection, - ToolsetMetadataDependabot, - ToolsetMetadataNotifications, - ToolsetMetadataExperiments, - ToolsetMetadataDiscussions, - ToolsetMetadataGists, - ToolsetMetadataSecurityAdvisories, - ToolsetMetadataProjects, - ToolsetMetadataStargazers, - ToolsetMetadataDynamic, - ToolsetLabels, - } -} - -// GetValidToolsetIDs returns a map of all valid toolset IDs for quick lookup -func GetValidToolsetIDs() map[string]bool { - validIDs := make(map[string]bool) - for _, tool := range AvailableTools() { - validIDs[tool.ID] = true +// AllTools returns all tools with their embedded toolset metadata. +// Tool functions return ServerTool directly with toolset info. +func AllTools(t translations.TranslationHelperFunc) []registry.ServerTool { + return []registry.ServerTool{ + // Context tools + GetMe(t), + GetTeams(t), + GetTeamMembers(t), + + // Repository tools + SearchRepositories(t), + GetFileContents(t), + ListCommits(t), + SearchCode(t), + GetCommit(t), + ListBranches(t), + ListTags(t), + GetTag(t), + ListReleases(t), + GetLatestRelease(t), + GetReleaseByTag(t), + CreateOrUpdateFile(t), + CreateRepository(t), + ForkRepository(t), + CreateBranch(t), + PushFiles(t), + DeleteFile(t), + ListStarredRepositories(t), + StarRepository(t), + UnstarRepository(t), + + // Git tools + GetRepositoryTree(t), + + // Issue tools + IssueRead(t), + SearchIssues(t), + ListIssues(t), + ListIssueTypes(t), + IssueWrite(t), + AddIssueComment(t), + AssignCopilotToIssue(t), + SubIssueWrite(t), + + // User tools + SearchUsers(t), + + // Organization tools + SearchOrgs(t), + + // Pull request tools + PullRequestRead(t), + ListPullRequests(t), + SearchPullRequests(t), + MergePullRequest(t), + UpdatePullRequestBranch(t), + CreatePullRequest(t), + UpdatePullRequest(t), + RequestCopilotReview(t), + PullRequestReviewWrite(t), + AddCommentToPendingReview(t), + + // Code security tools + GetCodeScanningAlert(t), + ListCodeScanningAlerts(t), + + // Secret protection tools + GetSecretScanningAlert(t), + ListSecretScanningAlerts(t), + + // Dependabot tools + GetDependabotAlert(t), + ListDependabotAlerts(t), + + // Notification tools + ListNotifications(t), + GetNotificationDetails(t), + DismissNotification(t), + MarkAllNotificationsRead(t), + ManageNotificationSubscription(t), + ManageRepositoryNotificationSubscription(t), + + // Discussion tools + ListDiscussions(t), + GetDiscussion(t), + GetDiscussionComments(t), + ListDiscussionCategories(t), + + // Actions tools + ListWorkflows(t), + ListWorkflowRuns(t), + GetWorkflowRun(t), + GetWorkflowRunLogs(t), + ListWorkflowJobs(t), + GetJobLogs(t), + ListWorkflowRunArtifacts(t), + DownloadWorkflowRunArtifact(t), + GetWorkflowRunUsage(t), + RunWorkflow(t), + RerunWorkflowRun(t), + RerunFailedJobs(t), + CancelWorkflowRun(t), + DeleteWorkflowRunLogs(t), + + // Security advisories tools + ListGlobalSecurityAdvisories(t), + GetGlobalSecurityAdvisory(t), + ListRepositorySecurityAdvisories(t), + ListOrgRepositorySecurityAdvisories(t), + + // Gist tools + ListGists(t), + GetGist(t), + CreateGist(t), + UpdateGist(t), + + // Project tools + ListProjects(t), + GetProject(t), + ListProjectFields(t), + GetProjectField(t), + ListProjectItems(t), + GetProjectItem(t), + AddProjectItem(t), + DeleteProjectItem(t), + UpdateProjectItem(t), + + // Label tools + GetLabel(t), + GetLabelForLabelsToolset(t), + ListLabels(t), + LabelWrite(t), } - // Add special keywords - validIDs[ToolsetMetadataAll.ID] = true - validIDs[ToolsetMetadataDefault.ID] = true - return validIDs -} - -func GetDefaultToolsetIDs() []string { - return []string{ - ToolsetMetadataContext.ID, - ToolsetMetadataRepos.ID, - ToolsetMetadataIssues.ID, - ToolsetMetadataPullRequests.ID, - ToolsetMetadataUsers.ID, - } -} - -func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags, cache *lockdown.RepoAccessCache) *toolsets.ToolsetGroup { - tsg := toolsets.NewToolsetGroup(readOnly) - - // Create the dependencies struct that will be passed to all tool handlers - deps := ToolDependencies{ - GetClient: getClient, - GetGQLClient: getGQLClient, - GetRawClient: getRawClient, - RepoAccessCache: cache, - T: t, - Flags: flags, - ContentWindowSize: contentWindowSize, - } - - // Define all available features with their default state (disabled) - // Create toolsets - repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). - SetDependencies(deps). - AddReadTools( - SearchRepositories(t), - GetFileContents(t), - ListCommits(t), - SearchCode(t), - GetCommit(t), - ListBranches(t), - ListTags(t), - GetTag(t), - ListReleases(t), - GetLatestRelease(t), - GetReleaseByTag(t), - ). - AddWriteTools( - CreateOrUpdateFile(t), - CreateRepository(t), - ForkRepository(t), - CreateBranch(t), - PushFiles(t), - DeleteFile(t), - ). - AddResourceTemplates( - toolsets.NewServerResourceTemplate(GetRepositoryResourceContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceBranchContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceCommitContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceTagContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourcePrContent(getClient, getRawClient, t)), - ) - git := toolsets.NewToolset(ToolsetMetadataGit.ID, ToolsetMetadataGit.Description). - SetDependencies(deps). - AddReadTools( - GetRepositoryTree(t), - ) - issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). - SetDependencies(deps). - AddReadTools( - IssueRead(t), - SearchIssues(t), - ListIssues(t), - ListIssueTypes(t), - toolsets.NewServerToolLegacy(GetLabel(getGQLClient, t)), - ). - AddWriteTools( - IssueWrite(t), - AddIssueComment(t), - AssignCopilotToIssue(t), - SubIssueWrite(t), - ).AddPrompts( - toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), - toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), - ) - users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). - SetDependencies(deps). - AddReadTools( - SearchUsers(t), - ) - orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). - SetDependencies(deps). - AddReadTools( - SearchOrgs(t), - ) - pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). - SetDependencies(deps). - AddReadTools( - PullRequestRead(t), - ListPullRequests(t), - SearchPullRequests(t), - ). - AddWriteTools( - MergePullRequest(t), - UpdatePullRequestBranch(t), - CreatePullRequest(t), - UpdatePullRequest(t), - RequestCopilotReview(t), - // Reviews - PullRequestReviewWrite(t), - AddCommentToPendingReview(t), - ) - codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). - SetDependencies(deps). - AddReadTools( - GetCodeScanningAlert(t), - ListCodeScanningAlerts(t), - ) - secretProtection := toolsets.NewToolset(ToolsetMetadataSecretProtection.ID, ToolsetMetadataSecretProtection.Description). - SetDependencies(deps). - AddReadTools( - GetSecretScanningAlert(t), - ListSecretScanningAlerts(t), - ) - dependabot := toolsets.NewToolset(ToolsetMetadataDependabot.ID, ToolsetMetadataDependabot.Description). - SetDependencies(deps). - AddReadTools( - GetDependabotAlert(t), - ListDependabotAlerts(t), - ) - - notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description). - SetDependencies(deps). - AddReadTools( - ListNotifications(t), - GetNotificationDetails(t), - ). - AddWriteTools( - DismissNotification(t), - MarkAllNotificationsRead(t), - ManageNotificationSubscription(t), - ManageRepositoryNotificationSubscription(t), - ) - - discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). - SetDependencies(deps). - AddReadTools( - ListDiscussions(t), - GetDiscussion(t), - GetDiscussionComments(t), - ListDiscussionCategories(t), - ) - - actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description). - SetDependencies(deps). - AddReadTools( - ListWorkflows(t), - ListWorkflowRuns(t), - GetWorkflowRun(t), - GetWorkflowRunLogs(t), - ListWorkflowJobs(t), - GetJobLogs(t), - ListWorkflowRunArtifacts(t), - DownloadWorkflowRunArtifact(t), - GetWorkflowRunUsage(t), - ). - AddWriteTools( - RunWorkflow(t), - RerunWorkflowRun(t), - RerunFailedJobs(t), - CancelWorkflowRun(t), - DeleteWorkflowRunLogs(t), - ) - - securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description). - SetDependencies(deps). - AddReadTools( - ListGlobalSecurityAdvisories(t), - GetGlobalSecurityAdvisory(t), - ListRepositorySecurityAdvisories(t), - ListOrgRepositorySecurityAdvisories(t), - ) - - // // Keep experiments alive so the system doesn't error out when it's always enabled - experiments := toolsets.NewToolset(ToolsetMetadataExperiments.ID, ToolsetMetadataExperiments.Description). - SetDependencies(deps) - - contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). - SetDependencies(deps). - AddReadTools( - GetMe(t), - GetTeams(t), - GetTeamMembers(t), - ) - - gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). - SetDependencies(deps). - AddReadTools( - ListGists(t), - GetGist(t), - ). - AddWriteTools( - CreateGist(t), - UpdateGist(t), - ) - - projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description). - SetDependencies(deps). - AddReadTools( - toolsets.NewServerToolLegacy(ListProjects(getClient, t)), - toolsets.NewServerToolLegacy(GetProject(getClient, t)), - toolsets.NewServerToolLegacy(ListProjectFields(getClient, t)), - toolsets.NewServerToolLegacy(GetProjectField(getClient, t)), - toolsets.NewServerToolLegacy(ListProjectItems(getClient, t)), - toolsets.NewServerToolLegacy(GetProjectItem(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerToolLegacy(AddProjectItem(getClient, t)), - toolsets.NewServerToolLegacy(DeleteProjectItem(getClient, t)), - toolsets.NewServerToolLegacy(UpdateProjectItem(getClient, t)), - ) - stargazers := toolsets.NewToolset(ToolsetMetadataStargazers.ID, ToolsetMetadataStargazers.Description). - SetDependencies(deps). - AddReadTools( - ListStarredRepositories(t), - ). - AddWriteTools( - StarRepository(t), - UnstarRepository(t), - ) - labels := toolsets.NewToolset(ToolsetLabels.ID, ToolsetLabels.Description). - SetDependencies(deps). - AddReadTools( - // get - toolsets.NewServerToolLegacy(GetLabel(getGQLClient, t)), - // list labels on repo or issue - toolsets.NewServerToolLegacy(ListLabels(getGQLClient, t)), - ). - AddWriteTools( - // create or update - toolsets.NewServerToolLegacy(LabelWrite(getGQLClient, t)), - ) - - // Add toolsets to the group - tsg.AddToolset(contextTools) - tsg.AddToolset(repos) - tsg.AddToolset(git) - tsg.AddToolset(issues) - tsg.AddToolset(orgs) - tsg.AddToolset(users) - tsg.AddToolset(pullRequests) - tsg.AddToolset(actions) - tsg.AddToolset(codeSecurity) - tsg.AddToolset(dependabot) - tsg.AddToolset(secretProtection) - tsg.AddToolset(notifications) - tsg.AddToolset(experiments) - tsg.AddToolset(discussions) - tsg.AddToolset(gists) - tsg.AddToolset(securityAdvisories) - tsg.AddToolset(projects) - tsg.AddToolset(stargazers) - tsg.AddToolset(labels) - - tsg.AddDeprecatedToolAliases(DeprecatedToolAliases) - - return tsg -} - -// InitDynamicToolset creates a dynamic toolset that can be used to enable other toolsets, and so requires the server and toolset group as arguments -// -//nolint:unused -func InitDynamicToolset(s *mcp.Server, tsg *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) *toolsets.Toolset { - // Create a new dynamic toolset - // Need to add the dynamic toolset last so it can be used to enable other toolsets - dynamicToolSelection := toolsets.NewToolset(ToolsetMetadataDynamic.ID, ToolsetMetadataDynamic.Description). - AddReadTools( - toolsets.NewServerToolLegacy(ListAvailableToolsets(tsg, t)), - toolsets.NewServerToolLegacy(GetToolsetsTools(tsg, t)), - toolsets.NewServerToolLegacy(EnableToolset(s, tsg, t)), - ) - - dynamicToolSelection.Enabled = true - return dynamicToolSelection } // ToBoolPtr converts a bool to a *bool pointer. @@ -446,52 +262,77 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { - // Format default tools - defaultTools := strings.Join(GetDefaultToolsetIDs(), ", ") + // Get toolset group to derive defaults and available toolsets + r := NewRegistry(stubTranslator).Build() + + // Format default tools from metadata using strings.Builder + var defaultBuf strings.Builder + defaultIDs := r.DefaultToolsetIDs() + for i, id := range defaultIDs { + if i > 0 { + defaultBuf.WriteString(", ") + } + defaultBuf.WriteString(string(id)) + } - // Format available tools with line breaks for better readability - allTools := AvailableTools() - var availableToolsLines []string + // Get all available toolsets (excludes context and dynamic for display) + allToolsets := r.AvailableToolsets("context", "dynamic") + var availableBuf strings.Builder const maxLineLength = 70 currentLine := "" - for i, tool := range allTools { + for i, toolset := range allToolsets { + id := string(toolset.ID) switch { case i == 0: - currentLine = tool.ID - case len(currentLine)+len(tool.ID)+2 <= maxLineLength: - currentLine += ", " + tool.ID + currentLine = id + case len(currentLine)+len(id)+2 <= maxLineLength: + currentLine += ", " + id default: - availableToolsLines = append(availableToolsLines, currentLine) - currentLine = tool.ID + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + currentLine = id } } if currentLine != "" { - availableToolsLines = append(availableToolsLines, currentLine) - } - - availableTools := strings.Join(availableToolsLines, ",\n\t ") - - toolsetsHelp := fmt.Sprintf("Comma-separated list of tool groups to enable (no spaces).\n"+ - "Available: %s\n", availableTools) + - "Special toolset keywords:\n" + - " - all: Enables all available toolsets\n" + - fmt.Sprintf(" - default: Enables the default toolset configuration of:\n\t %s\n", defaultTools) + - "Examples:\n" + - " - --toolsets=actions,gists,notifications\n" + - " - Default + additional: --toolsets=default,actions,gists\n" + - " - All tools: --toolsets=all" - - return toolsetsHelp + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + } + + // Build the complete help text using strings.Builder + var buf strings.Builder + buf.WriteString("Comma-separated list of tool groups to enable (no spaces).\n") + buf.WriteString("Available: ") + buf.WriteString(availableBuf.String()) + buf.WriteString("\n") + buf.WriteString("Special toolset keywords:\n") + buf.WriteString(" - all: Enables all available toolsets\n") + buf.WriteString(" - default: Enables the default toolset configuration of:\n\t ") + buf.WriteString(defaultBuf.String()) + buf.WriteString("\n") + buf.WriteString("Examples:\n") + buf.WriteString(" - --toolsets=actions,gists,notifications\n") + buf.WriteString(" - Default + additional: --toolsets=default,actions,gists\n") + buf.WriteString(" - All tools: --toolsets=all") + + return buf.String() } +// stubTranslator is a passthrough translator for cases where we need a Registry +// but don't need actual translations (e.g., getting toolset IDs for CLI help). +func stubTranslator(_, fallback string) string { return fallback } + // AddDefaultToolset removes the default toolset and expands it to the actual default toolset IDs func AddDefaultToolset(result []string) []string { hasDefault := false seen := make(map[string]bool) for _, toolset := range result { seen[toolset] = true - if toolset == ToolsetMetadataDefault.ID { + if toolset == string(ToolsetMetadataDefault.ID) { hasDefault = true } } @@ -501,45 +342,18 @@ func AddDefaultToolset(result []string) []string { return result } - result = RemoveToolset(result, ToolsetMetadataDefault.ID) + result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) - for _, defaultToolset := range GetDefaultToolsetIDs() { - if !seen[defaultToolset] { - result = append(result, defaultToolset) + // Get default toolset IDs from the Registry + r := NewRegistry(stubTranslator).Build() + for _, id := range r.DefaultToolsetIDs() { + if !seen[string(id)] { + result = append(result, string(id)) } } return result } -// cleanToolsets cleans and handles special toolset keywords: -// - Duplicates are removed from the result -// - Removes whitespaces -// - Validates toolset names and returns invalid ones separately - for warning reporting -// Returns: (toolsets, invalidToolsets) -func CleanToolsets(enabledToolsets []string) ([]string, []string) { - seen := make(map[string]bool) - result := make([]string, 0, len(enabledToolsets)) - invalid := make([]string, 0) - validIDs := GetValidToolsetIDs() - - // Add non-default toolsets, removing duplicates and trimming whitespace - for _, toolset := range enabledToolsets { - trimmed := strings.TrimSpace(toolset) - if trimmed == "" { - continue - } - if !seen[trimmed] { - seen[trimmed] = true - result = append(result, trimmed) - if !validIDs[trimmed] { - invalid = append(invalid, trimmed) - } - } - } - - return result, invalid -} - func RemoveToolset(tools []string, toRemove string) []string { result := make([]string, 0, len(tools)) for _, tool := range tools { @@ -579,3 +393,15 @@ func CleanTools(toolNames []string) []string { return result } + +// GetDefaultToolsetIDs returns the IDs of toolsets marked as Default. +// This is a convenience function that builds a registry to determine defaults. +func GetDefaultToolsetIDs() []string { + r := NewRegistry(stubTranslator).Build() + ids := r.DefaultToolsetIDs() + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index 45c1e746f..80270d2bc 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -7,135 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestCleanToolsets(t *testing.T) { - tests := []struct { - name string - input []string - expected []string - expectedInvalid []string - }{ - { - name: "empty slice", - input: []string{}, - expected: []string{}, - }, - { - name: "nil input slice", - input: nil, - expected: []string{}, - }, - // CleanToolsets only cleans - it does NOT filter out special keywords - { - name: "default keyword preserved", - input: []string{"default"}, - expected: []string{"default"}, - }, - { - name: "default with additional toolsets", - input: []string{"default", "actions", "gists"}, - expected: []string{"default", "actions", "gists"}, - }, - { - name: "all keyword preserved", - input: []string{"all", "actions"}, - expected: []string{"all", "actions"}, - }, - { - name: "no special keywords", - input: []string{"actions", "gists", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "duplicate toolsets without special keywords", - input: []string{"actions", "gists", "actions"}, - expected: []string{"actions", "gists"}, - }, - { - name: "duplicate toolsets with default", - input: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - expected: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - }, - { - name: "default appears multiple times - duplicates removed", - input: []string{"default", "actions", "default", "gists", "default"}, - expected: []string{"default", "actions", "gists"}, - }, - // Whitespace test cases - { - name: "whitespace check - leading and trailing whitespace on regular toolsets", - input: []string{" actions ", " gists ", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "whitespace check - default toolset with whitespace", - input: []string{" actions ", " default ", "notifications"}, - expected: []string{"actions", "default", "notifications"}, - }, - { - name: "whitespace check - all toolset with whitespace", - input: []string{" all ", " actions "}, - expected: []string{"all", "actions"}, - }, - // Invalid toolset test cases - { - name: "mix of valid and invalid toolsets", - input: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expected: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expectedInvalid: []string{"invalid_toolset", "typo_repo"}, - }, - { - name: "invalid with whitespace", - input: []string{" invalid_tool ", " actions ", " typo_gist "}, - expected: []string{"invalid_tool", "actions", "typo_gist"}, - expectedInvalid: []string{"invalid_tool", "typo_gist"}, - }, - { - name: "empty string in toolsets", - input: []string{"", "actions", " ", "gists"}, - expected: []string{"actions", "gists"}, - expectedInvalid: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, invalid := CleanToolsets(tt.input) - - require.Len(t, result, len(tt.expected), "result length should match expected length") - - if tt.expectedInvalid == nil { - tt.expectedInvalid = []string{} - } - require.Len(t, invalid, len(tt.expectedInvalid), "invalid length should match expected invalid length") - - resultMap := make(map[string]bool) - for _, toolset := range result { - resultMap[toolset] = true - } - - expectedMap := make(map[string]bool) - for _, toolset := range tt.expected { - expectedMap[toolset] = true - } - - invalidMap := make(map[string]bool) - for _, toolset := range invalid { - invalidMap[toolset] = true - } - - expectedInvalidMap := make(map[string]bool) - for _, toolset := range tt.expectedInvalid { - expectedInvalidMap[toolset] = true - } - - assert.Equal(t, expectedMap, resultMap, "result should contain all expected toolsets without duplicates") - assert.Equal(t, expectedInvalidMap, invalidMap, "invalid should contain all expected invalid toolsets") - - assert.Len(t, resultMap, len(result), "result should not contain duplicates") - }) - } -} - func TestAddDefaultToolset(t *testing.T) { tests := []struct { name string @@ -280,3 +151,34 @@ func TestContainsToolset(t *testing.T) { }) } } + +func TestGenerateToolsetsHelp(t *testing.T) { + // Generate the help text + helpText := GenerateToolsetsHelp() + + // Verify help text is not empty + require.NotEmpty(t, helpText) + + // Verify it contains expected sections + assert.Contains(t, helpText, "Comma-separated list of tool groups to enable") + assert.Contains(t, helpText, "Available:") + assert.Contains(t, helpText, "Special toolset keywords:") + assert.Contains(t, helpText, "all: Enables all available toolsets") + assert.Contains(t, helpText, "default: Enables the default toolset configuration") + assert.Contains(t, helpText, "Examples:") + assert.Contains(t, helpText, "--toolsets=actions,gists,notifications") + assert.Contains(t, helpText, "--toolsets=default,actions,gists") + assert.Contains(t, helpText, "--toolsets=all") + + // Verify it contains some expected default toolsets + assert.Contains(t, helpText, "context") + assert.Contains(t, helpText, "repos") + assert.Contains(t, helpText, "issues") + assert.Contains(t, helpText, "pull_requests") + assert.Contains(t, helpText, "users") + + // Verify it contains some expected available toolsets + assert.Contains(t, helpText, "actions") + assert.Contains(t, helpText, "gists") + assert.Contains(t, helpText, "notifications") +} diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go new file mode 100644 index 000000000..aa809dfa6 --- /dev/null +++ b/pkg/github/tools_validation_test.go @@ -0,0 +1,177 @@ +package github + +import ( + "testing" + + "github.com/github/github-mcp-server/pkg/registry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubTranslation is a simple translation function for testing +func stubTranslation(_, fallback string) string { + return fallback +} + +// TestAllToolsHaveRequiredMetadata validates that all tools have mandatory metadata: +// - Toolset must be set (non-empty ID) +// - ReadOnlyHint annotation must be explicitly set (not nil) +func TestAllToolsHaveRequiredMetadata(t *testing.T) { + tools := AllTools(stubTranslation) + + require.NotEmpty(t, tools, "AllTools should return at least one tool") + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, tool.Toolset.ID, + "Tool %q must have a Toolset.ID", tool.Tool.Name) + + // Toolset description should be set for documentation + assert.NotEmpty(t, tool.Toolset.Description, + "Tool %q should have a Toolset.Description", tool.Tool.Name) + + // Annotations must exist and have ReadOnlyHint explicitly set + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations set (for ReadOnlyHint)", tool.Tool.Name) + + // We can't distinguish between "not set" and "set to false" for a bool, + // but having Annotations non-nil confirms the developer thought about it. + // The ReadOnlyHint value itself is validated by ensuring Annotations exist. + }) + } +} + +// TestAllResourcesHaveRequiredMetadata validates that all resources have mandatory metadata +func TestAllResourcesHaveRequiredMetadata(t *testing.T) { + // Resources are now stateless - no client functions needed + resources := AllResources(stubTranslation) + + require.NotEmpty(t, resources, "AllResources should return at least one resource") + + for _, res := range resources { + t.Run(res.Template.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, res.Toolset.ID, + "Resource %q must have a Toolset.ID", res.Template.Name) + + // HandlerFunc must be set + assert.True(t, res.HasHandler(), + "Resource %q must have a HandlerFunc", res.Template.Name) + }) + } +} + +// TestAllPromptsHaveRequiredMetadata validates that all prompts have mandatory metadata +func TestAllPromptsHaveRequiredMetadata(t *testing.T) { + prompts := AllPrompts(stubTranslation) + + require.NotEmpty(t, prompts, "AllPrompts should return at least one prompt") + + for _, prompt := range prompts { + t.Run(prompt.Prompt.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, prompt.Toolset.ID, + "Prompt %q must have a Toolset.ID", prompt.Prompt.Name) + + // Handler must be set + assert.NotNil(t, prompt.Handler, + "Prompt %q must have a Handler", prompt.Prompt.Name) + }) + } +} + +// TestToolReadOnlyHintConsistency validates that read-only tools are correctly annotated +func TestToolReadOnlyHintConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations", tool.Tool.Name) + + // Verify IsReadOnly() method matches the annotation + assert.Equal(t, tool.Tool.Annotations.ReadOnlyHint, tool.IsReadOnly(), + "Tool %q: IsReadOnly() should match Annotations.ReadOnlyHint", tool.Tool.Name) + }) + } +} + +// TestNoDuplicateToolNames ensures all tools have unique names +func TestNoDuplicateToolNames(t *testing.T) { + tools := AllTools(stubTranslation) + seen := make(map[string]bool) + + // get_label is intentionally in both issues and labels toolsets for conformance + // with original behavior where it was registered in both + allowedDuplicates := map[string]bool{ + "get_label": true, + } + + for _, tool := range tools { + name := tool.Tool.Name + if !allowedDuplicates[name] { + assert.False(t, seen[name], + "Duplicate tool name found: %q", name) + } + seen[name] = true + } +} + +// TestNoDuplicateResourceNames ensures all resources have unique names +func TestNoDuplicateResourceNames(t *testing.T) { + resources := AllResources(stubTranslation) + seen := make(map[string]bool) + + for _, res := range resources { + name := res.Template.Name + assert.False(t, seen[name], + "Duplicate resource name found: %q", name) + seen[name] = true + } +} + +// TestNoDuplicatePromptNames ensures all prompts have unique names +func TestNoDuplicatePromptNames(t *testing.T) { + prompts := AllPrompts(stubTranslation) + seen := make(map[string]bool) + + for _, prompt := range prompts { + name := prompt.Prompt.Name + assert.False(t, seen[name], + "Duplicate prompt name found: %q", name) + seen[name] = true + } +} + +// TestAllToolsHaveHandlerFunc ensures all tools have a handler function +func TestAllToolsHaveHandlerFunc(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + assert.NotNil(t, tool.HandlerFunc, + "Tool %q must have a HandlerFunc", tool.Tool.Name) + assert.True(t, tool.HasHandler(), + "Tool %q HasHandler() should return true", tool.Tool.Name) + }) + } +} + +// TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions +func TestToolsetMetadataConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + toolsetDescriptions := make(map[registry.ToolsetID]string) + + for _, tool := range tools { + id := tool.Toolset.ID + desc := tool.Toolset.Description + + if existing, ok := toolsetDescriptions[id]; ok { + assert.Equal(t, existing, desc, + "Toolset %q has inconsistent descriptions across tools", id) + } else { + toolsetDescriptions[id] = desc + } + } +} diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go index bc7c7581f..603a98087 100644 --- a/pkg/github/workflow_prompts.go +++ b/pkg/github/workflow_prompts.go @@ -4,13 +4,16 @@ import ( "context" "fmt" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/modelcontextprotocol/go-sdk/mcp" ) // IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it -func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Prompt, handler mcp.PromptHandler) { - return mcp.Prompt{ +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "issue_to_fix_workflow", Description: t("PROMPT_ISSUE_TO_FIX_WORKFLOW_DESCRIPTION", "Create an issue for a problem and then generate a pull request to fix it"), Arguments: []*mcp.PromptArgument{ @@ -102,5 +105,6 @@ func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Pr return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go new file mode 100644 index 000000000..813712064 --- /dev/null +++ b/pkg/registry/builder.go @@ -0,0 +1,274 @@ +package registry + +import ( + "context" + "sort" + "strings" +) + +// ToolFilter is a function that determines if a tool should be included. +// Returns true if the tool should be included, false to exclude it. +type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) + +// Builder builds a Registry with the specified configuration. +// Use NewBuilder to create a builder, chain configuration methods, +// then call Build() to create the final Registry. +// +// Example: +// +// reg := NewBuilder(). +// SetTools(tools). +// SetResources(resources). +// SetPrompts(prompts). +// WithDeprecatedAliases(aliases). +// WithReadOnly(true). +// WithToolsets([]string{"repos", "issues"}). +// WithFeatureChecker(checker). +// WithFilter(myFilter). +// Build() +type Builder struct { + tools []ServerTool + resourceTemplates []ServerResourceTemplate + prompts []ServerPrompt + deprecatedAliases map[string]string + + // Configuration options (processed at Build time) + readOnly bool + toolsetIDs []string // raw input, processed at Build() + toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) + additionalTools []string // raw input, processed at Build() + featureChecker FeatureFlagChecker + filters []ToolFilter // filters to apply to all tools +} + +// NewBuilder creates a new Builder. +func NewBuilder() *Builder { + return &Builder{ + deprecatedAliases: make(map[string]string), + toolsetIDsIsNil: true, // default to nil (use defaults) + } +} + +// SetTools sets the tools for the registry. Returns self for chaining. +func (b *Builder) SetTools(tools []ServerTool) *Builder { + b.tools = tools + return b +} + +// SetResources sets the resource templates for the registry. Returns self for chaining. +func (b *Builder) SetResources(resources []ServerResourceTemplate) *Builder { + b.resourceTemplates = resources + return b +} + +// SetPrompts sets the prompts for the registry. Returns self for chaining. +func (b *Builder) SetPrompts(prompts []ServerPrompt) *Builder { + b.prompts = prompts + return b +} + +// WithDeprecatedAliases adds deprecated tool name aliases that map to canonical names. +// Returns self for chaining. +func (b *Builder) WithDeprecatedAliases(aliases map[string]string) *Builder { + for oldName, newName := range aliases { + b.deprecatedAliases[oldName] = newName + } + return b +} + +// WithReadOnly sets whether only read-only tools should be available. +// When true, write tools are filtered out. Returns self for chaining. +func (b *Builder) WithReadOnly(readOnly bool) *Builder { + b.readOnly = readOnly + return b +} + +// WithToolsets specifies which toolsets should be enabled. +// Special keywords: +// - "all": enables all toolsets +// - "default": expands to toolsets marked with Default: true in their metadata +// +// Input strings are trimmed of whitespace and duplicates are removed. +// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets +// (useful for dynamic toolsets mode where tools are enabled on demand). +// Returns self for chaining. +func (b *Builder) WithToolsets(toolsetIDs []string) *Builder { + b.toolsetIDs = toolsetIDs + b.toolsetIDsIsNil = toolsetIDs == nil + return b +} + +// WithTools specifies additional tools that bypass toolset filtering. +// These tools are additive - they will be included even if their toolset is not enabled. +// Read-only filtering still applies to these tools. +// Deprecated tool aliases are automatically resolved to their canonical names during Build(). +// Returns self for chaining. +func (b *Builder) WithTools(toolNames []string) *Builder { + b.additionalTools = toolNames + return b +} + +// WithFeatureChecker sets the feature flag checker function. +// The checker receives a context (for actor extraction) and feature flag name, +// returns (enabled, error). If error occurs, it will be logged and treated as false. +// If checker is nil, all feature flag checks return false. +// Returns self for chaining. +func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { + b.featureChecker = checker + return b +} + +// WithFilter adds a filter function that will be applied to all tools. +// Multiple filters can be added and are evaluated in order. +// If any filter returns false or an error, the tool is excluded. +// Returns self for chaining. +func (b *Builder) WithFilter(filter ToolFilter) *Builder { + b.filters = append(b.filters, filter) + return b +} + +// Build creates the final Registry with all configuration applied. +// This processes toolset filtering, tool name resolution, and sets up +// the registry for use. The returned Registry is ready for use with +// AvailableTools(), RegisterAll(), etc. +func (b *Builder) Build() *Registry { + r := &Registry{ + tools: b.tools, + resourceTemplates: b.resourceTemplates, + prompts: b.prompts, + deprecatedAliases: b.deprecatedAliases, + readOnly: b.readOnly, + featureChecker: b.featureChecker, + filters: b.filters, + } + + // Process toolsets and pre-compute metadata in a single pass + r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.toolsetIDSet, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() + + // Process additional tools (resolve aliases) + if len(b.additionalTools) > 0 { + r.additionalTools = make(map[string]bool, len(b.additionalTools)) + for _, name := range b.additionalTools { + // Resolve deprecated aliases to canonical names + if canonical, isAlias := b.deprecatedAliases[name]; isAlias { + r.additionalTools[canonical] = true + } else { + r.additionalTools[name] = true + } + } + } + + return r +} + +// processToolsets processes the toolsetIDs configuration and returns: +// - enabledToolsets map (nil means all enabled) +// - unrecognizedToolsets list for warnings +// - allToolsetIDs sorted list of all toolset IDs +// - toolsetIDSet map for O(1) HasToolset lookup +// - defaultToolsetIDs sorted list of default toolset IDs +// - toolsetDescriptions map of toolset ID to description +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, map[ToolsetID]bool, []ToolsetID, map[ToolsetID]string) { + // Single pass: collect all toolset metadata together + validIDs := make(map[ToolsetID]bool) + defaultIDs := make(map[ToolsetID]bool) + descriptions := make(map[ToolsetID]string) + + for i := range b.tools { + t := &b.tools[i] + validIDs[t.Toolset.ID] = true + if t.Toolset.Default { + defaultIDs[t.Toolset.ID] = true + } + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } + } + for i := range b.resourceTemplates { + r := &b.resourceTemplates[i] + validIDs[r.Toolset.ID] = true + if r.Toolset.Default { + defaultIDs[r.Toolset.ID] = true + } + if r.Toolset.Description != "" { + descriptions[r.Toolset.ID] = r.Toolset.Description + } + } + for i := range b.prompts { + p := &b.prompts[i] + validIDs[p.Toolset.ID] = true + if p.Toolset.Default { + defaultIDs[p.Toolset.ID] = true + } + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + + // Build sorted slices from the collected maps + allToolsetIDs := make([]ToolsetID, 0, len(validIDs)) + for id := range validIDs { + allToolsetIDs = append(allToolsetIDs, id) + } + sort.Slice(allToolsetIDs, func(i, j int) bool { return allToolsetIDs[i] < allToolsetIDs[j] }) + + defaultToolsetIDList := make([]ToolsetID, 0, len(defaultIDs)) + for id := range defaultIDs { + defaultToolsetIDList = append(defaultToolsetIDList, id) + } + sort.Slice(defaultToolsetIDList, func(i, j int) bool { return defaultToolsetIDList[i] < defaultToolsetIDList[j] }) + + toolsetIDs := b.toolsetIDs + + // Check for "all" keyword - enables all toolsets + for _, id := range toolsetIDs { + if strings.TrimSpace(id) == "all" { + return nil, nil, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions // nil means all enabled + } + } + + // nil means use defaults, empty slice means no toolsets + if b.toolsetIDsIsNil { + toolsetIDs = []string{"default"} + } + + // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized + seen := make(map[ToolsetID]bool) + expanded := make([]ToolsetID, 0, len(toolsetIDs)) + var unrecognized []string + + for _, id := range toolsetIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + if trimmed == "default" { + for _, defaultID := range defaultToolsetIDList { + if !seen[defaultID] { + seen[defaultID] = true + expanded = append(expanded, defaultID) + } + } + } else { + tsID := ToolsetID(trimmed) + if !seen[tsID] { + seen[tsID] = true + expanded = append(expanded, tsID) + // Track if this toolset doesn't exist + if !validIDs[tsID] { + unrecognized = append(unrecognized, trimmed) + } + } + } + } + + if len(expanded) == 0 { + return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions + } + + enabledToolsets := make(map[ToolsetID]bool, len(expanded)) + for _, id := range expanded { + enabledToolsets[id] = true + } + return enabledToolsets, unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions +} diff --git a/pkg/registry/errors.go b/pkg/registry/errors.go new file mode 100644 index 000000000..75cbb6f82 --- /dev/null +++ b/pkg/registry/errors.go @@ -0,0 +1,41 @@ +package registry + +import "fmt" + +// ToolsetDoesNotExistError is returned when a toolset is not found. +type ToolsetDoesNotExistError struct { + Name string +} + +func (e *ToolsetDoesNotExistError) Error() string { + return fmt.Sprintf("toolset %s does not exist", e.Name) +} + +func (e *ToolsetDoesNotExistError) Is(target error) bool { + if target == nil { + return false + } + if _, ok := target.(*ToolsetDoesNotExistError); ok { + return true + } + return false +} + +// NewToolsetDoesNotExistError creates a new ToolsetDoesNotExistError. +func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { + return &ToolsetDoesNotExistError{Name: name} +} + +// ToolDoesNotExistError is returned when a tool is not found. +type ToolDoesNotExistError struct { + Name string +} + +func (e *ToolDoesNotExistError) Error() string { + return fmt.Sprintf("tool %s does not exist", e.Name) +} + +// NewToolDoesNotExistError creates a new ToolDoesNotExistError. +func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { + return &ToolDoesNotExistError{Name: name} +} diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go new file mode 100644 index 000000000..5bce63570 --- /dev/null +++ b/pkg/registry/filters.go @@ -0,0 +1,289 @@ +package registry + +import ( + "context" + "fmt" + "os" + "sort" +) + +// FeatureFlagChecker is a function that checks if a feature flag is enabled. +// The context can be used to extract actor/user information for flag evaluation. +// Returns (enabled, error). If error occurs, the caller should log and treat as false. +type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) + +// isToolsetEnabled checks if a toolset is enabled based on current filters. +func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { + // Check enabled toolsets filter + if r.enabledToolsets != nil { + return r.enabledToolsets[toolsetID] + } + return true +} + +// checkFeatureFlag checks a feature flag using the feature checker. +// Returns false if checker is nil or returns an error (errors are logged). +func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { + if r.featureChecker == nil || flagName == "" { + return false + } + enabled, err := r.featureChecker(ctx, flagName) + if err != nil { + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + return enabled +} + +// isFeatureFlagAllowed checks if an item passes feature flag filtering. +// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled +// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled +func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { + // Check enable flag - item requires this flag to be on + if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { + return false + } + // Check disable flag - item is excluded if this flag is on + if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { + return false + } + return true +} + +// isToolEnabled checks if a specific tool is enabled based on current filters. +// Filter evaluation order: +// 1. Tool.Enabled (tool self-filtering) +// 2. FeatureFlagEnable/FeatureFlagDisable +// 3. Read-only filter +// 4. Builder filters (via WithFilter) +// 5. Toolset/additional tools +func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { + // 1. Check tool's own Enabled function first + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.Enabled check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } + // 2. Check feature flags + if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + return false + } + // 3. Check read-only filter (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // 4. Apply builder filters + for _, filter := range r.filters { + allowed, err := filter(ctx, tool) + if err != nil { + fmt.Fprintf(os.Stderr, "Builder filter error for tool %q: %v\n", tool.Tool.Name, err) + return false + } + if !allowed { + return false + } + } + // 5. Check if tool is in additionalTools (bypasses toolset filter) + if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { + return true + } + // 5. Check toolset filter + if !r.isToolsetEnabled(tool.Toolset.ID) { + return false + } + return true +} + +// AvailableTools returns the tools that pass all current filters, +// sorted deterministically by toolset ID, then tool name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + if r.isToolEnabled(ctx, tool) { + result = append(result, *tool) + } + } + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableResourceTemplates returns resource templates that pass all current filters, +// sorted deterministically by toolset ID, then template name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + var result []ServerResourceTemplate + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(res.Toolset.ID) { + result = append(result, *res) + } + } + + // Sort deterministically: by toolset ID, then by template name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Template.Name < result[j].Template.Name + }) + + return result +} + +// AvailablePrompts returns prompts that pass all current filters, +// sorted deterministically by toolset ID, then prompt name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { + var result []ServerPrompt + for i := range r.prompts { + prompt := &r.prompts[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(prompt.Toolset.ID) { + result = append(result, *prompt) + } + } + + // Sort deterministically: by toolset ID, then by prompt name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Prompt.Name < result[j].Prompt.Name + }) + + return result +} + +// filterToolsByName returns tools matching the given name, checking deprecated aliases. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterToolsByName(name string) []ServerTool { + // First check for exact match + for i := range r.tools { + if r.tools[i].Tool.Name == name { + return []ServerTool{r.tools[i]} + } + } + // Check if name is a deprecated alias + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { + for i := range r.tools { + if r.tools[i].Tool.Name == canonical { + return []ServerTool{r.tools[i]} + } + } + } + return []ServerTool{} +} + +// filterResourcesByURI returns resource templates matching the given URI pattern. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{r.resourceTemplates[i]} + } + } + return []ServerResourceTemplate{} +} + +// filterPromptsByName returns prompts matching the given name. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Registry) filterPromptsByName(name string) []ServerPrompt { + for i := range r.prompts { + if r.prompts[i].Prompt.Name == name { + return []ServerPrompt{r.prompts[i]} + } + } + return []ServerPrompt{} +} + +// ToolsForToolset returns all tools belonging to a specific toolset. +// This method bypasses the toolset enabled filter (for dynamic toolset registration), +// but still respects the read-only filter. +func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + // Only check read-only filter, not toolset enabled filter + if tool.Toolset.ID == toolsetID { + if r.readOnly && !tool.IsReadOnly() { + continue + } + result = append(result, *tool) + } + } + + // Sort by tool name for deterministic order + sort.Slice(result, func(i, j int) bool { + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// IsToolsetEnabled checks if a toolset is currently enabled based on filters. +func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { + return r.isToolsetEnabled(toolsetID) +} + +// EnableToolset marks a toolset as enabled in this group. +// This is used by dynamic toolset management to track which toolsets have been enabled. +func (r *Registry) EnableToolset(toolsetID ToolsetID) { + if r.enabledToolsets == nil { + // nil means all enabled, so nothing to do + return + } + r.enabledToolsets[toolsetID] = true +} + +// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. +// Returns all toolset IDs if no filter is set. +func (r *Registry) EnabledToolsetIDs() []ToolsetID { + if r.enabledToolsets == nil { + return r.ToolsetIDs() + } + + ids := make([]ToolsetID, 0, len(r.enabledToolsets)) + for id := range r.enabledToolsets { + if r.HasToolset(id) { + ids = append(ids, id) + } + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// FilteredTools returns tools filtered by the Enabled function and builder filters. +// This provides an explicit API for accessing filtered tools, currently implemented +// as an alias for AvailableTools. +// +// The error return is currently always nil but is included for future extensibility. +// Library consumers (e.g., remote server implementations) may need to surface +// recoverable filter errors rather than silently logging them. Having the error +// return in the API now avoids breaking changes later. +// +// The context is used for Enabled function evaluation and builder filter checks. +func (r *Registry) FilteredTools(ctx context.Context) ([]ServerTool, error) { + return r.AvailableTools(ctx), nil +} diff --git a/pkg/registry/prompts.go b/pkg/registry/prompts.go new file mode 100644 index 000000000..02dda6c9c --- /dev/null +++ b/pkg/registry/prompts.go @@ -0,0 +1,26 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ServerPrompt pairs a prompt with its toolset metadata. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler mcp.PromptHandler + // Toolset identifies which toolset this prompt belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt + // to be available. If set and the flag is not enabled, the prompt is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt + // to be omitted. Used to disable prompts when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerPrompt creates a new ServerPrompt with toolset metadata. +func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { + return ServerPrompt{ + Prompt: prompt, + Handler: handler, + Toolset: toolset, + } +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 000000000..7376a4c9c --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,282 @@ +package registry + +import ( + "context" + "fmt" + "os" + "slices" + "sort" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Registry holds a collection of tools, resources, and prompts with filtering applied. +// Create a Registry using Builder: +// +// reg := NewBuilder(). +// SetTools(tools). +// WithReadOnly(true). +// WithToolsets([]string{"repos"}). +// Build() +// +// The Registry is configured at build time and provides: +// - Filtered access to tools/resources/prompts via Available* methods +// - Deterministic ordering for documentation generation +// - Lazy dependency injection during registration via RegisterAll() +// - Runtime toolset enabling for dynamic toolsets mode +type Registry struct { + // tools holds all tools in this group (ordered for iteration) + tools []ServerTool + // resourceTemplates holds all resource templates in this group (ordered for iteration) + resourceTemplates []ServerResourceTemplate + // prompts holds all prompts in this group (ordered for iteration) + prompts []ServerPrompt + // deprecatedAliases maps old tool names to new canonical names + deprecatedAliases map[string]string + + // Pre-computed toolset metadata (set during Build) + toolsetIDs []ToolsetID // sorted list of all toolset IDs + toolsetIDSet map[ToolsetID]bool // set for O(1) HasToolset lookup + defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs + toolsetDescriptions map[ToolsetID]string // toolset ID -> description + + // Filters - these control what's returned by Available* methods + // readOnly when true filters out write tools + readOnly bool + // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets + // when nil, all toolsets are enabled + enabledToolsets map[ToolsetID]bool + // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) + // These are additive - a tool is included if it matches toolset filters OR is in this set + additionalTools map[string]bool + // featureChecker when non-nil, checks if a feature flag is enabled. + // Takes context and flag name, returns (enabled, error). If error, log and treat as false. + // If checker is nil, all flag checks return false. + featureChecker FeatureFlagChecker + // filters are functions that will be applied to all tools during filtering. + // If any filter returns false or an error, the tool is excluded. + filters []ToolFilter + // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets + unrecognizedToolsets []string +} + +// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't +// match any registered toolsets. This is useful for warning users about typos. +func (r *Registry) UnrecognizedToolsets() []string { + return r.unrecognizedToolsets +} + +// MCP method constants for use with ForMCPRequest. +const ( + MCPMethodInitialize = "initialize" + MCPMethodToolsList = "tools/list" + MCPMethodToolsCall = "tools/call" + MCPMethodResourcesList = "resources/list" + MCPMethodResourcesRead = "resources/read" + MCPMethodResourcesTemplatesList = "resources/templates/list" + MCPMethodPromptsList = "prompts/list" + MCPMethodPromptsGet = "prompts/get" +) + +// ForMCPRequest returns a Registry optimized for a specific MCP request. +// This is designed for servers that create a new instance per request (like the remote server), +// allowing them to only register the items needed for that specific request rather than all ~90 tools. +// +// Parameters: +// - method: The MCP method being called (use MCP* constants) +// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) +// +// Returns a new Registry containing only the items relevant to the request: +// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) +// - MCPMethodToolsList: All available tools (no resources/prompts) +// - MCPMethodToolsCall: Only the named tool +// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) +// - MCPMethodResourcesRead: Only the named resource template +// - MCPMethodPromptsList: All available prompts (no tools/resources) +// - MCPMethodPromptsGet: Only the named prompt +// - Unknown methods: Empty (no items registered) +// +// All existing filters (read-only, toolsets, etc.) still apply to the returned items. +func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { + // Create a shallow copy with shared filter settings + // Note: lazy-init maps (toolsByName, etc.) are NOT copied - the new Registry + // will initialize its own maps on first use if needed + result := &Registry{ + tools: r.tools, + resourceTemplates: r.resourceTemplates, + prompts: r.prompts, + deprecatedAliases: r.deprecatedAliases, + readOnly: r.readOnly, + enabledToolsets: r.enabledToolsets, // shared, not modified + additionalTools: r.additionalTools, // shared, not modified + featureChecker: r.featureChecker, + filters: r.filters, // shared, not modified + unrecognizedToolsets: r.unrecognizedToolsets, + } + + // Helper to clear all item types + clearAll := func() { + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + } + + switch method { + case MCPMethodInitialize: + clearAll() + case MCPMethodToolsList: + result.resourceTemplates, result.prompts = nil, nil + case MCPMethodToolsCall: + result.resourceTemplates, result.prompts = nil, nil + if itemName != "" { + result.tools = r.filterToolsByName(itemName) + } + case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: + result.tools, result.prompts = nil, nil + case MCPMethodResourcesRead: + result.tools, result.prompts = nil, nil + if itemName != "" { + result.resourceTemplates = r.filterResourcesByURI(itemName) + } + case MCPMethodPromptsList: + result.tools, result.resourceTemplates = nil, nil + case MCPMethodPromptsGet: + result.tools, result.resourceTemplates = nil, nil + if itemName != "" { + result.prompts = r.filterPromptsByName(itemName) + } + default: + clearAll() + } + + return result +} + +// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. +func (r *Registry) ToolsetIDs() []ToolsetID { + return r.toolsetIDs +} + +// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. +// The IDs are returned in sorted order for deterministic output. +func (r *Registry) DefaultToolsetIDs() []ToolsetID { + return r.defaultToolsetIDs +} + +// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. +func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { + return r.toolsetDescriptions +} + +// RegisterTools registers all available tools with the server using the provided dependencies. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range r.AvailableTools(ctx) { + tool.RegisterFunc(s, deps) + } +} + +// RegisterResourceTemplates registers all available resource templates with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { + for _, res := range r.AvailableResourceTemplates(ctx) { + s.AddResourceTemplate(&res.Template, res.Handler(deps)) + } +} + +// RegisterPrompts registers all available prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range r.AvailablePrompts(ctx) { + s.AddPrompt(&prompt.Prompt, prompt.Handler) + } +} + +// RegisterAll registers all available tools, resources, and prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + r.RegisterTools(ctx, s, deps) + r.RegisterResourceTemplates(ctx, s, deps) + r.RegisterPrompts(ctx, s) +} + +// ResolveToolAliases resolves deprecated tool aliases to their canonical names. +// It logs a warning to stderr for each deprecated alias that is resolved. +// Returns: +// - resolved: tool names with aliases replaced by canonical names +// - aliasesUsed: map of oldName → newName for each alias that was resolved +func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { + resolved = make([]string, 0, len(toolNames)) + aliasesUsed = make(map[string]string) + for _, toolName := range toolNames { + if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { + fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) + aliasesUsed[toolName] = canonicalName + resolved = append(resolved, canonicalName) + } else { + resolved = append(resolved, toolName) + } + } + return resolved, aliasesUsed +} + +// FindToolByName searches all tools for one matching the given name. +// Returns the tool, its toolset ID, and an error if not found. +// This searches ALL tools regardless of filters. +func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range r.tools { + if r.tools[i].Tool.Name == toolName { + return &r.tools[i], r.tools[i].Toolset.ID, nil + } + } + return nil, "", NewToolDoesNotExistError(toolName) +} + +// HasToolset checks if any tool/resource/prompt belongs to the given toolset. +func (r *Registry) HasToolset(toolsetID ToolsetID) bool { + return r.toolsetIDSet[toolsetID] +} + +// AllTools returns all tools without any filtering, sorted deterministically. +func (r *Registry) AllTools() []ServerTool { + result := slices.Clone(r.tools) + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableToolsets returns the unique toolsets that have tools, in sorted order. +// This is the ordered intersection of toolsets with reality - only toolsets that +// actually contain tools are returned, sorted by toolset ID. +// Optional exclude parameter filters out specific toolset IDs from the result. +func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := r.AllTools() + if len(tools) == 0 { + return nil + } + + // Build exclude set for O(1) lookup + excludeSet := make(map[ToolsetID]bool, len(exclude)) + for _, id := range exclude { + excludeSet[id] = true + } + + var result []ToolsetMetadata + var lastID ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastID { + lastID = tool.Toolset.ID + if !excludeSet[lastID] { + result = append(result, tool.Toolset) + } + } + } + return result +} diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go new file mode 100644 index 000000000..44ed3d773 --- /dev/null +++ b/pkg/registry/registry_test.go @@ -0,0 +1,1645 @@ +package registry + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// testToolsetMetadata returns a ToolsetMetadata for testing +func testToolsetMetadata(id string) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + } +} + +// testToolsetMetadataWithDefault returns a ToolsetMetadata with Default flag for testing +func testToolsetMetadataWithDefault(id string, isDefault bool) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + Default: isDefault, + } +} + +// mockToolWithDefault creates a mock tool with a default toolset flag +func mockToolWithDefault(name string, toolsetID string, readOnly bool, isDefault bool) ServerTool { + return NewServerToolFromHandler( + mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ + ReadOnlyHint: readOnly, + }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + testToolsetMetadataWithDefault(toolsetID, isDefault), + func(_ any) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) +} + +// mockTool creates a minimal ServerTool for testing +func mockTool(name string, toolsetID string, readOnly bool) ServerTool { + return NewServerToolFromHandler( + mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ + ReadOnlyHint: readOnly, + }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + testToolsetMetadata(toolsetID), + func(_ any) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) +} + +func TestNewRegistryEmpty(t *testing.T) { + reg := NewBuilder().Build() + if len(reg.AvailableTools(context.Background())) != 0 { + t.Fatalf("Expected tools to be empty") + } + if len(reg.AvailableResourceTemplates(context.Background())) != 0 { + t.Fatalf("Expected resourceTemplates to be empty") + } + if len(reg.AvailablePrompts(context.Background())) != 0 { + t.Fatalf("Expected prompts to be empty") + } +} + +func TestNewRegistryWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + + if len(reg.AllTools()) != 3 { + t.Errorf("Expected 3 tools, got %d", len(reg.AllTools())) + } +} + +func TestAvailableTools_NoFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool_b", "toolset1", true), + mockTool("tool_a", "toolset1", false), + mockTool("tool_c", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != 3 { + t.Fatalf("Expected 3 available tools, got %d", len(available)) + } + + // Verify deterministic sorting: by toolset ID, then tool name + expectedOrder := []string{"tool_a", "tool_b", "tool_c"} + for i, tool := range available { + if tool.Tool.Name != expectedOrder[i] { + t.Errorf("Tool at index %d: expected %s, got %s", i, expectedOrder[i], tool.Tool.Name) + } + } +} + +func TestWithReadOnly(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), + } + + // Build without read-only - should have both tools + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := reg.AvailableTools(context.Background()) + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools without read-only, got %d", len(allTools)) + } + + // Build with read-only - should filter out write tools + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + readOnlyTools := readOnlyReg.AvailableTools(context.Background()) + if len(readOnlyTools) != 1 { + t.Fatalf("Expected 1 tool in read-only, got %d", len(readOnlyTools)) + } + if readOnlyTools[0].Tool.Name != "read_tool" { + t.Errorf("Expected read_tool, got %s", readOnlyTools[0].Tool.Name) + } +} + +func TestWithToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + mockTool("tool3", "toolset3", true), + } + + // Build with all toolsets + allReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := allReg.AvailableTools(context.Background()) + if len(allTools) != 3 { + t.Fatalf("Expected 3 tools without filter, got %d", len(allTools)) + } + + // Build with specific toolsets + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset3"}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) + } + + // Verify correct tools are included + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) + } +} + +func TestWithToolsetsTrimsWhitespace(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + // Whitespace should be trimmed + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{" toolset1 ", " toolset2 "}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 tools after whitespace trimming, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsDeduplicates(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Duplicates should be removed + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool after deduplication, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsIgnoresEmptyStrings(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Empty strings should be ignored + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"", "toolset1", " ", ""}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(filteredTools)) + } +} + +func TestUnrecognizedToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + tests := []struct { + name string + input []string + expectedUnrecognized []string + }{ + { + name: "all valid", + input: []string{"toolset1", "toolset2"}, + expectedUnrecognized: nil, + }, + { + name: "one invalid", + input: []string{"toolset1", "invalid_toolset"}, + expectedUnrecognized: []string{"invalid_toolset"}, + }, + { + name: "multiple invalid", + input: []string{"typo1", "toolset1", "typo2"}, + expectedUnrecognized: []string{"typo1", "typo2"}, + }, + { + name: "invalid with whitespace trimmed", + input: []string{" invalid_tool "}, + expectedUnrecognized: []string{"invalid_tool"}, + }, + { + name: "empty input", + input: []string{}, + expectedUnrecognized: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filtered := NewBuilder().SetTools(tools).WithToolsets(tt.input).Build() + unrecognized := filtered.UnrecognizedToolsets() + + if len(unrecognized) != len(tt.expectedUnrecognized) { + t.Fatalf("Expected %d unrecognized, got %d: %v", + len(tt.expectedUnrecognized), len(unrecognized), unrecognized) + } + + for i, expected := range tt.expectedUnrecognized { + if unrecognized[i] != expected { + t.Errorf("Expected unrecognized[%d] = %q, got %q", i, expected, unrecognized[i]) + } + } + }) + } +} + +func TestWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), + } + + // WithTools adds additional tools that bypass toolset filtering + // When combined with WithToolsets([]), only the additional tools should be available + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) + } + + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) + } +} + +func TestChainedFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("read1", "toolset1", true), + mockTool("write1", "toolset1", false), + mockTool("read2", "toolset2", true), + mockTool("write2", "toolset2", false), + } + + // Chain read-only and toolset filter + filtered := NewBuilder().SetTools(tools).WithReadOnly(true).WithToolsets([]string{"toolset1"}).Build() + result := filtered.AvailableTools(context.Background()) + + if len(result) != 1 { + t.Fatalf("Expected 1 tool after chained filters, got %d", len(result)) + } + if result[0].Tool.Name != "read1" { + t.Errorf("Expected read1, got %s", result[0].Tool.Name) + } +} + +func TestToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset_b", true), + mockTool("tool2", "toolset_a", true), + mockTool("tool3", "toolset_b", true), // duplicate toolset + } + + reg := NewBuilder().SetTools(tools).Build() + ids := reg.ToolsetIDs() + + if len(ids) != 2 { + t.Fatalf("Expected 2 unique toolset IDs, got %d", len(ids)) + } + + // Should be sorted + if ids[0] != "toolset_a" || ids[1] != "toolset_b" { + t.Errorf("Expected sorted IDs [toolset_a, toolset_b], got %v", ids) + } +} + +func TestToolsetDescriptions(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + descriptions := reg.ToolsetDescriptions() + + if len(descriptions) != 2 { + t.Fatalf("Expected 2 descriptions, got %d", len(descriptions)) + } + + if descriptions["toolset1"] != "Test toolset: toolset1" { + t.Errorf("Wrong description for toolset1: %s", descriptions["toolset1"]) + } +} + +func TestToolsForToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + toolset1Tools := reg.ToolsForToolset("toolset1") + + if len(toolset1Tools) != 2 { + t.Fatalf("Expected 2 tools for toolset1, got %d", len(toolset1Tools)) + } +} + +func TestWithDeprecatedAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("new_name", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).WithDeprecatedAliases(map[string]string{ + "old_name": "new_name", + "get_issue": "issue_read", + }).Build() + + // Test resolving aliases + resolved, aliasesUsed := reg.ResolveToolAliases([]string{"old_name"}) + if len(resolved) != 1 || resolved[0] != "new_name" { + t.Errorf("expected alias to resolve to 'new_name', got %v", resolved) + } + if len(aliasesUsed) != 1 || aliasesUsed["old_name"] != "new_name" { + t.Errorf("expected alias mapping, got %v", aliasesUsed) + } +} + +func TestResolveToolAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("some_tool", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }).Build() + + // Test resolving a mix of aliases and canonical names + input := []string{"get_issue", "some_tool"} + resolved, aliasesUsed := reg.ResolveToolAliases(input) + + if len(resolved) != 2 { + t.Fatalf("expected 2 resolved names, got %d", len(resolved)) + } + if resolved[0] != "issue_read" { + t.Errorf("expected 'issue_read', got '%s'", resolved[0]) + } + if resolved[1] != "some_tool" { + t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) + } + + if len(aliasesUsed) != 1 { + t.Fatalf("expected 1 alias used, got %d", len(aliasesUsed)) + } + if aliasesUsed["get_issue"] != "issue_read" { + t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) + } +} + +func TestFindToolByName(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).Build() + + // Find by name + tool, toolsetID, err := reg.FindToolByName("issue_read") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tool.Tool.Name != "issue_read" { + t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) + } + if toolsetID != "toolset1" { + t.Errorf("expected toolset ID 'toolset1', got '%s'", toolsetID) + } + + // Non-existent tool + _, _, err = reg.FindToolByName("nonexistent") + if err == nil { + t.Error("expected error for non-existent tool") + } +} + +func TestWithToolsAdditive(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("issue_write", "toolset1", false), + mockTool("repo_read", "toolset2", true), + } + + // Test WithTools bypasses toolset filtering + // Enable only toolset2, but add issue_read as additional tool + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}).Build() + + available := filtered.AvailableTools(context.Background()) + if len(available) != 2 { + t.Errorf("expected 2 tools (repo_read from toolset + issue_read additional), got %d", len(available)) + } + + // Verify both tools are present + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + if !toolNames["issue_read"] { + t.Error("expected issue_read to be included as additional tool") + } + if !toolNames["repo_read"] { + t.Error("expected repo_read to be included from toolset2") + } + + // Test WithTools respects read-only mode + readOnlyFiltered := NewBuilder().SetTools(tools).WithReadOnly(true).WithTools([]string{"issue_write"}).Build() + available = readOnlyFiltered.AvailableTools(context.Background()) + + // issue_write should be excluded because read-only applies to additional tools too + for _, tool := range available { + if tool.Tool.Name == "issue_write" { + t.Error("expected issue_write to be excluded in read-only mode") + } + } + + // Test WithTools with non-existent tool (should not error, just won't match anything) + nonexistent := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"nonexistent"}).Build() + available = nonexistent.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("expected 0 tools for non-existent additional tool, got %d", len(available)) + } +} + +func TestWithToolsResolvesAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + // Using deprecated alias should resolve to canonical name + filtered := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }). + WithToolsets([]string{}). + WithTools([]string{"get_issue"}). + Build() + available := filtered.AvailableTools(context.Background()) + + if len(available) != 1 { + t.Errorf("expected 1 tool, got %d", len(available)) + } + if available[0].Tool.Name != "issue_read" { + t.Errorf("expected issue_read, got %s", available[0].Tool.Name) + } +} + +func TestHasToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + + if !reg.HasToolset("toolset1") { + t.Error("expected HasToolset to return true for existing toolset") + } + if reg.HasToolset("nonexistent") { + t.Error("expected HasToolset to return false for non-existent toolset") + } +} + +func TestEnabledToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + // Without filter, all toolsets are enabled + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + ids := reg.EnabledToolsetIDs() + if len(ids) != 2 { + t.Fatalf("Expected 2 enabled toolset IDs, got %d", len(ids)) + } + + // With filter + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1"}).Build() + filteredIDs := filtered.EnabledToolsetIDs() + if len(filteredIDs) != 1 { + t.Fatalf("Expected 1 enabled toolset ID, got %d", len(filteredIDs)) + } + if filteredIDs[0] != "toolset1" { + t.Errorf("Expected toolset1, got %s", filteredIDs[0]) + } +} + +func TestAllTools(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), + } + + // Even with read-only filter, AllTools returns everything + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + + allTools := readOnlyReg.AllTools() + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools from AllTools, got %d", len(allTools)) + } + + // But AvailableTools respects the filter + availableTools := readOnlyReg.AvailableTools(context.Background()) + if len(availableTools) != 1 { + t.Fatalf("Expected 1 tool from AvailableTools, got %d", len(availableTools)) + } +} + +func TestServerToolIsReadOnly(t *testing.T) { + readTool := mockTool("read_tool", "toolset1", true) + writeTool := mockTool("write_tool", "toolset1", false) + + if !readTool.IsReadOnly() { + t.Error("Expected read tool to be read-only") + } + if writeTool.IsReadOnly() { + t.Error("Expected write tool to not be read-only") + } +} + +// mockResource creates a minimal ServerResourceTemplate for testing +func mockResource(name string, toolsetID string, uriTemplate string) ServerResourceTemplate { + return NewServerResourceTemplate( + testToolsetMetadata(toolsetID), + mcp.ResourceTemplate{ + Name: name, + URITemplate: uriTemplate, + }, + func(_ any) mcp.ResourceHandler { + return func(_ context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return nil, nil + } + }, + ) +} + +// mockPrompt creates a minimal ServerPrompt for testing +func mockPrompt(name string, toolsetID string) ServerPrompt { + return NewServerPrompt( + testToolsetMetadata(toolsetID), + mcp.Prompt{Name: name}, + func(_ context.Context, _ *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return nil, nil + }, + ) +} + +func TestForMCPRequest_Initialize(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", false), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodInitialize, "") + + // Initialize should return empty - capabilities come from ServerOptions + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for initialize, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for initialize, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for initialize, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ToolsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") + + // tools/list should return all tools, no resources or prompts + if len(filtered.AvailableTools(context.Background())) != 2 { + t.Errorf("Expected 2 tools for tools/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for tools/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for tools/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ToolsCall(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("create_issue", "issues", false), + mockTool("list_repos", "repos", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "get_me") + + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool for tools/call with name, got %d", len(available)) + } + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected tool name 'get_me', got %q", available[0].Tool.Name) + } +} + +func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for nonexistent tool, got %d", len(filtered.AvailableTools(context.Background()))) + } +} + +func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("list_commits", "repos", true), + } + + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"all"}). + WithDeprecatedAliases(map[string]string{ + "old_get_me": "get_me", + }).Build() + + // Request using the deprecated alias + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") + + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool when using deprecated alias, got %d", len(available)) + } + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected canonical name 'get_me', got %q", available[0].Tool.Name) + } +} + +func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("create_issue", "issues", false), // write tool + } + + // Apply read-only filter at build time, then ForMCPRequest + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "create_issue") + + // The tool exists in the filtered group, but AvailableTools respects read-only + available := filtered.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools - write tool should be filtered by read-only, got %d", len(available)) + } +} + +func TestForMCPRequest_ResourcesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for resources/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources for resources/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for resources/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ResourcesRead(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), + } + + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") + + available := filtered.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource for resources/read, got %d", len(available)) + } + if available[0].Template.URITemplate != "repo://{owner}/{repo}" { + t.Errorf("Expected URI template 'repo://{owner}/{repo}', got %q", available[0].Template.URITemplate) + } +} + +func TestForMCPRequest_PromptsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for prompts/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for prompts/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts for prompts/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_PromptsGet(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), + } + + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") + + available := filtered.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt for prompts/get, got %d", len(available)) + } + if available[0].Prompt.Name != "prompt1" { + t.Errorf("Expected prompt name 'prompt1', got %q", available[0].Prompt.Name) + } +} + +func TestForMCPRequest_UnknownMethod(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest("unknown/method", "") + + // Unknown methods should return empty + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for unknown method, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for unknown method, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for unknown method, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_DoesNotMutateOriginal(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + original := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") + + // Original should be unchanged + if len(original.AvailableTools(context.Background())) != 2 { + t.Errorf("Original was mutated! Expected 2 tools, got %d", len(original.AvailableTools(context.Background()))) + } + if len(original.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 resource, got %d", len(original.AvailableResourceTemplates(context.Background()))) + } + if len(original.AvailablePrompts(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 prompt, got %d", len(original.AvailablePrompts(context.Background()))) + } + + // Filtered should have only the requested tool + if len(filtered.AvailableTools(context.Background())) != 1 { + t.Errorf("Expected 1 tool in filtered, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources in filtered, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts in filtered, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { + tools := []ServerTool{ + mockToolWithDefault("get_me", "context", true, true), // default toolset + mockToolWithDefault("create_issue", "issues", false, false), // not default + mockToolWithDefault("list_repos", "repos", true, true), // default toolset + mockToolWithDefault("delete_repo", "repos", false, true), // default but write + } + + // Chain: default toolsets -> read-only -> specific method + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"default"}). + WithReadOnly(true). + Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") + + available := filtered.AvailableTools(context.Background()) + + // Should have: get_me (context, read), list_repos (repos, read) + // Should NOT have: create_issue (issues not in default), delete_repo (write) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter chain, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["get_me"] { + t.Error("Expected get_me to be available") + } + if !toolNames["list_repos"] { + t.Error("Expected list_repos to be available") + } + if toolNames["create_issue"] { + t.Error("create_issue should not be available (toolset not enabled)") + } + if toolNames["delete_repo"] { + t.Error("delete_repo should not be available (write tool in read-only mode)") + } +} + +func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") + + // Same behavior as resources/list + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Expected 1 resource, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } +} + +func TestMCPMethodConstants(t *testing.T) { + // Verify constants match expected MCP method names + tests := []struct { + constant string + expected string + }{ + {MCPMethodInitialize, "initialize"}, + {MCPMethodToolsList, "tools/list"}, + {MCPMethodToolsCall, "tools/call"}, + {MCPMethodResourcesList, "resources/list"}, + {MCPMethodResourcesRead, "resources/read"}, + {MCPMethodResourcesTemplatesList, "resources/templates/list"}, + {MCPMethodPromptsList, "prompts/list"}, + {MCPMethodPromptsGet, "prompts/get"}, + } + + for _, tt := range tests { + if tt.constant != tt.expected { + t.Errorf("Constant mismatch: got %q, expected %q", tt.constant, tt.expected) + } + } +} + +// mockToolWithFlags creates a ServerTool with feature flags for testing +func mockToolWithFlags(name string, toolsetID string, readOnly bool, enableFlag, disableFlag string) ServerTool { + tool := mockTool(name, toolsetID, readOnly) + tool.FeatureFlagEnable = enableFlag + tool.FeatureFlagDisable = disableFlag + return tool +} + +func TestFeatureFlagEnable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + // Without feature checker, tool with FeatureFlagEnable should be excluded + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool without feature checker, got %d", len(available)) + } + if available[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", available[0].Tool.Name) + } + + // With feature checker returning false, tool should still be excluded + checkerFalse := func(_ context.Context, _ string) (bool, error) { return false, nil } + regFalse := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerFalse).Build() + availableFalse := regFalse.AvailableTools(context.Background()) + if len(availableFalse) != 1 { + t.Fatalf("Expected 1 tool with false checker, got %d", len(availableFalse)) + } + + // With feature checker returning true for "my_feature", tool should be included + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + regTrue := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableTrue := regTrue.AvailableTools(context.Background()) + if len(availableTrue) != 2 { + t.Fatalf("Expected 2 tools with true checker, got %d", len(availableTrue)) + } +} + +func TestFeatureFlagDisable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), + } + + // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools without feature checker, got %d", len(available)) + } + + // With feature checker returning true for "kill_switch", tool should be excluded + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "kill_switch", nil + } + regFiltered := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableFiltered := regFiltered.AvailableTools(context.Background()) + if len(availableFiltered) != 1 { + t.Fatalf("Expected 1 tool with kill_switch enabled, got %d", len(availableFiltered)) + } + if availableFiltered[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", availableFiltered[0].Tool.Name) + } +} + +func TestFeatureFlagBoth(t *testing.T) { + // Tool that requires "new_feature" AND is disabled by "kill_switch" + tools := []ServerTool{ + mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), + } + + // Enable flag not set -> excluded + checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } + reg1 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker1).Build() + if len(reg1.AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when enable flag is false") + } + + // Enable flag set, disable flag not set -> included + checker2 := func(_ context.Context, flag string) (bool, error) { return flag == "new_feature", nil } + reg2 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker2).Build() + if len(reg2.AvailableTools(context.Background())) != 1 { + t.Error("Tool should be included when enable flag is true and disable flag is false") + } + + // Enable flag set, disable flag also set -> excluded (disable wins) + checker3 := func(_ context.Context, _ string) (bool, error) { return true, nil } + reg3 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker3).Build() + if len(reg3.AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when both flags are true (disable wins)") + } +} + +func TestFeatureFlagError(t *testing.T) { + tools := []ServerTool{ + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + // Checker that returns error should treat as false (tool excluded) + checkerError := func(_ context.Context, _ string) (bool, error) { + return false, fmt.Errorf("simulated error") + } + reg := NewBuilder().SetTools(tools).WithFeatureChecker(checkerError).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when checker errors, got %d", len(available)) + } +} + +func TestFeatureFlagResources(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("always_available", "toolset1", "uri1"), + { + Template: mcp.ResourceTemplate{Name: "needs_flag", URITemplate: "uri2"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + // Without checker, resource with enable flag should be excluded + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + available := reg.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + regWithChecker := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources with checker, got %d", len(regWithChecker.AvailableResourceTemplates(context.Background()))) + } +} + +func TestFeatureFlagPrompts(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("always_available", "toolset1"), + { + Prompt: mcp.Prompt{Name: "needs_flag"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + // Without checker, prompt with enable flag should be excluded + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + available := reg.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + regWithChecker := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts with checker, got %d", len(regWithChecker.AvailablePrompts(context.Background()))) + } +} + +func TestServerToolHasHandler(t *testing.T) { + // Tool with handler + toolWithHandler := mockTool("has_handler", "toolset1", true) + if !toolWithHandler.HasHandler() { + t.Error("Expected HasHandler() to return true for tool with handler") + } + + // Tool without handler + toolWithoutHandler := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + if toolWithoutHandler.HasHandler() { + t.Error("Expected HasHandler() to return false for tool without handler") + } +} + +func TestServerToolHandlerPanicOnNil(t *testing.T) { + tool := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + + defer func() { + if r := recover(); r == nil { + t.Error("Expected Handler() to panic when HandlerFunc is nil") + } + }() + + tool.Handler(nil) +} + +// Tests for Enabled function on ServerTool +func TestServerToolEnabled(t *testing.T) { + tests := []struct { + name string + enabledFunc func(ctx context.Context) (bool, error) + expectedCount int + expectInResult bool + }{ + { + name: "nil Enabled function - tool included", + enabledFunc: nil, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns true - tool included", + enabledFunc: func(_ context.Context) (bool, error) { + return true, nil + }, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns false - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, nil + }, + expectedCount: 0, + expectInResult: false, + }, + { + name: "Enabled returns error - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, fmt.Errorf("simulated error") + }, + expectedCount: 0, + expectInResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = tt.enabledFunc + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != tt.expectedCount { + t.Errorf("Expected %d tools, got %d", tt.expectedCount, len(available)) + } + + found := false + for _, t := range available { + if t.Tool.Name == "test_tool" { + found = true + break + } + } + if found != tt.expectInResult { + t.Errorf("Expected tool in result: %v, got: %v", tt.expectInResult, found) + } + }) + } +} + +func TestServerToolEnabledWithContext(t *testing.T) { + type contextKey string + const userKey contextKey = "user" + + // Tool that checks context for user + tool := mockTool("context_aware_tool", "toolset1", true) + tool.Enabled = func(ctx context.Context) (bool, error) { + user := ctx.Value(userKey) + return user != nil && user.(string) == "authorized", nil + } + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + + // Without user in context - tool should be excluded + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools without user, got %d", len(available)) + } + + // With authorized user - tool should be included + ctxWithUser := context.WithValue(context.Background(), userKey, "authorized") + availableWithUser := reg.AvailableTools(ctxWithUser) + if len(availableWithUser) != 1 { + t.Errorf("Expected 1 tool with authorized user, got %d", len(availableWithUser)) + } + + // With unauthorized user - tool should be excluded + ctxWithBadUser := context.WithValue(context.Background(), userKey, "unauthorized") + availableWithBadUser := reg.AvailableTools(ctxWithBadUser) + if len(availableWithBadUser) != 0 { + t.Errorf("Expected 0 tools with unauthorized user, got %d", len(availableWithBadUser)) + } +} + +// Tests for WithFilter builder method +func TestBuilderWithFilter(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + } + + // Filter that excludes tool2 + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter, got %d", len(available)) + } + + for _, tool := range available { + if tool.Tool.Name == "tool2" { + t.Error("tool2 should have been filtered out") + } + } +} + +func TestBuilderWithMultipleFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + mockTool("tool4", "toolset1", true), + } + + // First filter excludes tool2 + filter1 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + // Second filter excludes tool3 + filter2 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool3", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter1). + WithFilter(filter2). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after multiple filters, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["tool1"] || !toolNames["tool4"] { + t.Error("Expected tool1 and tool4 to be available") + } + if toolNames["tool2"] || toolNames["tool3"] { + t.Error("tool2 and tool3 should have been filtered out") + } +} + +func TestBuilderFilterError(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Filter that returns an error + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, fmt.Errorf("filter error") + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when filter returns error, got %d", len(available)) + } +} + +func TestBuilderFilterWithContext(t *testing.T) { + type contextKey string + const scopeKey contextKey = "scope" + + tools := []ServerTool{ + mockTool("public_tool", "toolset1", true), + mockTool("private_tool", "toolset1", true), + } + + // Filter that checks context for scope + filter := func(ctx context.Context, tool *ServerTool) (bool, error) { + scope := ctx.Value(scopeKey) + if scope == "public" && tool.Tool.Name == "private_tool" { + return false, nil + } + return true, nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + // With public scope - private_tool should be excluded + ctxPublic := context.WithValue(context.Background(), scopeKey, "public") + availablePublic := reg.AvailableTools(ctxPublic) + if len(availablePublic) != 1 { + t.Fatalf("Expected 1 tool with public scope, got %d", len(availablePublic)) + } + if availablePublic[0].Tool.Name != "public_tool" { + t.Error("Expected only public_tool to be available") + } + + // With private scope - both tools should be available + ctxPrivate := context.WithValue(context.Background(), scopeKey, "private") + availablePrivate := reg.AvailableTools(ctxPrivate) + if len(availablePrivate) != 2 { + t.Errorf("Expected 2 tools with private scope, got %d", len(availablePrivate)) + } +} + +// Tests for interaction between Enabled, feature flags, and filters +func TestEnabledAndFeatureFlagInteraction(t *testing.T) { + // Tool with both Enabled function and feature flag + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Feature flag not enabled - tool should be excluded despite Enabled returning true + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when feature flag is not enabled") + } + + // Feature flag enabled - tool should be included + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and feature flag pass") + } + + // Enabled returns false - tool should be excluded despite feature flag + tool.Enabled = func(_ context.Context) (bool, error) { + return false, nil + } + reg3 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available3 := reg3.AvailableTools(context.Background()) + if len(available3) != 0 { + t.Error("Tool should be excluded when Enabled returns false") + } +} + +func TestEnabledAndBuilderFilterInteraction(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Filter that excludes the tool + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when filter returns false, despite Enabled returning true") + } +} + +func TestAllFiltersInteraction(t *testing.T) { + // Tool with Enabled, feature flag, and subject to builder filter + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return true, nil + } + + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + + // All conditions pass - tool should be included + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when all filters pass") + } + + // Change filter to return false - tool should be excluded + filterFalse := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filterFalse). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when any filter fails") + } +} + +// Test FilteredTools method +func TestFilteredTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + } + + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name == "tool1", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + filtered, err := reg.FilteredTools(context.Background()) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + if len(filtered) != 1 { + t.Fatalf("Expected 1 filtered tool, got %d", len(filtered)) + } + + if filtered[0].Tool.Name != "tool1" { + t.Errorf("Expected tool1, got %s", filtered[0].Tool.Name) + } +} + +func TestFilteredToolsMatchesAvailableTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"toolset1"}). + WithReadOnly(true). + Build() + + ctx := context.Background() + filtered, err := reg.FilteredTools(ctx) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + available := reg.AvailableTools(ctx) + + // Both methods should return the same results + if len(filtered) != len(available) { + t.Errorf("FilteredTools and AvailableTools returned different counts: %d vs %d", + len(filtered), len(available)) + } + + for i := range filtered { + if filtered[i].Tool.Name != available[i].Tool.Name { + t.Errorf("Tool at index %d differs: FilteredTools=%s, AvailableTools=%s", + i, filtered[i].Tool.Name, available[i].Tool.Name) + } + } +} + +func TestFilteringOrder(t *testing.T) { + // Test that filters are applied in the correct order: + // 1. Tool.Enabled + // 2. Feature flags + // 3. Read-only + // 4. Builder filters + // 5. Toolset/additional tools + + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", false, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "Enabled") + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + callOrder = append(callOrder, "Filter") + return true, nil + } + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithReadOnly(true). // This will exclude the tool (it's not read-only) + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Expected order: Enabled, FeatureFlag, ReadOnly (stops here because it's write tool) + expectedOrder := []string{"Enabled", "FeatureFlag"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } + } +} diff --git a/pkg/registry/resources.go b/pkg/registry/resources.go new file mode 100644 index 000000000..99e0240c5 --- /dev/null +++ b/pkg/registry/resources.go @@ -0,0 +1,48 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. +// This allows resources to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type ResourceHandlerFunc func(deps any) mcp.ResourceHandler + +// ServerResourceTemplate pairs a resource template with its toolset metadata. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + // HandlerFunc generates the handler when given dependencies. + // This allows resources to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc ResourceHandlerFunc + // Toolset identifies which toolset this resource belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this resource + // to be available. If set and the flag is not enabled, the resource is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource + // to be omitted. Used to disable resources when a feature flag is on. + FeatureFlagDisable string +} + +// HasHandler returns true if this resource has a handler function. +func (sr *ServerResourceTemplate) HasHandler() bool { + return sr.HandlerFunc != nil +} + +// Handler returns a resource handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all resources should have handlers. +func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { + if sr.HandlerFunc == nil { + panic("HandlerFunc is nil for resource: " + sr.Template.Name) + } + return sr.HandlerFunc(deps) +} + +// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { + return ServerResourceTemplate{ + Template: resourceTemplate, + HandlerFunc: handlerFn, + Toolset: toolset, + } +} diff --git a/pkg/registry/server_tool.go b/pkg/registry/server_tool.go new file mode 100644 index 000000000..3145b693d --- /dev/null +++ b/pkg/registry/server_tool.go @@ -0,0 +1,115 @@ +package registry + +import ( + "context" + "encoding/json" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// HandlerFunc is a function that takes dependencies and returns an MCP tool handler. +// This allows tools to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +// The deps parameter is typed as `any` to avoid circular dependencies - callers +// should define their own typed dependencies struct and type-assert as needed. +type HandlerFunc func(deps any) mcp.ToolHandler + +// ToolsetID is a unique identifier for a toolset. +// Using a distinct type provides compile-time type safety. +type ToolsetID string + +// ToolsetMetadata contains metadata about the toolset a tool belongs to. +type ToolsetMetadata struct { + // ID is the unique identifier for the toolset (e.g., "repos", "issues") + ID ToolsetID + // Description provides a human-readable description of the toolset + Description string + // Default indicates this toolset should be enabled by default + Default bool +} + +// ServerTool represents an MCP tool with metadata and a handler generator function. +// The tool definition is static, while the handler is generated on-demand +// when the tool is registered with a server. +// Tools are now self-describing with their toolset membership and read-only status +// derived from the Tool.Annotations.ReadOnlyHint field. +type ServerTool struct { + // Tool is the MCP tool definition containing name, description, schema, etc. + Tool mcp.Tool + + // Toolset contains metadata about which toolset this tool belongs to. + Toolset ToolsetMetadata + + // HandlerFunc generates the handler when given dependencies. + // This allows tools to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc HandlerFunc + + // FeatureFlagEnable specifies a feature flag that must be enabled for this tool + // to be available. If set and the flag is not enabled, the tool is omitted. + FeatureFlagEnable string + + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool + // to be omitted. Used to disable tools when a feature flag is on. + FeatureFlagDisable string + + // Enabled is an optional function called at build/filter time to determine + // if this tool should be available. If nil, the tool is considered enabled + // (subject to FeatureFlagEnable/FeatureFlagDisable checks). + // The context carries request-scoped information for the consumer to use. + // Returns (enabled, error). On error, the tool should be treated as disabled. + Enabled func(ctx context.Context) (bool, error) +} + +// IsReadOnly returns true if this tool is marked as read-only via annotations. +func (st *ServerTool) IsReadOnly() bool { + return st.Tool.Annotations != nil && st.Tool.Annotations.ReadOnlyHint +} + +// HasHandler returns true if this tool has a handler function. +func (st *ServerTool) HasHandler() bool { + return st.HandlerFunc != nil +} + +// Handler returns a tool handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all tools should have handlers. +func (st *ServerTool) Handler(deps any) mcp.ToolHandler { + if st.HandlerFunc == nil { + panic("HandlerFunc is nil for tool: " + st.Tool.Name) + } + return st.HandlerFunc(deps) +} + +// RegisterFunc registers the tool with the server using the provided dependencies. +// Panics if the tool has no handler - all tools should have handlers. +func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { + handler := st.Handler(deps) // This will panic if HandlerFunc is nil + s.AddTool(&st.Tool, handler) +} + +// NewServerTool creates a ServerTool from a tool definition, toolset metadata, and a typed handler function. +// The handler function takes dependencies (as any) and returns a typed handler. +// Callers should type-assert deps to their typed dependencies struct. +func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + HandlerFunc: func(deps any) mcp.ToolHandler { + typedHandler := handlerFn(deps) + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := typedHandler(ctx, req, arguments) + return resp, err + } + }, + } +} + +// NewServerToolFromHandler creates a ServerTool from a tool definition, toolset metadata, and a raw handler function. +// Use this when you have a handler that already conforms to mcp.ToolHandler. +func NewServerToolFromHandler(tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandler) ServerTool { + return ServerTool{Tool: tool, Toolset: toolset, HandlerFunc: handlerFn} +} diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go deleted file mode 100644 index 3e3e5d9f8..000000000 --- a/pkg/toolsets/server_tool.go +++ /dev/null @@ -1,99 +0,0 @@ -package toolsets - -import ( - "context" - "encoding/json" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// HandlerFunc is a function that takes dependencies and returns an MCP tool handler. -// This allows tools to be defined statically while their handlers are generated -// on-demand with the appropriate dependencies. -// The deps parameter is typed as `any` to avoid circular dependencies - callers -// should define their own typed dependencies struct and type-assert as needed. -type HandlerFunc func(deps any) mcp.ToolHandler - -// ServerTool represents an MCP tool with a handler generator function. -// The tool definition is static, while the handler is generated on-demand -// when the tool is registered with a server. -type ServerTool struct { - // Tool is the MCP tool definition containing name, description, schema, etc. - Tool mcp.Tool - - // HandlerFunc generates the handler when given dependencies. - // This allows tools to be passed around without handlers being set up, - // and handlers are only created when needed. - HandlerFunc HandlerFunc -} - -// Handler returns a tool handler by calling HandlerFunc with the given dependencies. -func (st *ServerTool) Handler(deps any) mcp.ToolHandler { - if st.HandlerFunc == nil { - return nil - } - return st.HandlerFunc(deps) -} - -// RegisterFunc registers the tool with the server using the provided dependencies. -func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { - handler := st.Handler(deps) - s.AddTool(&st.Tool, handler) -} - -// NewServerTool creates a ServerTool from a tool definition and a typed handler function. -// The handler function takes dependencies (as any) and returns a typed handler. -// Callers should type-assert deps to their typed dependencies struct. -func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { - return ServerTool{ - Tool: tool, - HandlerFunc: func(deps any) mcp.ToolHandler { - typedHandler := handlerFn(deps) - return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - var arguments In - if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err - } - resp, _, err := typedHandler(ctx, req, arguments) - return resp, err - } - }, - } -} - -// NewServerToolFromHandler creates a ServerTool from a tool definition and a raw handler function. -// Use this when you have a handler that already conforms to mcp.ToolHandler. -func NewServerToolFromHandler(tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandler) ServerTool { - return ServerTool{Tool: tool, HandlerFunc: handlerFn} -} - -// NewServerToolLegacy creates a ServerTool from a tool definition and an already-bound typed handler. -// This is for backward compatibility during the refactor - the handler doesn't use dependencies. -// Deprecated: Use NewServerTool instead for new code. -func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { - return ServerTool{ - Tool: tool, - HandlerFunc: func(_ any) mcp.ToolHandler { - return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - var arguments In - if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err - } - resp, _, err := handler(ctx, req, arguments) - return resp, err - } - }, - } -} - -// NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition and an already-bound raw handler. -// This is for backward compatibility during the refactor - the handler doesn't use dependencies. -// Deprecated: Use NewServerToolFromHandler instead for new code. -func NewServerToolFromHandlerLegacy(tool mcp.Tool, handler mcp.ToolHandler) ServerTool { - return ServerTool{ - Tool: tool, - HandlerFunc: func(_ any) mcp.ToolHandler { - return handler - }, - } -} diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go deleted file mode 100644 index 8502328d5..000000000 --- a/pkg/toolsets/toolsets.go +++ /dev/null @@ -1,372 +0,0 @@ -package toolsets - -import ( - "fmt" - "os" - "strings" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type ToolsetDoesNotExistError struct { - Name string -} - -func (e *ToolsetDoesNotExistError) Error() string { - return fmt.Sprintf("toolset %s does not exist", e.Name) -} - -func (e *ToolsetDoesNotExistError) Is(target error) bool { - if target == nil { - return false - } - if _, ok := target.(*ToolsetDoesNotExistError); ok { - return true - } - return false -} - -func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { - return &ToolsetDoesNotExistError{Name: name} -} - -// ServerTool is defined in server_tool.go - -type ServerResourceTemplate struct { - Template mcp.ResourceTemplate - Handler mcp.ResourceHandler -} - -func NewServerResourceTemplate(resourceTemplate mcp.ResourceTemplate, handler mcp.ResourceHandler) ServerResourceTemplate { - return ServerResourceTemplate{ - Template: resourceTemplate, - Handler: handler, - } -} - -type ServerPrompt struct { - Prompt mcp.Prompt - Handler mcp.PromptHandler -} - -func NewServerPrompt(prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { - return ServerPrompt{ - Prompt: prompt, - Handler: handler, - } -} - -// Toolset represents a collection of MCP functionality that can be enabled or disabled as a group. -type Toolset struct { - Name string - Description string - Enabled bool - readOnly bool - writeTools []ServerTool - readTools []ServerTool - // deps holds the dependencies for tool handlers (typed as any to avoid circular deps) - deps any - // resources are not tools, but the community seems to be moving towards namespaces as a broader concept - // and in order to have multiple servers running concurrently, we want to avoid overlapping resources too. - resourceTemplates []ServerResourceTemplate - // prompts are also not tools but are namespaced similarly - prompts []ServerPrompt -} - -func (t *Toolset) GetActiveTools() []ServerTool { - if t.Enabled { - if t.readOnly { - return t.readTools - } - return append(t.readTools, t.writeTools...) - } - return nil -} - -func (t *Toolset) GetAvailableTools() []ServerTool { - if t.readOnly { - return t.readTools - } - return append(t.readTools, t.writeTools...) -} - -func (t *Toolset) RegisterTools(s *mcp.Server) { - if !t.Enabled { - return - } - for i := range t.readTools { - t.readTools[i].RegisterFunc(s, t.deps) - } - if !t.readOnly { - for i := range t.writeTools { - t.writeTools[i].RegisterFunc(s, t.deps) - } - } -} - -// SetDependencies sets the dependencies for this toolset's tool handlers. -// The deps parameter is typed as `any` to avoid circular dependencies between packages. -func (t *Toolset) SetDependencies(deps any) *Toolset { - t.deps = deps - return t -} - -func (t *Toolset) AddResourceTemplates(templates ...ServerResourceTemplate) *Toolset { - t.resourceTemplates = append(t.resourceTemplates, templates...) - return t -} - -func (t *Toolset) AddPrompts(prompts ...ServerPrompt) *Toolset { - t.prompts = append(t.prompts, prompts...) - return t -} - -func (t *Toolset) GetActiveResourceTemplates() []ServerResourceTemplate { - if !t.Enabled { - return nil - } - return t.resourceTemplates -} - -func (t *Toolset) GetAvailableResourceTemplates() []ServerResourceTemplate { - return t.resourceTemplates -} - -func (t *Toolset) RegisterResourcesTemplates(s *mcp.Server) { - if !t.Enabled { - return - } - for _, resource := range t.resourceTemplates { - s.AddResourceTemplate(&resource.Template, resource.Handler) - } -} - -func (t *Toolset) RegisterPrompts(s *mcp.Server) { - if !t.Enabled { - return - } - for _, prompt := range t.prompts { - s.AddPrompt(&prompt.Prompt, prompt.Handler) - } -} - -func (t *Toolset) SetReadOnly() { - // Set the toolset to read-only - t.readOnly = true -} - -func (t *Toolset) AddWriteTools(tools ...ServerTool) *Toolset { - // Silently ignore if the toolset is read-only to avoid any breach of that contract - for _, tool := range tools { - if tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) is incorrectly annotated as read-only", tool.Tool.Name)) - } - } - if !t.readOnly { - t.writeTools = append(t.writeTools, tools...) - } - return t -} - -func (t *Toolset) AddReadTools(tools ...ServerTool) *Toolset { - for _, tool := range tools { - if !tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) must be annotated as read-only", tool.Tool.Name)) - } - } - t.readTools = append(t.readTools, tools...) - return t -} - -type ToolsetGroup struct { - Toolsets map[string]*Toolset - deprecatedAliases map[string]string - everythingOn bool - readOnly bool -} - -func NewToolsetGroup(readOnly bool) *ToolsetGroup { - return &ToolsetGroup{ - Toolsets: make(map[string]*Toolset), - deprecatedAliases: make(map[string]string), - everythingOn: false, - readOnly: readOnly, - } -} - -func (tg *ToolsetGroup) AddDeprecatedToolAliases(aliases map[string]string) { - for oldName, newName := range aliases { - tg.deprecatedAliases[oldName] = newName - } -} - -func (tg *ToolsetGroup) AddToolset(ts *Toolset) { - if tg.readOnly { - ts.SetReadOnly() - } - tg.Toolsets[ts.Name] = ts -} - -func NewToolset(name string, description string) *Toolset { - return &Toolset{ - Name: name, - Description: description, - Enabled: false, - readOnly: false, - } -} - -func (tg *ToolsetGroup) IsEnabled(name string) bool { - // If everythingOn is true, all features are enabled - if tg.everythingOn { - return true - } - - feature, exists := tg.Toolsets[name] - if !exists { - return false - } - return feature.Enabled -} - -type EnableToolsetsOptions struct { - ErrorOnUnknown bool -} - -func (tg *ToolsetGroup) EnableToolsets(names []string, options *EnableToolsetsOptions) error { - if options == nil { - options = &EnableToolsetsOptions{ - ErrorOnUnknown: false, - } - } - - // Special case for "all" - for _, name := range names { - if name == "all" { - tg.everythingOn = true - break - } - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err - } - } - // Do this after to ensure all toolsets are enabled if "all" is present anywhere in list - if tg.everythingOn { - for name := range tg.Toolsets { - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err - } - } - return nil - } - return nil -} - -func (tg *ToolsetGroup) EnableToolset(name string) error { - toolset, exists := tg.Toolsets[name] - if !exists { - return NewToolsetDoesNotExistError(name) - } - toolset.Enabled = true - tg.Toolsets[name] = toolset - return nil -} - -func (tg *ToolsetGroup) RegisterAll(s *mcp.Server) { - for _, toolset := range tg.Toolsets { - toolset.RegisterTools(s) - toolset.RegisterResourcesTemplates(s) - toolset.RegisterPrompts(s) - } -} - -func (tg *ToolsetGroup) GetToolset(name string) (*Toolset, error) { - toolset, exists := tg.Toolsets[name] - if !exists { - return nil, NewToolsetDoesNotExistError(name) - } - return toolset, nil -} - -type ToolDoesNotExistError struct { - Name string -} - -func (e *ToolDoesNotExistError) Error() string { - return fmt.Sprintf("tool %s does not exist", e.Name) -} - -func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { - return &ToolDoesNotExistError{Name: name} -} - -// ResolveToolAliases resolves deprecated tool aliases to their canonical names. -// It logs a warning to stderr for each deprecated alias that is resolved. -// Returns: -// - resolved: tool names with aliases replaced by canonical names -// - aliasesUsed: map of oldName → newName for each alias that was resolved -func (tg *ToolsetGroup) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { - resolved = make([]string, 0, len(toolNames)) - aliasesUsed = make(map[string]string) - for _, toolName := range toolNames { - if canonicalName, isAlias := tg.deprecatedAliases[toolName]; isAlias { - fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) - aliasesUsed[toolName] = canonicalName - resolved = append(resolved, canonicalName) - } else { - resolved = append(resolved, toolName) - } - } - return resolved, aliasesUsed -} - -// FindToolByName searches all toolsets (enabled or disabled) for a tool by name. -// Returns the tool, its parent toolset name, and an error if not found. -func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, string, error) { - for toolsetName, toolset := range tg.Toolsets { - // Check read tools - for _, tool := range toolset.readTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } - } - // Check write tools - for _, tool := range toolset.writeTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } - } - } - return nil, "", NewToolDoesNotExistError(toolName) -} - -// RegisterSpecificTools registers only the specified tools. -// Respects read-only mode (skips write tools if readOnly=true). -// Returns error if any tool is not found. -func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool, deps any) error { - var skippedTools []string - for _, toolName := range toolNames { - tool, _, err := tg.FindToolByName(toolName) - if err != nil { - return fmt.Errorf("tool %s not found: %w", toolName, err) - } - - if !tool.Tool.Annotations.ReadOnlyHint && readOnly { - // Skip write tools in read-only mode - skippedTools = append(skippedTools, toolName) - continue - } - - // Register the tool - tool.RegisterFunc(s, deps) - } - - // Log skipped write tools if any - if len(skippedTools) > 0 { - fmt.Fprintf(os.Stderr, "Write tools skipped due to read-only mode: %s\n", strings.Join(skippedTools, ", ")) - } - - return nil -} diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go deleted file mode 100644 index 66be5cba2..000000000 --- a/pkg/toolsets/toolsets_test.go +++ /dev/null @@ -1,408 +0,0 @@ -package toolsets - -import ( - "context" - "encoding/json" - "errors" - "testing" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// mockTool creates a minimal ServerTool for testing -func mockTool(name string, readOnly bool) ServerTool { - return NewServerToolFromHandler( - mcp.Tool{ - Name: name, - Annotations: &mcp.ToolAnnotations{ - ReadOnlyHint: readOnly, - }, - InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), - }, - func(_ any) mcp.ToolHandler { - return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return nil, nil - } - }, - ) -} - -func TestNewToolsetGroupIsEmptyWithoutEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) - if len(tsg.Toolsets) != 0 { - t.Fatalf("Expected Toolsets map to be empty, got %d items", len(tsg.Toolsets)) - } - if tsg.everythingOn { - t.Fatal("Expected everythingOn to be initialized as false") - } -} - -func TestAddToolset(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test adding a toolset - toolset := NewToolset("test-toolset", "A test toolset") - toolset.Enabled = true - tsg.AddToolset(toolset) - - // Verify toolset was added correctly - if len(tsg.Toolsets) != 1 { - t.Errorf("Expected 1 toolset, got %d", len(tsg.Toolsets)) - } - - toolset, exists := tsg.Toolsets["test-toolset"] - if !exists { - t.Fatal("Feature was not added to the map") - } - - if toolset.Name != "test-toolset" { - t.Errorf("Expected toolset name to be 'test-toolset', got '%s'", toolset.Name) - } - - if toolset.Description != "A test toolset" { - t.Errorf("Expected toolset description to be 'A test toolset', got '%s'", toolset.Description) - } - - if !toolset.Enabled { - t.Error("Expected toolset to be enabled") - } - - // Test adding another toolset - anotherToolset := NewToolset("another-toolset", "Another test toolset") - tsg.AddToolset(anotherToolset) - - if len(tsg.Toolsets) != 2 { - t.Errorf("Expected 2 toolsets, got %d", len(tsg.Toolsets)) - } - - // Test overriding existing toolset - updatedToolset := NewToolset("test-toolset", "Updated description") - tsg.AddToolset(updatedToolset) - - toolset = tsg.Toolsets["test-toolset"] - if toolset.Description != "Updated description" { - t.Errorf("Expected toolset description to be updated to 'Updated description', got '%s'", toolset.Description) - } - - if toolset.Enabled { - t.Error("Expected toolset to be disabled after update") - } -} - -func TestIsEnabled(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test with non-existent toolset - if tsg.IsEnabled("non-existent") { - t.Error("Expected IsEnabled to return false for non-existent toolset") - } - - // Test with disabled toolset - disabledToolset := NewToolset("disabled-toolset", "A disabled toolset") - tsg.AddToolset(disabledToolset) - if tsg.IsEnabled("disabled-toolset") { - t.Error("Expected IsEnabled to return false for disabled toolset") - } - - // Test with enabled toolset - enabledToolset := NewToolset("enabled-toolset", "An enabled toolset") - enabledToolset.Enabled = true - tsg.AddToolset(enabledToolset) - if !tsg.IsEnabled("enabled-toolset") { - t.Error("Expected IsEnabled to return true for enabled toolset") - } -} - -func TestEnableFeature(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test enabling non-existent toolset - err := tsg.EnableToolset("non-existent") - if err == nil { - t.Error("Expected error when enabling non-existent toolset") - } - - // Test enabling toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) - - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") - } - - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling toolset, got: %v", err) - } - - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled after EnableFeature call") - } - - // Test enabling already enabled toolset - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling already enabled toolset, got: %v", err) - } -} - -func TestEnableToolsets(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Prepare toolsets - toolset1 := NewToolset("toolset1", "Feature 1") - toolset2 := NewToolset("toolset2", "Feature 2") - tsg.AddToolset(toolset1) - tsg.AddToolset(toolset2) - - // Test enabling multiple toolsets - err := tsg.EnableToolsets([]string{"toolset1", "toolset2"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling toolsets, got: %v", err) - } - - if !tsg.IsEnabled("toolset1") { - t.Error("Expected toolset1 to be enabled") - } - - if !tsg.IsEnabled("toolset2") { - t.Error("Expected toolset2 to be enabled") - } - - // Test with non-existent toolset in the list - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, nil) - if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) - } - - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ - ErrorOnUnknown: false, - }) - if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) - } - - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ErrorOnUnknown: true}) - if err == nil { - t.Error("Expected error when enabling list with non-existent toolset") - } - if !errors.Is(err, NewToolsetDoesNotExistError("non-existent")) { - t.Errorf("Expected ToolsetDoesNotExistError when enabling non-existent toolset, got: %v", err) - } - - // Test with empty list - err = tsg.EnableToolsets([]string{}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error with empty toolset list, got: %v", err) - } - - // Test enabling everything through EnableToolsets - tsg = NewToolsetGroup(false) - err = tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all' via EnableToolsets") - } -} - -func TestEnableEverything(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Add a disabled toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) - - // Verify it's disabled - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") - } - - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - // Verify everythingOn was set - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all'") - } - - // Verify the previously disabled toolset is now enabled - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled when everythingOn is true") - } - - // Verify a non-existent toolset is also enabled - if !tsg.IsEnabled("non-existent") { - t.Error("Expected non-existent toolset to be enabled when everythingOn is true") - } -} - -func TestIsEnabledWithEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - // Test that any toolset name returns true with IsEnabled - if !tsg.IsEnabled("some-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") - } - - if !tsg.IsEnabled("another-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") - } -} - -func TestToolsetGroup_GetToolset(t *testing.T) { - tsg := NewToolsetGroup(false) - toolset := NewToolset("my-toolset", "desc") - tsg.AddToolset(toolset) - - // Should find the toolset - got, err := tsg.GetToolset("my-toolset") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if got != toolset { - t.Errorf("expected to get the same toolset instance") - } - - // Should not find a non-existent toolset - _, err = tsg.GetToolset("does-not-exist") - if err == nil { - t.Error("expected error for missing toolset, got nil") - } - if !errors.Is(err, NewToolsetDoesNotExistError("does-not-exist")) { - t.Errorf("expected error to be ToolsetDoesNotExistError, got %v", err) - } -} - -func TestAddDeprecatedToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test adding aliases - tsg.AddDeprecatedToolAliases(map[string]string{ - "old_name": "new_name", - "get_issue": "issue_read", - "create_pr": "pull_request_create", - }) - - if len(tsg.deprecatedAliases) != 3 { - t.Errorf("expected 3 aliases, got %d", len(tsg.deprecatedAliases)) - } - if tsg.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsg.deprecatedAliases["old_name"]) - } - if tsg.deprecatedAliases["get_issue"] != "issue_read" { - t.Errorf("expected alias 'get_issue' -> 'issue_read'") - } - if tsg.deprecatedAliases["create_pr"] != "pull_request_create" { - t.Errorf("expected alias 'create_pr' -> 'pull_request_create'") - } -} - -func TestResolveToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) - tsg.AddDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - "create_pr": "pull_request_create", - }) - - // Test resolving a mix of aliases and canonical names - input := []string{"get_issue", "some_tool", "create_pr"} - resolved, aliasesUsed := tsg.ResolveToolAliases(input) - - // Verify resolved names - if len(resolved) != 3 { - t.Fatalf("expected 3 resolved names, got %d", len(resolved)) - } - if resolved[0] != "issue_read" { - t.Errorf("expected 'issue_read', got '%s'", resolved[0]) - } - if resolved[1] != "some_tool" { - t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) - } - if resolved[2] != "pull_request_create" { - t.Errorf("expected 'pull_request_create', got '%s'", resolved[2]) - } - - // Verify aliasesUsed map - if len(aliasesUsed) != 2 { - t.Fatalf("expected 2 aliases used, got %d", len(aliasesUsed)) - } - if aliasesUsed["get_issue"] != "issue_read" { - t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) - } - if aliasesUsed["create_pr"] != "pull_request_create" { - t.Errorf("expected aliasesUsed['create_pr'] = 'pull_request_create', got '%s'", aliasesUsed["create_pr"]) - } -} - -func TestFindToolByName(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Create a toolset with a tool - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - tsg.AddToolset(toolset) - - // Find by canonical name - tool, toolsetName, err := tsg.FindToolByName("issue_read") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if tool.Tool.Name != "issue_read" { - t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) - } - if toolsetName != "test-toolset" { - t.Errorf("expected toolset name 'test-toolset', got '%s'", toolsetName) - } - - // FindToolByName does NOT resolve aliases - it expects canonical names - _, _, err = tsg.FindToolByName("get_issue") - if err == nil { - t.Error("expected error when using alias directly with FindToolByName") - } -} - -func TestRegisterSpecificTools(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Create a toolset with both read and write tools - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - toolset.writeTools = append(toolset.writeTools, mockTool("issue_write", false)) - tsg.AddToolset(toolset) - - // deps is typed as any in toolsets package (to avoid circular deps) - var deps any - - // Create a real server for testing - server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) - - // Test registering with canonical names - err := tsg.RegisterSpecificTools(server, []string{"issue_read"}, false, deps) - if err != nil { - t.Errorf("expected no error registering tool, got %v", err) - } - - // Test registering write tool in read-only mode (should skip but not error) - err = tsg.RegisterSpecificTools(server, []string{"issue_write"}, true, deps) - if err != nil { - t.Errorf("expected no error when skipping write tool in read-only mode, got %v", err) - } - - // Test registering non-existent tool (should error) - err = tsg.RegisterSpecificTools(server, []string{"nonexistent"}, false, deps) - if err == nil { - t.Error("expected error for non-existent tool") - } -} diff --git a/script/conformance-test b/script/conformance-test new file mode 100755 index 000000000..3ff0a55c2 --- /dev/null +++ b/script/conformance-test @@ -0,0 +1,432 @@ +#!/bin/bash +set -e + +# Conformance test script for comparing MCP server behavior between branches +# Builds both main and current branch, runs various flag combinations, +# and produces a conformance report with timing and diffs. +# +# Output: +# - Progress/status messages go to stderr (for visibility in CI) +# - Final report summary goes to stdout (for piping/capture) + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +REPORT_DIR="$PROJECT_DIR/conformance-report" +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +# Colors for output (only used on stderr) +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Helper to print to stderr +log() { + echo -e "$@" >&2 +} + +log "${BLUE}=== MCP Server Conformance Test ===${NC}" +log "Current branch: $CURRENT_BRANCH" +log "Report directory: $REPORT_DIR" + +# Find the common ancestor +MERGE_BASE=$(git merge-base HEAD origin/main) +log "Comparing against merge-base: $MERGE_BASE" +log "" + +# Create report directory +rm -rf "$REPORT_DIR" +mkdir -p "$REPORT_DIR"/{main,branch,diffs} + +# Build binaries +log "${YELLOW}Building binaries...${NC}" + +log "Building current branch ($CURRENT_BRANCH)..." +go build -o "$REPORT_DIR/branch/github-mcp-server" ./cmd/github-mcp-server +BRANCH_BUILD_OK=$? + +log "Building main branch (using temp worktree at merge-base)..." +TEMP_WORKTREE=$(mktemp -d) +git worktree add --quiet "$TEMP_WORKTREE" "$MERGE_BASE" +(cd "$TEMP_WORKTREE" && go build -o "$REPORT_DIR/main/github-mcp-server" ./cmd/github-mcp-server) +MAIN_BUILD_OK=$? +git worktree remove --force "$TEMP_WORKTREE" + +if [ $BRANCH_BUILD_OK -ne 0 ] || [ $MAIN_BUILD_OK -ne 0 ]; then + log "${RED}Build failed!${NC}" + exit 1 +fi + +log "${GREEN}Both binaries built successfully${NC}" +log "" + +# MCP JSON-RPC messages +INIT_MSG='{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"conformance-test","version":"1.0.0"}}}' +INITIALIZED_MSG='{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}' +LIST_TOOLS_MSG='{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' +LIST_RESOURCES_MSG='{"jsonrpc":"2.0","id":3,"method":"resources/listTemplates","params":{}}' +LIST_PROMPTS_MSG='{"jsonrpc":"2.0","id":4,"method":"prompts/list","params":{}}' + +# Dynamic toolset management tool calls (for dynamic mode testing) +LIST_TOOLSETS_MSG='{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' +GET_TOOLSET_TOOLS_MSG='{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"get_toolset_tools","arguments":{"toolset":"repos"}}}' +ENABLE_TOOLSET_MSG='{"jsonrpc":"2.0","id":12,"method":"tools/call","params":{"name":"enable_toolset","arguments":{"toolset":"repos"}}}' +LIST_TOOLSETS_AFTER_MSG='{"jsonrpc":"2.0","id":13,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' + +# Function to normalize JSON for comparison +# Sorts all arrays (including nested ones) and formats consistently +# Also handles embedded JSON strings in "text" fields (from tool call responses) +normalize_json() { + local file="$1" + if [ -s "$file" ]; then + # First, try to parse and re-serialize any JSON embedded in text fields + # This handles tool call responses where the result is JSON-in-a-string + jq -S ' + # Function to sort arrays recursively + def deep_sort: + if type == "array" then + [.[] | deep_sort] | sort_by(tostring) + elif type == "object" then + to_entries | map(.value |= deep_sort) | from_entries + else + . + end; + + # Walk the structure, and for any "text" field that looks like JSON array/object, parse and sort it + walk( + if type == "object" and .text and (.text | type == "string") and ((.text | startswith("[")) or (.text | startswith("{"))) then + .text = ((.text | fromjson | deep_sort) | tojson) + else + . + end + ) | deep_sort + ' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" + fi +} + +# Function to run MCP server and capture output with timing +run_mcp_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with all list commands - each response is on its own line + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLS_MSG" + echo "$LIST_RESOURCES_MSG" + echo "$LIST_PROMPTS_MSG" + sleep 0.5 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + # Each line is a separate JSON response + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 2) echo "$line" | jq -S '.' > "${output_prefix}_tools.json" 2>/dev/null ;; + 3) echo "$line" | jq -S '.' > "${output_prefix}_resources.json" 2>/dev/null ;; + 4) echo "$line" | jq -S '.' > "${output_prefix}_prompts.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created (in case of errors or missing responses) + touch "${output_prefix}_initialize.json" "${output_prefix}_tools.json" \ + "${output_prefix}_resources.json" "${output_prefix}_prompts.json" + + # Normalize all JSON files for consistent comparison (sorts arrays, keys) + for endpoint in initialize tools resources prompts; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Function to run MCP server with dynamic tool calls (for dynamic mode testing) +run_mcp_dynamic_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with dynamic tool calls in sequence: + # 1. Initialize + # 2. List available toolsets (before enable) + # 3. Get tools for repos toolset + # 4. Enable repos toolset + # 5. List available toolsets (after enable - should show repos as enabled) + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLSETS_MSG" + sleep 0.1 + echo "$GET_TOOLSET_TOOLS_MSG" + sleep 0.1 + echo "$ENABLE_TOOLSET_MSG" + sleep 0.1 + echo "$LIST_TOOLSETS_AFTER_MSG" + sleep 0.3 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 10) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_before.json" 2>/dev/null ;; + 11) echo "$line" | jq -S '.' > "${output_prefix}_get_toolset_tools.json" 2>/dev/null ;; + 12) echo "$line" | jq -S '.' > "${output_prefix}_enable_toolset.json" 2>/dev/null ;; + 13) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_after.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created + touch "${output_prefix}_initialize.json" "${output_prefix}_list_toolsets_before.json" \ + "${output_prefix}_get_toolset_tools.json" "${output_prefix}_enable_toolset.json" \ + "${output_prefix}_list_toolsets_after.json" + + # Normalize all JSON files + for endpoint in initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Test configurations - array of "name|flags|type" +# type can be "standard" or "dynamic" (for dynamic tool call testing) +declare -a TEST_CONFIGS=( + "default||standard" + "read-only|--read-only|standard" + "dynamic-toolsets|--dynamic-toolsets|standard" + "read-only+dynamic|--read-only --dynamic-toolsets|standard" + "toolsets-repos|--toolsets=repos|standard" + "toolsets-issues|--toolsets=issues|standard" + "toolsets-pull_requests|--toolsets=pull_requests|standard" + "toolsets-repos,issues|--toolsets=repos,issues|standard" + "toolsets-all|--toolsets=all|standard" + "tools-get_me|--tools=get_me|standard" + "tools-get_me,list_issues|--tools=get_me,list_issues|standard" + "toolsets-repos+read-only|--toolsets=repos --read-only|standard" + "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets|standard" + "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets|standard" + "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets|standard" + "dynamic-tool-calls|--dynamic-toolsets|dynamic" +) + +# Summary arrays +declare -a TEST_NAMES +declare -a MAIN_TIMES +declare -a BRANCH_TIMES +declare -a DIFF_STATUS + +log "${YELLOW}Running conformance tests...${NC}" +log "" + +for config in "${TEST_CONFIGS[@]}"; do + IFS='|' read -r test_name flags test_type <<< "$config" + + log "${BLUE}Test: ${test_name}${NC}" + log " Flags: ${flags:-}" + log " Type: ${test_type}" + + # Create output directories + mkdir -p "$REPORT_DIR/main/$test_name" + mkdir -p "$REPORT_DIR/branch/$test_name" + mkdir -p "$REPORT_DIR/diffs/$test_name" + + if [ "$test_type" = "dynamic" ]; then + # Run dynamic tool call test + main_time=$(run_mcp_dynamic_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_dynamic_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after" + else + # Run standard test + main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize tools resources prompts" + fi + + # Calculate time difference + time_diff=$(echo "$branch_time - $main_time" | bc) + if (( $(echo "$time_diff > 0" | bc -l) )); then + log " Δ Time: ${RED}+${time_diff}s (slower)${NC}" + else + log " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" + fi + + # Generate diffs for each endpoint + has_diff=false + for endpoint in $endpoints; do + main_file="$REPORT_DIR/main/$test_name/output_${endpoint}.json" + branch_file="$REPORT_DIR/branch/$test_name/output_${endpoint}.json" + diff_file="$REPORT_DIR/diffs/$test_name/${endpoint}.diff" + + if ! diff -u "$main_file" "$branch_file" > "$diff_file" 2>/dev/null; then + has_diff=true + lines=$(wc -l < "$diff_file" | tr -d ' ') + log " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" + else + rm -f "$diff_file" # No diff, remove empty file + log " ${GREEN}${endpoint}: OK${NC}" + fi + done + + # Store results + TEST_NAMES+=("$test_name") + MAIN_TIMES+=("$main_time") + BRANCH_TIMES+=("$branch_time") + if [ "$has_diff" = true ]; then + DIFF_STATUS+=("DIFF") + else + DIFF_STATUS+=("OK") + fi + + log "" +done + +# Generate summary report +REPORT_FILE="$REPORT_DIR/CONFORMANCE_REPORT.md" + +cat > "$REPORT_FILE" << EOF +# MCP Server Conformance Report + +Generated: $(date) +Current Branch: $CURRENT_BRANCH +Compared Against: merge-base ($MERGE_BASE) + +## Summary + +| Test | Main Time | Branch Time | Δ Time | Status | +|------|-----------|-------------|--------|--------| +EOF + +total_main=0 +total_branch=0 +diff_count=0 +ok_count=0 + +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + main_t="${MAIN_TIMES[$i]}" + branch_t="${BRANCH_TIMES[$i]}" + status="${DIFF_STATUS[$i]}" + + delta=$(echo "$branch_t - $main_t" | bc) + if (( $(echo "$delta > 0" | bc -l) )); then + delta_str="+${delta}s" + else + delta_str="${delta}s" + fi + + if [ "$status" = "DIFF" ]; then + status_str="⚠️ DIFF" + ((diff_count++)) || true + else + status_str="✅ OK" + ((ok_count++)) || true + fi + + total_main=$(echo "$total_main + $main_t" | bc) + total_branch=$(echo "$total_branch + $branch_t" | bc) + + echo "| $name | ${main_t}s | ${branch_t}s | $delta_str | $status_str |" >> "$REPORT_FILE" +done + +total_delta=$(echo "$total_branch - $total_main" | bc) +if (( $(echo "$total_delta > 0" | bc -l) )); then + total_delta_str="+${total_delta}s" +else + total_delta_str="${total_delta}s" +fi + +cat >> "$REPORT_FILE" << EOF +| **TOTAL** | **${total_main}s** | **${total_branch}s** | **$total_delta_str** | **$ok_count OK / $diff_count DIFF** | + +## Statistics + +- **Tests Passed (no diff):** $ok_count +- **Tests with Differences:** $diff_count +- **Total Main Time:** ${total_main}s +- **Total Branch Time:** ${total_branch}s +- **Overall Time Delta:** $total_delta_str + +## Detailed Diffs + +EOF + +# Add diff details to report +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + status="${DIFF_STATUS[$i]}" + + if [ "$status" = "DIFF" ]; then + echo "### $name" >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + + # Check all possible endpoints + for endpoint in initialize tools resources prompts list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + diff_file="$REPORT_DIR/diffs/$name/${endpoint}.diff" + if [ -f "$diff_file" ] && [ -s "$diff_file" ]; then + echo "#### ${endpoint}" >> "$REPORT_FILE" + echo '```diff' >> "$REPORT_FILE" + cat "$diff_file" >> "$REPORT_FILE" + echo '```' >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + fi + done + fi +done + +log "${BLUE}=== Conformance Test Complete ===${NC}" +log "" +log "Report: ${GREEN}$REPORT_FILE${NC}" +log "" + +# Output summary to stdout (for CI capture) +echo "=== Conformance Test Summary ===" +echo "Tests passed: $ok_count" +echo "Tests with diffs: $diff_count" +echo "Total main time: ${total_main}s" +echo "Total branch time: ${total_branch}s" +echo "Time delta: $total_delta_str" + +if [ $diff_count -gt 0 ]; then + log "" + log "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" + log " $REPORT_DIR/diffs/" + echo "" + echo "RESULT: DIFFERENCES FOUND" + # Don't exit with error - diffs may be intentional improvements +else + echo "" + echo "RESULT: ALL TESTS PASSED" +fi