Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/github-mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"strings"
"time"

"github.com/github/github-mcp-server/internal/ghmcp"
"github.com/github/github-mcp-server/pkg/github"
Expand Down Expand Up @@ -84,6 +85,9 @@ var (
return fmt.Errorf("failed to unmarshal toolsets: %w", err)
}

// Pre-compute heartbeat interval
hbInterval, _ := time.ParseDuration(viper.GetString("http_heartbeat_interval"))

httpServerConfig := ghmcp.HttpServerConfig{
Version: version,
Host: viper.GetString("host"),
Expand All @@ -101,6 +105,7 @@ var (
AppPrivateKey: appPrivateKey,
EnableGitHubAppAuth: enableGitHubAppAuth,
InstallationIDHeader: viper.GetString("installation_id_header"),
HeartbeatInterval: hbInterval,
}

return ghmcp.RunHTTPServer(httpServerConfig)
Expand Down Expand Up @@ -133,6 +138,7 @@ func init() {
httpCmd.Flags().String("http-address", ":8080", "HTTP server address to bind to")
httpCmd.Flags().String("http-mcp-path", "/mcp", "HTTP path for MCP endpoint")
httpCmd.Flags().Bool("http-enable-cors", false, "Enable CORS for cross-origin requests")
httpCmd.Flags().String("http-heartbeat-interval", "15s", "Interval for SSE heartbeats on GET listener (e.g., 15s; set 0 to disable)")

// Bind flags to viper
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
Expand All @@ -149,6 +155,7 @@ func init() {
_ = viper.BindPFlag("http_address", httpCmd.Flags().Lookup("http-address"))
_ = viper.BindPFlag("http_mcp_path", httpCmd.Flags().Lookup("http-mcp-path"))
_ = viper.BindPFlag("http_enable_cors", httpCmd.Flags().Lookup("http-enable-cors"))
_ = viper.BindPFlag("http_heartbeat_interval", httpCmd.Flags().Lookup("http-heartbeat-interval"))

// Add subcommands
rootCmd.AddCommand(stdioCmd)
Expand Down
11 changes: 10 additions & 1 deletion internal/ghmcp/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"github.com/sirupsen/logrus"
)

type HttpServerConfig struct {

Check failure on line 18 in internal/ghmcp/http_server.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: type HttpServerConfig should be HTTPServerConfig (revive)
// Version of the server
Version string

Expand Down Expand Up @@ -66,6 +66,10 @@

// Custom header name to read installation ID from (defaults to "X-GitHub-Installation-ID")
InstallationIDHeader string

// HeartbeatInterval controls how often the server sends SSE heartbeat pings on the GET listener
// Set to 0 to disable heartbeats
HeartbeatInterval time.Duration
}

const installationContextKey = "installation_id"
Expand Down Expand Up @@ -96,7 +100,12 @@
return fmt.Errorf("failed to create MCP server: %w", err)
}

httpServer := server.NewStreamableHTTPServer(ghServer)
// Configure the streamable HTTP server with optional heartbeat pings
var httpOpts []server.StreamableHTTPOption
if cfg.HeartbeatInterval > 0 {
httpOpts = append(httpOpts, server.WithHeartbeatInterval(cfg.HeartbeatInterval))
}
httpServer := server.NewStreamableHTTPServer(ghServer, httpOpts...)

logrusLogger := logrus.New()
if cfg.LogFilePath != "" {
Expand Down Expand Up @@ -133,7 +142,7 @@

mux.Handle(cfg.MCPPath, handler)

srv := &http.Server{

Check failure on line 145 in internal/ghmcp/http_server.go

View workflow job for this annotation

GitHub Actions / lint

G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server (gosec)
Addr: cfg.Address,
Handler: mux,
}
Expand Down Expand Up @@ -215,7 +224,7 @@
return
}

ctx := context.WithValue(r.Context(), installationContextKey, installationID)

Check failure on line 227 in internal/ghmcp/http_server.go

View workflow job for this annotation

GitHub Actions / lint

context-keys-type: should not use basic type untyped string as key in context.WithValue (revive)
r = r.WithContext(ctx)

if logger.GetLevel() == logrus.DebugLevel {
Expand Down
Loading