From 370ebca56795563a469fdab307b1c662fff2b05b Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 22 Jan 2026 11:53:08 +0000 Subject: [PATCH 01/38] initial oauth metadata implementation --- cmd/github-mcp-server/main.go | 3 + pkg/http/handler.go | 12 +- pkg/http/headers/headers.go | 5 + pkg/http/middleware/token.go | 14 +- pkg/http/oauth/oauth.go | 225 ++++++++++++++++++++ pkg/http/oauth/protected_resource.json.tmpl | 20 ++ pkg/http/server.go | 19 +- 7 files changed, 294 insertions(+), 4 deletions(-) create mode 100644 pkg/http/oauth/oauth.go create mode 100644 pkg/http/oauth/protected_resource.json.tmpl diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index b4ae54717..4014c0804 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -99,6 +99,7 @@ var ( Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), + BaseURL: viper.GetString("base-url"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -130,6 +131,7 @@ func init() { rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -145,6 +147,7 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) + _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index f2fcb531f..b4d2f0524 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-chi/chi/v5" @@ -25,11 +26,13 @@ type HTTPMcpHandler struct { t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc + oauthCfg *oauth.Config } type HTTPMcpHandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc + OAuthConfig *oauth.Config } type HTTPMcpHandlerOption func(*HTTPMcpHandlerOptions) @@ -46,6 +49,12 @@ func WithInventoryFactory(f InventoryFactoryFunc) HTTPMcpHandlerOption { } } +func WithOAuthConfig(cfg *oauth.Config) HTTPMcpHandlerOption { + return func(o *HTTPMcpHandlerOptions) { + o.OAuthConfig = cfg + } +} + func NewHTTPMcpHandler(cfg *HTTPServerConfig, deps github.ToolDependencies, t translations.TranslationHelperFunc, @@ -73,6 +82,7 @@ func NewHTTPMcpHandler(cfg *HTTPServerConfig, t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, + oauthCfg: opts.OAuthConfig, } } @@ -101,7 +111,7 @@ func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Stateless: true, }) - middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) + middleware.ExtractUserToken(h.oauthCfg)(mcpHandler).ServeHTTP(w, r) } func DefaultGitHubMCPServerFactory(ctx context.Context, _ *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index b73104c34..83bdb8fde 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -21,6 +21,11 @@ const ( // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. RealIPHeader = "X-Real-IP" + // ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying. + ForwardedHostHeader = "X-Forwarded-Host" + // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. + ForwardedProtoHeader = "X-Forwarded-Proto" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index c2e5c6382..93b93279e 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -10,6 +10,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" httpheaders "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/mark" + "github.com/github/github-mcp-server/pkg/http/oauth" ) type authType int @@ -40,14 +41,14 @@ var supportedThirdPartyTokenPrefixes = []string{ // were 40 characters long and only contained the characters a-f and 0-9. var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) -func ExtractUserToken() func(next http.Handler) http.Handler { +func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, token, err := parseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec if errors.Is(err, errMissingAuthorizationHeader) { - // sendAuthChallenge(w, r, cfg, obsv) + sendAuthChallenge(w, r, oauthCfg) return } // For other auth errors (bad format, unsupported), return 400 @@ -63,6 +64,15 @@ func ExtractUserToken() func(next http.Handler) http.Handler { }) } } + +// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header +// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. +func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, "mcp") + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { authHeader := req.Header.Get(httpheaders.AuthorizationHeader) if authHeader == "" { diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go new file mode 100644 index 000000000..a504d348e --- /dev/null +++ b/pkg/http/oauth/oauth.go @@ -0,0 +1,225 @@ +// Package oauth provides OAuth 2.0 Protected Resource Metadata (RFC 9728) support +// for the GitHub MCP Server HTTP mode. +package oauth + +import ( + "bytes" + _ "embed" + "fmt" + "html" + "net/http" + "strings" + "text/template" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" +) + +const ( + // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. + OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" + + // DefaultAuthorizationServer is GitHub's OAuth authorization server. + DefaultAuthorizationServer = "https://github.com/login/oauth" +) + +//go:embed protected_resource.json.tmpl +var protectedResourceTemplate []byte + +// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +var SupportedScopes = []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", +} + +// Config holds the OAuth configuration for the MCP server. +type Config struct { + // BaseURL is the publicly accessible URL where this server is hosted. + // This is used to construct the OAuth resource URL. + // Example: "https://mcp.example.com" + BaseURL string + + // AuthorizationServer is the OAuth authorization server URL. + // Defaults to GitHub's OAuth server if not specified. + AuthorizationServer string + + // ResourcePath is the resource path suffix (e.g., "/mcp"). + // If empty, defaults to "/" + ResourcePath string +} + +// ProtectedResourceData contains the data needed to build an OAuth protected resource response. +type ProtectedResourceData struct { + ResourceURL string + AuthorizationServer string +} + +// AuthHandler handles OAuth-related HTTP endpoints. +type AuthHandler struct { + cfg *Config + protectedResourceTemplate *template.Template +} + +// NewAuthHandler creates a new OAuth auth handler. +func NewAuthHandler(cfg *Config) (*AuthHandler, error) { + if cfg == nil { + cfg = &Config{} + } + + // Default authorization server to GitHub + if cfg.AuthorizationServer == "" { + cfg.AuthorizationServer = DefaultAuthorizationServer + } + + tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate)) + if err != nil { + return nil, fmt.Errorf("failed to parse protected resource template: %w", err) + } + + return &AuthHandler{ + cfg: cfg, + protectedResourceTemplate: tmpl, + }, nil +} + +// routePatterns defines the route patterns for OAuth protected resource metadata. +var routePatterns = []string{ + "", // Root: /.well-known/oauth-protected-resource + "/readonly", // Read-only mode + "/x/{toolset}", + "/x/{toolset}/readonly", +} + +// RegisterRoutes registers the OAuth protected resource metadata routes. +func (h *AuthHandler) RegisterRoutes(r chi.Router) { + for _, pattern := range routePatterns { + for _, route := range h.routesForPattern(pattern) { + path := OAuthProtectedResourcePrefix + route + r.Get(path, h.handleProtectedResource) + r.Options(path, h.handleProtectedResource) // CORS support + } + } +} + +// routesForPattern generates route variants for a given pattern. +func (h *AuthHandler) routesForPattern(pattern string) []string { + routes := []string{ + pattern, + pattern + "/", + pattern + "/mcp", + pattern + "/mcp/", + } + return routes +} + +// handleProtectedResource handles requests for OAuth protected resource metadata. +func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) { + // Extract the resource path from the URL + resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix) + if resourcePath == "" || resourcePath == "/" { + resourcePath = "/" + } else { + resourcePath = strings.TrimPrefix(resourcePath, "/") + } + + data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var buf bytes.Buffer + if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(buf.Bytes()) +} + +// GetProtectedResourceData builds the OAuth protected resource data for a request. +func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { + host, scheme := GetEffectiveHostAndScheme(r, h.cfg) + + // Build the resource URL + var resourceURL string + if h.cfg.BaseURL != "" { + // Use configured base URL + baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") + if resourcePath == "/" { + resourceURL = baseURL + "/" + } else { + resourceURL = baseURL + "/" + resourcePath + } + } else { + // Derive from request + if resourcePath == "/" { + resourceURL = fmt.Sprintf("%s://%s/", scheme, host) + } else { + resourceURL = fmt.Sprintf("%s://%s/%s", scheme, host, resourcePath) + } + } + + return &ProtectedResourceData{ + ResourceURL: resourceURL, + AuthorizationServer: h.cfg.AuthorizationServer, + }, nil +} + +// GetEffectiveHostAndScheme returns the effective host and scheme for a request. +// It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), +// then falls back to the request's Host and TLS state. +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { + // Check for forwarded headers first (typically set by reverse proxies) + if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { + host = forwardedHost + } else { + host = r.Host + } + + // Determine scheme + switch { + case r.Header.Get(headers.ForwardedProtoHeader) != "": + scheme = strings.ToLower(r.Header.Get(headers.ForwardedProtoHeader)) + case r.TLS != nil: + scheme = "https" + default: + // Default to HTTPS in production scenarios + scheme = "https" + } + + return host, scheme +} + +// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. +func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, cfg) + + if cfg != nil && cfg.BaseURL != "" { + baseURL := strings.TrimSuffix(cfg.BaseURL, "/") + return baseURL + OAuthProtectedResourcePrefix + "/" + strings.TrimPrefix(resourcePath, "/") + } + + path := OAuthProtectedResourcePrefix + if resourcePath != "" && resourcePath != "/" { + path = path + "/" + strings.TrimPrefix(resourcePath, "/") + } + + return fmt.Sprintf("%s://%s%s", scheme, host, path) +} diff --git a/pkg/http/oauth/protected_resource.json.tmpl b/pkg/http/oauth/protected_resource.json.tmpl new file mode 100644 index 000000000..7a9257404 --- /dev/null +++ b/pkg/http/oauth/protected_resource.json.tmpl @@ -0,0 +1,20 @@ +{ + "resource_name": "GitHub MCP Server", + "resource": "{{.ResourceURL}}", + "authorization_servers": ["{{.AuthorizationServer}}"], + "bearer_methods_supported": ["header"], + "scopes_supported": [ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace" + ] +} diff --git a/pkg/http/server.go b/pkg/http/server.go index ac9a35c08..d35e874c4 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" @@ -28,6 +29,11 @@ type HTTPServerConfig struct { // Port to listen on (default: 8082) Port int + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. + // Example: "https://mcp.example.com" + // If not set, the server will derive the URL from incoming request headers. + BaseURL string + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -97,7 +103,18 @@ func RunHTTPServer(cfg HTTPServerConfig) error { r := chi.NewRouter() - handler := NewHTTPMcpHandler(&cfg, deps, t, logger) + // Register OAuth protected resource metadata endpoints + oauthCfg := &oauth.Config{ + BaseURL: cfg.BaseURL, + } + oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create OAuth handler: %w", err) + } + oauthHandler.RegisterRoutes(r) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) + + handler := NewHTTPMcpHandler(&cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) handler.RegisterRoutes(r) addr := fmt.Sprintf(":%d", cfg.Port) From 0a1b70136ce0d266ad52c63fe103b6060abf1e9a Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 22 Jan 2026 12:23:38 +0000 Subject: [PATCH 02/38] add nolint for GetEffectiveHostAndScheme --- pkg/http/oauth/oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index a504d348e..9d75041de 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -185,7 +185,7 @@ func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath str // GetEffectiveHostAndScheme returns the effective host and scheme for a request. // It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), // then falls back to the request's Host and TLS state. -func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive // parameters are required by http.oauth.BuildResourceMetadataURL signature // Check for forwarded headers first (typically set by reverse proxies) if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { host = forwardedHost From afda19b709eba0f247862dcc97ad15b20554010d Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:04:47 +0000 Subject: [PATCH 03/38] remove CAPI reference --- pkg/http/headers/headers.go | 4 +++ pkg/http/oauth/oauth.go | 56 +++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index 83bdb8fde..1e0d3be47 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -26,6 +26,10 @@ const ( // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. ForwardedProtoHeader = "X-Forwarded-Proto" + // OriginalPathHeader is set to preserve the original request path + // before the /mcp prefix was stripped during proxying. + OriginalPathHeader = "X-GitHub-Original-Path" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 9d75041de..893012415 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -8,6 +8,7 @@ import ( "fmt" "html" "net/http" + "net/url" "strings" "text/template" @@ -112,14 +113,16 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { } // routesForPattern generates route variants for a given pattern. +// GitHub strips the /mcp prefix before forwarding, so we register both variants: +// - With /mcp prefix: for direct access or when GitHub doesn't strip +// - Without /mcp prefix: for when GitHub has stripped the prefix func (h *AuthHandler) routesForPattern(pattern string) []string { - routes := []string{ + return []string{ pattern, + "/mcp" + pattern, pattern + "/", - pattern + "/mcp", - pattern + "/mcp/", + "/mcp" + pattern + "/", } - return routes } // handleProtectedResource handles requests for OAuth protected resource metadata. @@ -153,26 +156,43 @@ func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Req _, _ = w.Write(buf.Bytes()) } +// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. +// It checks for the X-GitHub-Original-Path header set by copilot-api (CAPI), which contains +// the exact path the client requested before the /mcp prefix was stripped. +// If the header is not present (e.g., direct access or older CAPI versions), it falls back to +// restoring the /mcp prefix. +func GetEffectiveResourcePath(r *http.Request) string { + // Check for the original path header from copilot-api (preferred method) + if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { + return originalPath + } + + // Fallback: copilot-api strips /mcp prefix, so we need to restore it for the external URL + if r.URL.Path == "/" { + return "/mcp" + } + return "/mcp" + r.URL.Path +} + // GetProtectedResourceData builds the OAuth protected resource data for a request. func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { host, scheme := GetEffectiveHostAndScheme(r, h.cfg) - // Build the resource URL - var resourceURL string + // Build the base URL + baseURL := fmt.Sprintf("%s://%s", scheme, host) if h.cfg.BaseURL != "" { - // Use configured base URL - baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") - if resourcePath == "/" { - resourceURL = baseURL + "/" - } else { - resourceURL = baseURL + "/" + resourcePath - } + baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") + } + + // Build the resource URL using url.JoinPath for proper path handling + var resourceURL string + var err error + if resourcePath == "/" { + resourceURL = baseURL + "/" } else { - // Derive from request - if resourcePath == "/" { - resourceURL = fmt.Sprintf("%s://%s/", scheme, host) - } else { - resourceURL = fmt.Sprintf("%s://%s/%s", scheme, host, resourcePath) + resourceURL, err = url.JoinPath(baseURL, resourcePath) + if err != nil { + return nil, fmt.Errorf("failed to build resource URL: %w", err) } } From 97859a1282894e59ddcd833aafde97f75febbb55 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:07:53 +0000 Subject: [PATCH 04/38] remove nonsensical example URL --- pkg/http/oauth/oauth.go | 1 - pkg/http/server.go | 1 - 2 files changed, 2 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 893012415..ce20d05d6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -47,7 +47,6 @@ var SupportedScopes = []string{ type Config struct { // BaseURL is the publicly accessible URL where this server is hosted. // This is used to construct the OAuth resource URL. - // Example: "https://mcp.example.com" BaseURL string // AuthorizationServer is the OAuth authorization server URL. diff --git a/pkg/http/server.go b/pkg/http/server.go index d35e874c4..76adb8948 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -30,7 +30,6 @@ type HTTPServerConfig struct { Port int // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. - // Example: "https://mcp.example.com" // If not set, the server will derive the URL from incoming request headers. BaseURL string From f8f109cec611a57908bf57ccf2c0f886a1632e30 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:10:01 +0000 Subject: [PATCH 05/38] anonymize --- pkg/http/oauth/oauth.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index ce20d05d6..f24db6786 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -156,17 +156,17 @@ func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Req } // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. -// It checks for the X-GitHub-Original-Path header set by copilot-api (CAPI), which contains +// It checks for the X-GitHub-Original-Path header set by GitHub, which contains // the exact path the client requested before the /mcp prefix was stripped. -// If the header is not present (e.g., direct access or older CAPI versions), it falls back to +// If the header is not present, it falls back to // restoring the /mcp prefix. func GetEffectiveResourcePath(r *http.Request) string { - // Check for the original path header from copilot-api (preferred method) + // Check for the original path header from GitHub (preferred method) if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { return originalPath } - // Fallback: copilot-api strips /mcp prefix, so we need to restore it for the external URL + // Fallback: GitHub strips /mcp prefix, so we need to restore it for the external URL if r.URL.Path == "/" { return "/mcp" } From 9f308b349a8b4d46f69110a0753784f5525eebaf Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 15:54:11 +0000 Subject: [PATCH 06/38] add oauth tests --- pkg/http/oauth/oauth_test.go | 677 +++++++++++++++++++++++++++++++++++ 1 file changed, 677 insertions(+) create mode 100644 pkg/http/oauth/oauth_test.go diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go new file mode 100644 index 000000000..035f5c35b --- /dev/null +++ b/pkg/http/oauth/oauth_test.go @@ -0,0 +1,677 @@ +package oauth + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + expectedAuthServer string + expectedResourcePath string + }{ + { + name: "nil config uses defaults", + cfg: nil, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "empty config uses defaults", + cfg: &Config{}, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://custom.example.com/oauth", + }, + expectedAuthServer: "https://custom.example.com/oauth", + expectedResourcePath: "", + }, + { + name: "custom base URL and resource path", + cfg: &Config{ + BaseURL: "https://example.com", + ResourcePath: "/mcp", + }, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "/mcp", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + require.NotNil(t, handler) + + assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + assert.NotNil(t, handler.protectedResourceTemplate) + }) + } +} + +func TestGetEffectiveHostAndScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + cfg *Config + expectedHost string + expectedScheme string + }{ + { + name: "basic request without forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", // defaults to https + }, + { + name: "request with X-Forwarded-Host header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with X-Forwarded-Proto header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "request with both forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "scheme is lowercased", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "HTTPS") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + host, scheme := GetEffectiveHostAndScheme(req, tc.cfg) + + assert.Equal(t, tc.expectedHost, host) + assert.Equal(t, tc.expectedScheme, scheme) + }) + } +} + +func TestGetEffectiveResourcePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + expectedPath string + }{ + { + name: "root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + return req + }, + expectedPath: "/mcp", + }, + { + name: "non-root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + return req + }, + expectedPath: "/mcp/readonly", + }, + { + name: "with X-GitHub-Original-Path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/x/repos/readonly") + return req + }, + expectedPath: "/mcp/x/repos/readonly", + }, + { + name: "original path header takes precedence", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/something-else", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/custom/path") + return req + }, + expectedPath: "/mcp/custom/path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + path := GetEffectiveResourcePath(req) + + assert.Equal(t, tc.expectedPath, path) + }) + } +} + +func TestGetProtectedResourceData(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedResourceURL string + expectedAuthServer string + expectError bool + }{ + { + name: "basic request with root resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedResourceURL: "https://api.example.com/", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "basic request with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom base URL", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://auth.example.com/oauth", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: "https://auth.example.com/oauth", + }, + { + name: "base URL with trailing slash is trimmed", + cfg: &Config{ + BaseURL: "https://custom.example.com/", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "nested resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/x/repos", + expectedResourceURL: "https://api.example.com/mcp/x/repos", + expectedAuthServer: DefaultAuthorizationServer, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + req := tc.setupRequest() + data, err := handler.GetProtectedResourceData(req, tc.resourcePath) + + if tc.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedResourceURL, data.ResourceURL) + assert.Equal(t, tc.expectedAuthServer, data.AuthorizationServer) + }) + } +} + +func TestBuildResourceMetadataURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedURL string + }{ + { + name: "root path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + { + name: "with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with base URL config", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://custom.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with forwarded headers", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + resourcePath: "/mcp", + expectedURL: "https://public.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "nil config uses request host", + cfg: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + url := BuildResourceMetadataURL(req, tc.cfg, tc.resourcePath) + + assert.Equal(t, tc.expectedURL, url) + }) + } +} + +func TestHandleProtectedResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + path string + host string + method string + expectedStatusCode int + expectedScopes []string + validateResponse func(t *testing.T, body map[string]any) + }{ + { + name: "GET request returns protected resource metadata", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + expectedScopes: SupportedScopes, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "GitHub MCP Server", body["resource_name"]) + assert.Contains(t, body["resource"], "api.example.com") + + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + }, + }, + { + name: "OPTIONS request for CORS", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodOptions, + expectedStatusCode: http.StatusOK, + }, + { + name: "path with /mcp suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/mcp", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/mcp") + }, + }, + { + name: "path with /readonly suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/readonly", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/readonly") + }, + }, + { + name: "custom authorization server in response", + cfg: &Config{ + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, "https://custom.auth.example.com/oauth", authServers[0]) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(tc.method, tc.path, nil) + req.Host = tc.host + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatusCode, rec.Code) + + // Check CORS headers + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "OPTIONS") + + if tc.method == http.MethodGet && tc.validateResponse != nil { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + + tc.validateResponse(t, body) + + // Verify scopes if expected + if tc.expectedScopes != nil { + scopes, ok := body["scopes_supported"].([]any) + require.True(t, ok) + assert.Len(t, scopes, len(tc.expectedScopes)) + } + } + }) + } +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + // List of expected routes that should be registered + expectedRoutes := []string{ + OAuthProtectedResourcePrefix, + OAuthProtectedResourcePrefix + "/", + OAuthProtectedResourcePrefix + "/mcp", + OAuthProtectedResourcePrefix + "/mcp/", + OAuthProtectedResourcePrefix + "/readonly", + OAuthProtectedResourcePrefix + "/readonly/", + OAuthProtectedResourcePrefix + "/mcp/readonly", + OAuthProtectedResourcePrefix + "/mcp/readonly/", + OAuthProtectedResourcePrefix + "/x/repos", + OAuthProtectedResourcePrefix + "/mcp/x/repos", + } + + for _, route := range expectedRoutes { + t.Run("route:"+route, func(t *testing.T) { + // Test GET + req := httptest.NewRequest(http.MethodGet, route, nil) + req.Host = "api.example.com" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) + + // Test OPTIONS (CORS) + req = httptest.NewRequest(http.MethodOptions, route, nil) + req.Host = "api.example.com" + rec = httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route) + }) + } +} + +func TestSupportedScopes(t *testing.T) { + t.Parallel() + + // Verify all expected scopes are present + expectedScopes := []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", + } + + assert.Equal(t, expectedScopes, SupportedScopes) +} + +func TestProtectedResourceResponseFormat(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all required RFC 9728 fields are present + assert.Contains(t, response, "resource") + assert.Contains(t, response, "authorization_servers") + assert.Contains(t, response, "bearer_methods_supported") + assert.Contains(t, response, "scopes_supported") + + // Verify resource name (optional but we include it) + assert.Contains(t, response, "resource_name") + assert.Equal(t, "GitHub MCP Server", response["resource_name"]) + + // Verify bearer_methods_supported contains "header" + bearerMethods, ok := response["bearer_methods_supported"].([]any) + require.True(t, ok) + assert.Contains(t, bearerMethods, "header") + + // Verify authorization_servers is an array with GitHub OAuth + authServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + assert.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) +} + +func TestOAuthProtectedResourcePrefix(t *testing.T) { + t.Parallel() + + // RFC 9728 specifies this well-known path + assert.Equal(t, "/.well-known/oauth-protected-resource", OAuthProtectedResourcePrefix) +} + +func TestDefaultAuthorizationServer(t *testing.T) { + t.Parallel() + + assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) +} From 50227bf780d41289b47949a087cb2ed4d513d842 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 26 Jan 2026 15:53:58 +0000 Subject: [PATCH 07/38] replace custom protected resource metadata handler with our own --- pkg/http/oauth/oauth.go | 75 +++++++-------------- pkg/http/oauth/oauth_test.go | 46 ++++++++----- pkg/http/oauth/protected_resource.json.tmpl | 20 ------ 3 files changed, 55 insertions(+), 86 deletions(-) delete mode 100644 pkg/http/oauth/protected_resource.json.tmpl diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index f24db6786..7710dd581 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,17 +3,16 @@ package oauth import ( - "bytes" - _ "embed" "fmt" - "html" "net/http" "net/url" "strings" - "text/template" + + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/oauthex" ) const ( @@ -24,9 +23,6 @@ const ( DefaultAuthorizationServer = "https://github.com/login/oauth" ) -//go:embed protected_resource.json.tmpl -var protectedResourceTemplate []byte - // SupportedScopes lists all OAuth scopes that may be required by MCP tools. var SupportedScopes = []string{ "repo", @@ -66,8 +62,7 @@ type ProtectedResourceData struct { // AuthHandler handles OAuth-related HTTP endpoints. type AuthHandler struct { - cfg *Config - protectedResourceTemplate *template.Template + cfg *Config } // NewAuthHandler creates a new OAuth auth handler. @@ -81,14 +76,8 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { cfg.AuthorizationServer = DefaultAuthorizationServer } - tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate)) - if err != nil { - return nil, fmt.Errorf("failed to parse protected resource template: %w", err) - } - return &AuthHandler{ - cfg: cfg, - protectedResourceTemplate: tmpl, + cfg: cfg, }, nil } @@ -96,6 +85,7 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { var routePatterns = []string{ "", // Root: /.well-known/oauth-protected-resource "/readonly", // Read-only mode + "/insiders", // Insiders mode "/x/{toolset}", "/x/{toolset}/readonly", } @@ -105,12 +95,30 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { for _, pattern := range routePatterns { for _, route := range h.routesForPattern(pattern) { path := OAuthProtectedResourcePrefix + route - r.Get(path, h.handleProtectedResource) - r.Options(path, h.handleProtectedResource) // CORS support + + // Build metadata for this specific resource path + metadata := h.buildMetadata(route) + r.Handle(path, auth.ProtectedResourceMetadataHandler(metadata)) } } } +func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResourceMetadata { + baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") + resourceURL := baseURL + if resourcePath != "" && resourcePath != "/" { + resourceURL = baseURL + resourcePath + } + + return &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{h.cfg.AuthorizationServer}, + ResourceName: "GitHub MCP Server", + ScopesSupported: SupportedScopes, + BearerMethodsSupported: []string{"header"}, + } +} + // routesForPattern generates route variants for a given pattern. // GitHub strips the /mcp prefix before forwarding, so we register both variants: // - With /mcp prefix: for direct access or when GitHub doesn't strip @@ -124,37 +132,6 @@ func (h *AuthHandler) routesForPattern(pattern string) []string { } } -// handleProtectedResource handles requests for OAuth protected resource metadata. -func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) { - // Extract the resource path from the URL - resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix) - if resourcePath == "" || resourcePath == "/" { - resourcePath = "/" - } else { - resourcePath = strings.TrimPrefix(resourcePath, "/") - } - - data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - var buf bytes.Buffer - if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - // Set CORS headers - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(buf.Bytes()) -} - // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. // It checks for the X-GitHub-Original-Path header set by GitHub, which contains // the exact path the client requested before the /mcp prefix was stripped. diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 035f5c35b..2bff37363 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -62,7 +62,6 @@ func TestNewAuthHandler(t *testing.T) { require.NotNil(t, handler) assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) - assert.NotNil(t, handler.protectedResourceTemplate) }) } } @@ -444,8 +443,10 @@ func TestHandleProtectedResource(t *testing.T) { validateResponse func(t *testing.T, body map[string]any) }{ { - name: "GET request returns protected resource metadata", - cfg: &Config{}, + name: "GET request returns protected resource metadata", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix, host: "api.example.com", method: http.MethodGet, @@ -454,7 +455,7 @@ func TestHandleProtectedResource(t *testing.T) { validateResponse: func(t *testing.T, body map[string]any) { t.Helper() assert.Equal(t, "GitHub MCP Server", body["resource_name"]) - assert.Contains(t, body["resource"], "api.example.com") + assert.Equal(t, "https://api.example.com", body["resource"]) authServers, ok := body["authorization_servers"].([]any) require.True(t, ok) @@ -463,40 +464,47 @@ func TestHandleProtectedResource(t *testing.T) { }, }, { - name: "OPTIONS request for CORS", - cfg: &Config{}, + name: "OPTIONS request for CORS preflight", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix, host: "api.example.com", method: http.MethodOptions, - expectedStatusCode: http.StatusOK, + expectedStatusCode: http.StatusNoContent, }, { - name: "path with /mcp suffix", - cfg: &Config{}, + name: "path with /mcp suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix + "/mcp", host: "api.example.com", method: http.MethodGet, expectedStatusCode: http.StatusOK, validateResponse: func(t *testing.T, body map[string]any) { t.Helper() - assert.Contains(t, body["resource"], "/mcp") + assert.Equal(t, "https://api.example.com/mcp", body["resource"]) }, }, { - name: "path with /readonly suffix", - cfg: &Config{}, + name: "path with /readonly suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix + "/readonly", host: "api.example.com", method: http.MethodGet, expectedStatusCode: http.StatusOK, validateResponse: func(t *testing.T, body map[string]any) { t.Helper() - assert.Contains(t, body["resource"], "/readonly") + assert.Equal(t, "https://api.example.com/readonly", body["resource"]) }, }, { name: "custom authorization server in response", cfg: &Config{ + BaseURL: "https://api.example.com", AuthorizationServer: "https://custom.auth.example.com/oauth", }, path: OAuthProtectedResourcePrefix, @@ -559,7 +567,9 @@ func TestHandleProtectedResource(t *testing.T) { func TestRegisterRoutes(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(&Config{}) + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) require.NoError(t, err) router := chi.NewRouter() @@ -588,12 +598,12 @@ func TestRegisterRoutes(t *testing.T) { router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) - // Test OPTIONS (CORS) + // Test OPTIONS (CORS preflight) req = httptest.NewRequest(http.MethodOptions, route, nil) req.Host = "api.example.com" rec = httptest.NewRecorder() router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route) + assert.Equal(t, http.StatusNoContent, rec.Code, "OPTIONS %s should return 204", route) }) } } @@ -623,7 +633,9 @@ func TestSupportedScopes(t *testing.T) { func TestProtectedResourceResponseFormat(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(&Config{}) + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) require.NoError(t, err) router := chi.NewRouter() diff --git a/pkg/http/oauth/protected_resource.json.tmpl b/pkg/http/oauth/protected_resource.json.tmpl deleted file mode 100644 index 7a9257404..000000000 --- a/pkg/http/oauth/protected_resource.json.tmpl +++ /dev/null @@ -1,20 +0,0 @@ -{ - "resource_name": "GitHub MCP Server", - "resource": "{{.ResourceURL}}", - "authorization_servers": ["{{.AuthorizationServer}}"], - "bearer_methods_supported": ["header"], - "scopes_supported": [ - "repo", - "read:org", - "read:user", - "user:email", - "read:packages", - "write:packages", - "read:project", - "project", - "gist", - "notifications", - "workflow", - "codespace" - ] -} From a3135d9fa2b4a6aaeae876c8ac36406bddcefce6 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 26 Jan 2026 16:35:08 +0000 Subject: [PATCH 08/38] remove unused header --- pkg/http/middleware/token.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 93b93279e..e09026f24 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -25,7 +25,6 @@ var ( errMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) errBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) errUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) - errMissingTokenInfoHeader = fmt.Errorf("%w: missing required token info header", mark.ErrBadRequest) ) var supportedThirdPartyTokenPrefixes = []string{ From 1ce01df6f42c2e96b6645c90b2d11d7a263c2f75 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 26 Jan 2026 16:39:03 +0000 Subject: [PATCH 09/38] Update pkg/http/oauth/oauth.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/http/oauth/oauth.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 7710dd581..8ef26d9c6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/modelcontextprotocol/go-sdk/auth" - "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/oauthex" From 4fc6c3aa2d56b8c7600c7a447b8a12fdd4653bbb Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 13:25:17 +0000 Subject: [PATCH 10/38] pass oauth config to mcp handler for token extraction --- pkg/http/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/http/server.go b/pkg/http/server.go index 2ff942d80..02e74b352 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -114,7 +114,7 @@ func RunHTTPServer(cfg HTTPServerConfig) error { oauthHandler.RegisterRoutes(r) logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger) + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) handler.RegisterRoutes(r) addr := fmt.Sprintf(":%d", cfg.Port) From b0bddbfd58ecb2103cdefd5dd40704b565280abb Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 14:04:17 +0000 Subject: [PATCH 11/38] chore: retrigger ci From 6c5102abe2ba30b89ef32b88d50be8849a47c67e Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 14:13:08 +0000 Subject: [PATCH 12/38] align types with base branch --- pkg/http/handler.go | 34 +++++++++++++++++----------------- pkg/http/server.go | 4 ++-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 671ebe3a0..66cdf44ca 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -19,9 +19,9 @@ import ( type InventoryFactoryFunc func(r *http.Request) (*inventory.Inventory, error) type GitHubMCPServerFactoryFunc func(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) -type HTTPMcpHandler struct { +type Handler struct { ctx context.Context - config *HTTPServerConfig + config *ServerConfig deps github.ToolDependencies logger *slog.Logger t translations.TranslationHelperFunc @@ -30,40 +30,40 @@ type HTTPMcpHandler struct { oauthCfg *oauth.Config } -type HTTPMcpHandlerOptions struct { +type HandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc OAuthConfig *oauth.Config } -type HTTPMcpHandlerOption func(*HTTPMcpHandlerOptions) +type HandlerOption func(*HandlerOptions) -func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HTTPMcpHandlerOption { - return func(o *HTTPMcpHandlerOptions) { +func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HandlerOption { + return func(o *HandlerOptions) { o.GitHubMcpServerFactory = f } } -func WithInventoryFactory(f InventoryFactoryFunc) HTTPMcpHandlerOption { - return func(o *HTTPMcpHandlerOptions) { +func WithInventoryFactory(f InventoryFactoryFunc) HandlerOption { + return func(o *HandlerOptions) { o.InventoryFactory = f } } -func WithOAuthConfig(cfg *oauth.Config) HTTPMcpHandlerOption { - return func(o *HTTPMcpHandlerOptions) { +func WithOAuthConfig(cfg *oauth.Config) HandlerOption { + return func(o *HandlerOptions) { o.OAuthConfig = cfg } } func NewHTTPMcpHandler( ctx context.Context, - cfg *HTTPServerConfig, + cfg *ServerConfig, deps github.ToolDependencies, t translations.TranslationHelperFunc, logger *slog.Logger, - options ...HTTPMcpHandlerOption) *HTTPMcpHandler { - opts := &HTTPMcpHandlerOptions{} + options ...HandlerOption) *Handler { + opts := &HandlerOptions{} for _, o := range options { o(opts) } @@ -78,7 +78,7 @@ func NewHTTPMcpHandler( inventoryFactory = DefaultInventoryFactory(cfg, t, nil) } - return &HTTPMcpHandler{ + return &Handler{ ctx: ctx, config: cfg, deps: deps, @@ -92,7 +92,7 @@ func NewHTTPMcpHandler( // RegisterRoutes registers the routes for the MCP server // URL-based values take precedence over header-based values -func (h *HTTPMcpHandler) RegisterRoutes(r chi.Router) { +func (h *Handler) RegisterRoutes(r chi.Router) { r.Use(middleware.WithRequestConfig) r.Mount("/", h) @@ -119,7 +119,7 @@ func withToolset(next http.Handler) http.Handler { }) } -func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { inventory, err := h.inventoryFactoryFunc(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -157,7 +157,7 @@ func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies }, deps, inventory) } -func DefaultInventoryFactory(cfg *HTTPServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker) InventoryFactoryFunc { +func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker) InventoryFactoryFunc { return func(r *http.Request) (*inventory.Inventory, error) { b := github.NewInventory(t).WithDeprecatedAliases(github.DeprecatedToolAliases) diff --git a/pkg/http/server.go b/pkg/http/server.go index 02e74b352..55bd611a5 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -19,7 +19,7 @@ import ( "github.com/go-chi/chi/v5" ) -type HTTPServerConfig struct { +type ServerConfig struct { // Version of the server Version string @@ -53,7 +53,7 @@ type HTTPServerConfig struct { RepoAccessCacheTTL *time.Duration } -func RunHTTPServer(cfg HTTPServerConfig) error { +func RunHTTPServer(cfg ServerConfig) error { // Create app context ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() From 3daa5c3d8c7bb465f1c434b8603889390ba92713 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 14:13:35 +0000 Subject: [PATCH 13/38] update more types --- cmd/github-mcp-server/main.go | 2 +- pkg/http/oauth/oauth.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 7dd5cb3db..5c77b51e1 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -96,7 +96,7 @@ var ( Short: "Start HTTP server", Long: `Start an HTTP server that listens for MCP requests over HTTP.`, RunE: func(_ *cobra.Command, _ []string) error { - httpConfig := ghhttp.HTTPServerConfig{ + httpConfig := ghhttp.ServerConfig{ Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 8ef26d9c6..8934a21c6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -8,9 +8,9 @@ import ( "net/url" "strings" - "github.com/modelcontextprotocol/go-sdk/auth" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" ) From e3c565a86721ae9d140ced1d249ea45b7debbf89 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 22 Jan 2026 11:53:08 +0000 Subject: [PATCH 14/38] initial oauth metadata implementation --- cmd/github-mcp-server/main.go | 3 + pkg/http/handler.go | 12 +- pkg/http/headers/headers.go | 5 + pkg/http/middleware/token.go | 14 +- pkg/http/oauth/oauth.go | 225 ++++++++++++++++++++ pkg/http/oauth/protected_resource.json.tmpl | 20 ++ pkg/http/server.go | 19 +- 7 files changed, 294 insertions(+), 4 deletions(-) create mode 100644 pkg/http/oauth/oauth.go create mode 100644 pkg/http/oauth/protected_resource.json.tmpl diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 7ff132229..fbca6ccff 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -101,6 +101,7 @@ var ( Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), + BaseURL: viper.GetString("base-url"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -135,6 +136,7 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -151,6 +153,7 @@ func init() { _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) + _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 5b889b36b..3a14bd624 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-chi/chi/v5" @@ -26,11 +27,13 @@ type Handler struct { t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc + oauthCfg *oauth.Config } type HandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc + OAuthConfig *oauth.Config } type HandlerOption func(*HandlerOptions) @@ -47,6 +50,12 @@ func WithInventoryFactory(f InventoryFactoryFunc) HandlerOption { } } +func WithOAuthConfig(cfg *oauth.Config) HandlerOption { + return func(o *HandlerOptions) { + o.OAuthConfig = cfg + } +} + func NewHTTPMcpHandler( ctx context.Context, cfg *ServerConfig, @@ -77,6 +86,7 @@ func NewHTTPMcpHandler( t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, + oauthCfg: opts.OAuthConfig, } } @@ -134,7 +144,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Stateless: true, }) - middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) + middleware.ExtractUserToken(h.oauthCfg)(mcpHandler).ServeHTTP(w, r) } func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index a1580cf96..8c389c828 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -21,6 +21,11 @@ const ( // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. RealIPHeader = "X-Real-IP" + // ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying. + ForwardedHostHeader = "X-Forwarded-Host" + // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. + ForwardedProtoHeader = "X-Forwarded-Proto" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 6369abf14..e09026f24 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -10,6 +10,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" httpheaders "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/mark" + "github.com/github/github-mcp-server/pkg/http/oauth" ) type authType int @@ -39,14 +40,14 @@ var supportedThirdPartyTokenPrefixes = []string{ // were 40 characters long and only contained the characters a-f and 0-9. var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) -func ExtractUserToken() func(next http.Handler) http.Handler { +func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, token, err := parseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec if errors.Is(err, errMissingAuthorizationHeader) { - // sendAuthChallenge(w, r, cfg, obsv) + sendAuthChallenge(w, r, oauthCfg) return } // For other auth errors (bad format, unsupported), return 400 @@ -62,6 +63,15 @@ func ExtractUserToken() func(next http.Handler) http.Handler { }) } } + +// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header +// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. +func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, "mcp") + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { authHeader := req.Header.Get(httpheaders.AuthorizationHeader) if authHeader == "" { diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go new file mode 100644 index 000000000..a504d348e --- /dev/null +++ b/pkg/http/oauth/oauth.go @@ -0,0 +1,225 @@ +// Package oauth provides OAuth 2.0 Protected Resource Metadata (RFC 9728) support +// for the GitHub MCP Server HTTP mode. +package oauth + +import ( + "bytes" + _ "embed" + "fmt" + "html" + "net/http" + "strings" + "text/template" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" +) + +const ( + // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. + OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" + + // DefaultAuthorizationServer is GitHub's OAuth authorization server. + DefaultAuthorizationServer = "https://github.com/login/oauth" +) + +//go:embed protected_resource.json.tmpl +var protectedResourceTemplate []byte + +// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +var SupportedScopes = []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", +} + +// Config holds the OAuth configuration for the MCP server. +type Config struct { + // BaseURL is the publicly accessible URL where this server is hosted. + // This is used to construct the OAuth resource URL. + // Example: "https://mcp.example.com" + BaseURL string + + // AuthorizationServer is the OAuth authorization server URL. + // Defaults to GitHub's OAuth server if not specified. + AuthorizationServer string + + // ResourcePath is the resource path suffix (e.g., "/mcp"). + // If empty, defaults to "/" + ResourcePath string +} + +// ProtectedResourceData contains the data needed to build an OAuth protected resource response. +type ProtectedResourceData struct { + ResourceURL string + AuthorizationServer string +} + +// AuthHandler handles OAuth-related HTTP endpoints. +type AuthHandler struct { + cfg *Config + protectedResourceTemplate *template.Template +} + +// NewAuthHandler creates a new OAuth auth handler. +func NewAuthHandler(cfg *Config) (*AuthHandler, error) { + if cfg == nil { + cfg = &Config{} + } + + // Default authorization server to GitHub + if cfg.AuthorizationServer == "" { + cfg.AuthorizationServer = DefaultAuthorizationServer + } + + tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate)) + if err != nil { + return nil, fmt.Errorf("failed to parse protected resource template: %w", err) + } + + return &AuthHandler{ + cfg: cfg, + protectedResourceTemplate: tmpl, + }, nil +} + +// routePatterns defines the route patterns for OAuth protected resource metadata. +var routePatterns = []string{ + "", // Root: /.well-known/oauth-protected-resource + "/readonly", // Read-only mode + "/x/{toolset}", + "/x/{toolset}/readonly", +} + +// RegisterRoutes registers the OAuth protected resource metadata routes. +func (h *AuthHandler) RegisterRoutes(r chi.Router) { + for _, pattern := range routePatterns { + for _, route := range h.routesForPattern(pattern) { + path := OAuthProtectedResourcePrefix + route + r.Get(path, h.handleProtectedResource) + r.Options(path, h.handleProtectedResource) // CORS support + } + } +} + +// routesForPattern generates route variants for a given pattern. +func (h *AuthHandler) routesForPattern(pattern string) []string { + routes := []string{ + pattern, + pattern + "/", + pattern + "/mcp", + pattern + "/mcp/", + } + return routes +} + +// handleProtectedResource handles requests for OAuth protected resource metadata. +func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) { + // Extract the resource path from the URL + resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix) + if resourcePath == "" || resourcePath == "/" { + resourcePath = "/" + } else { + resourcePath = strings.TrimPrefix(resourcePath, "/") + } + + data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var buf bytes.Buffer + if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Set CORS headers + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(buf.Bytes()) +} + +// GetProtectedResourceData builds the OAuth protected resource data for a request. +func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { + host, scheme := GetEffectiveHostAndScheme(r, h.cfg) + + // Build the resource URL + var resourceURL string + if h.cfg.BaseURL != "" { + // Use configured base URL + baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") + if resourcePath == "/" { + resourceURL = baseURL + "/" + } else { + resourceURL = baseURL + "/" + resourcePath + } + } else { + // Derive from request + if resourcePath == "/" { + resourceURL = fmt.Sprintf("%s://%s/", scheme, host) + } else { + resourceURL = fmt.Sprintf("%s://%s/%s", scheme, host, resourcePath) + } + } + + return &ProtectedResourceData{ + ResourceURL: resourceURL, + AuthorizationServer: h.cfg.AuthorizationServer, + }, nil +} + +// GetEffectiveHostAndScheme returns the effective host and scheme for a request. +// It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), +// then falls back to the request's Host and TLS state. +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { + // Check for forwarded headers first (typically set by reverse proxies) + if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { + host = forwardedHost + } else { + host = r.Host + } + + // Determine scheme + switch { + case r.Header.Get(headers.ForwardedProtoHeader) != "": + scheme = strings.ToLower(r.Header.Get(headers.ForwardedProtoHeader)) + case r.TLS != nil: + scheme = "https" + default: + // Default to HTTPS in production scenarios + scheme = "https" + } + + return host, scheme +} + +// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. +func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, cfg) + + if cfg != nil && cfg.BaseURL != "" { + baseURL := strings.TrimSuffix(cfg.BaseURL, "/") + return baseURL + OAuthProtectedResourcePrefix + "/" + strings.TrimPrefix(resourcePath, "/") + } + + path := OAuthProtectedResourcePrefix + if resourcePath != "" && resourcePath != "/" { + path = path + "/" + strings.TrimPrefix(resourcePath, "/") + } + + return fmt.Sprintf("%s://%s%s", scheme, host, path) +} diff --git a/pkg/http/oauth/protected_resource.json.tmpl b/pkg/http/oauth/protected_resource.json.tmpl new file mode 100644 index 000000000..7a9257404 --- /dev/null +++ b/pkg/http/oauth/protected_resource.json.tmpl @@ -0,0 +1,20 @@ +{ + "resource_name": "GitHub MCP Server", + "resource": "{{.ResourceURL}}", + "authorization_servers": ["{{.AuthorizationServer}}"], + "bearer_methods_supported": ["header"], + "scopes_supported": [ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace" + ] +} diff --git a/pkg/http/server.go b/pkg/http/server.go index 33fe23d14..c6047b43c 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" @@ -28,6 +29,11 @@ type ServerConfig struct { // Port to listen on (default: 8082) Port int + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. + // Example: "https://mcp.example.com" + // If not set, the server will derive the URL from incoming request headers. + BaseURL string + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -95,7 +101,18 @@ func RunHTTPServer(cfg ServerConfig) error { r := chi.NewRouter() - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger) + // Register OAuth protected resource metadata endpoints + oauthCfg := &oauth.Config{ + BaseURL: cfg.BaseURL, + } + oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create OAuth handler: %w", err) + } + oauthHandler.RegisterRoutes(r) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) + + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) handler.RegisterRoutes(r) addr := fmt.Sprintf(":%d", cfg.Port) From f768eda6327a26a763e465a9ad93044c022e8fbc Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 22 Jan 2026 12:23:38 +0000 Subject: [PATCH 15/38] add nolint for GetEffectiveHostAndScheme --- pkg/http/oauth/oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index a504d348e..9d75041de 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -185,7 +185,7 @@ func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath str // GetEffectiveHostAndScheme returns the effective host and scheme for a request. // It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), // then falls back to the request's Host and TLS state. -func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive // parameters are required by http.oauth.BuildResourceMetadataURL signature // Check for forwarded headers first (typically set by reverse proxies) if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { host = forwardedHost From 68e1f5002278d36b02db0d5b6c6b1ebbecbf2bb0 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:04:47 +0000 Subject: [PATCH 16/38] remove CAPI reference --- pkg/http/headers/headers.go | 4 +++ pkg/http/oauth/oauth.go | 56 +++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index 8c389c828..c9846fa97 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -26,6 +26,10 @@ const ( // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. ForwardedProtoHeader = "X-Forwarded-Proto" + // OriginalPathHeader is set to preserve the original request path + // before the /mcp prefix was stripped during proxying. + OriginalPathHeader = "X-GitHub-Original-Path" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 9d75041de..893012415 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -8,6 +8,7 @@ import ( "fmt" "html" "net/http" + "net/url" "strings" "text/template" @@ -112,14 +113,16 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { } // routesForPattern generates route variants for a given pattern. +// GitHub strips the /mcp prefix before forwarding, so we register both variants: +// - With /mcp prefix: for direct access or when GitHub doesn't strip +// - Without /mcp prefix: for when GitHub has stripped the prefix func (h *AuthHandler) routesForPattern(pattern string) []string { - routes := []string{ + return []string{ pattern, + "/mcp" + pattern, pattern + "/", - pattern + "/mcp", - pattern + "/mcp/", + "/mcp" + pattern + "/", } - return routes } // handleProtectedResource handles requests for OAuth protected resource metadata. @@ -153,26 +156,43 @@ func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Req _, _ = w.Write(buf.Bytes()) } +// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. +// It checks for the X-GitHub-Original-Path header set by copilot-api (CAPI), which contains +// the exact path the client requested before the /mcp prefix was stripped. +// If the header is not present (e.g., direct access or older CAPI versions), it falls back to +// restoring the /mcp prefix. +func GetEffectiveResourcePath(r *http.Request) string { + // Check for the original path header from copilot-api (preferred method) + if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { + return originalPath + } + + // Fallback: copilot-api strips /mcp prefix, so we need to restore it for the external URL + if r.URL.Path == "/" { + return "/mcp" + } + return "/mcp" + r.URL.Path +} + // GetProtectedResourceData builds the OAuth protected resource data for a request. func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { host, scheme := GetEffectiveHostAndScheme(r, h.cfg) - // Build the resource URL - var resourceURL string + // Build the base URL + baseURL := fmt.Sprintf("%s://%s", scheme, host) if h.cfg.BaseURL != "" { - // Use configured base URL - baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") - if resourcePath == "/" { - resourceURL = baseURL + "/" - } else { - resourceURL = baseURL + "/" + resourcePath - } + baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") + } + + // Build the resource URL using url.JoinPath for proper path handling + var resourceURL string + var err error + if resourcePath == "/" { + resourceURL = baseURL + "/" } else { - // Derive from request - if resourcePath == "/" { - resourceURL = fmt.Sprintf("%s://%s/", scheme, host) - } else { - resourceURL = fmt.Sprintf("%s://%s/%s", scheme, host, resourcePath) + resourceURL, err = url.JoinPath(baseURL, resourcePath) + if err != nil { + return nil, fmt.Errorf("failed to build resource URL: %w", err) } } From 67b821c044c3de89abbebfa0ad59f0857816d584 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:07:53 +0000 Subject: [PATCH 17/38] remove nonsensical example URL --- pkg/http/oauth/oauth.go | 1 - pkg/http/server.go | 1 - 2 files changed, 2 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 893012415..ce20d05d6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -47,7 +47,6 @@ var SupportedScopes = []string{ type Config struct { // BaseURL is the publicly accessible URL where this server is hosted. // This is used to construct the OAuth resource URL. - // Example: "https://mcp.example.com" BaseURL string // AuthorizationServer is the OAuth authorization server URL. diff --git a/pkg/http/server.go b/pkg/http/server.go index c6047b43c..180cb75b5 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -30,7 +30,6 @@ type ServerConfig struct { Port int // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. - // Example: "https://mcp.example.com" // If not set, the server will derive the URL from incoming request headers. BaseURL string From 7c9005091660d144804b532fa47e820b2165c694 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 11:10:01 +0000 Subject: [PATCH 18/38] anonymize --- pkg/http/oauth/oauth.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index ce20d05d6..f24db6786 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -156,17 +156,17 @@ func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Req } // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. -// It checks for the X-GitHub-Original-Path header set by copilot-api (CAPI), which contains +// It checks for the X-GitHub-Original-Path header set by GitHub, which contains // the exact path the client requested before the /mcp prefix was stripped. -// If the header is not present (e.g., direct access or older CAPI versions), it falls back to +// If the header is not present, it falls back to // restoring the /mcp prefix. func GetEffectiveResourcePath(r *http.Request) string { - // Check for the original path header from copilot-api (preferred method) + // Check for the original path header from GitHub (preferred method) if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { return originalPath } - // Fallback: copilot-api strips /mcp prefix, so we need to restore it for the external URL + // Fallback: GitHub strips /mcp prefix, so we need to restore it for the external URL if r.URL.Path == "/" { return "/mcp" } From 78f1a82e206bc9ed7cc3a24288be784d1c5bbb5f Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Fri, 23 Jan 2026 15:54:11 +0000 Subject: [PATCH 19/38] add oauth tests --- pkg/http/oauth/oauth_test.go | 677 +++++++++++++++++++++++++++++++++++ 1 file changed, 677 insertions(+) create mode 100644 pkg/http/oauth/oauth_test.go diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go new file mode 100644 index 000000000..035f5c35b --- /dev/null +++ b/pkg/http/oauth/oauth_test.go @@ -0,0 +1,677 @@ +package oauth + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + expectedAuthServer string + expectedResourcePath string + }{ + { + name: "nil config uses defaults", + cfg: nil, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "empty config uses defaults", + cfg: &Config{}, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://custom.example.com/oauth", + }, + expectedAuthServer: "https://custom.example.com/oauth", + expectedResourcePath: "", + }, + { + name: "custom base URL and resource path", + cfg: &Config{ + BaseURL: "https://example.com", + ResourcePath: "/mcp", + }, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "/mcp", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + require.NotNil(t, handler) + + assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + assert.NotNil(t, handler.protectedResourceTemplate) + }) + } +} + +func TestGetEffectiveHostAndScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + cfg *Config + expectedHost string + expectedScheme string + }{ + { + name: "basic request without forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", // defaults to https + }, + { + name: "request with X-Forwarded-Host header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with X-Forwarded-Proto header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "request with both forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "scheme is lowercased", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "HTTPS") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + host, scheme := GetEffectiveHostAndScheme(req, tc.cfg) + + assert.Equal(t, tc.expectedHost, host) + assert.Equal(t, tc.expectedScheme, scheme) + }) + } +} + +func TestGetEffectiveResourcePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + expectedPath string + }{ + { + name: "root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + return req + }, + expectedPath: "/mcp", + }, + { + name: "non-root path without original path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + return req + }, + expectedPath: "/mcp/readonly", + }, + { + name: "with X-GitHub-Original-Path header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/readonly", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/x/repos/readonly") + return req + }, + expectedPath: "/mcp/x/repos/readonly", + }, + { + name: "original path header takes precedence", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/something-else", nil) + req.Header.Set(headers.OriginalPathHeader, "/mcp/custom/path") + return req + }, + expectedPath: "/mcp/custom/path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + path := GetEffectiveResourcePath(req) + + assert.Equal(t, tc.expectedPath, path) + }) + } +} + +func TestGetProtectedResourceData(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedResourceURL string + expectedAuthServer string + expectError bool + }{ + { + name: "basic request with root resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedResourceURL: "https://api.example.com/", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "basic request with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom base URL", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "with custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://auth.example.com/oauth", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://api.example.com/mcp", + expectedAuthServer: "https://auth.example.com/oauth", + }, + { + name: "base URL with trailing slash is trimmed", + cfg: &Config{ + BaseURL: "https://custom.example.com/", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedResourceURL: "https://custom.example.com/mcp", + expectedAuthServer: DefaultAuthorizationServer, + }, + { + name: "nested resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/x/repos", + expectedResourceURL: "https://api.example.com/mcp/x/repos", + expectedAuthServer: DefaultAuthorizationServer, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + req := tc.setupRequest() + data, err := handler.GetProtectedResourceData(req, tc.resourcePath) + + if tc.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedResourceURL, data.ResourceURL) + assert.Equal(t, tc.expectedAuthServer, data.AuthorizationServer) + }) + } +} + +func TestBuildResourceMetadataURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedURL string + }{ + { + name: "root path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + { + name: "with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with base URL config", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://custom.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with forwarded headers", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + resourcePath: "/mcp", + expectedURL: "https://public.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "nil config uses request host", + cfg: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + url := BuildResourceMetadataURL(req, tc.cfg, tc.resourcePath) + + assert.Equal(t, tc.expectedURL, url) + }) + } +} + +func TestHandleProtectedResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + path string + host string + method string + expectedStatusCode int + expectedScopes []string + validateResponse func(t *testing.T, body map[string]any) + }{ + { + name: "GET request returns protected resource metadata", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + expectedScopes: SupportedScopes, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "GitHub MCP Server", body["resource_name"]) + assert.Contains(t, body["resource"], "api.example.com") + + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + }, + }, + { + name: "OPTIONS request for CORS", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodOptions, + expectedStatusCode: http.StatusOK, + }, + { + name: "path with /mcp suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/mcp", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/mcp") + }, + }, + { + name: "path with /readonly suffix", + cfg: &Config{}, + path: OAuthProtectedResourcePrefix + "/readonly", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Contains(t, body["resource"], "/readonly") + }, + }, + { + name: "custom authorization server in response", + cfg: &Config{ + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, "https://custom.auth.example.com/oauth", authServers[0]) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(tc.method, tc.path, nil) + req.Host = tc.host + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatusCode, rec.Code) + + // Check CORS headers + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "OPTIONS") + + if tc.method == http.MethodGet && tc.validateResponse != nil { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + + tc.validateResponse(t, body) + + // Verify scopes if expected + if tc.expectedScopes != nil { + scopes, ok := body["scopes_supported"].([]any) + require.True(t, ok) + assert.Len(t, scopes, len(tc.expectedScopes)) + } + } + }) + } +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + // List of expected routes that should be registered + expectedRoutes := []string{ + OAuthProtectedResourcePrefix, + OAuthProtectedResourcePrefix + "/", + OAuthProtectedResourcePrefix + "/mcp", + OAuthProtectedResourcePrefix + "/mcp/", + OAuthProtectedResourcePrefix + "/readonly", + OAuthProtectedResourcePrefix + "/readonly/", + OAuthProtectedResourcePrefix + "/mcp/readonly", + OAuthProtectedResourcePrefix + "/mcp/readonly/", + OAuthProtectedResourcePrefix + "/x/repos", + OAuthProtectedResourcePrefix + "/mcp/x/repos", + } + + for _, route := range expectedRoutes { + t.Run("route:"+route, func(t *testing.T) { + // Test GET + req := httptest.NewRequest(http.MethodGet, route, nil) + req.Host = "api.example.com" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) + + // Test OPTIONS (CORS) + req = httptest.NewRequest(http.MethodOptions, route, nil) + req.Host = "api.example.com" + rec = httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route) + }) + } +} + +func TestSupportedScopes(t *testing.T) { + t.Parallel() + + // Verify all expected scopes are present + expectedScopes := []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", + } + + assert.Equal(t, expectedScopes, SupportedScopes) +} + +func TestProtectedResourceResponseFormat(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{}) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all required RFC 9728 fields are present + assert.Contains(t, response, "resource") + assert.Contains(t, response, "authorization_servers") + assert.Contains(t, response, "bearer_methods_supported") + assert.Contains(t, response, "scopes_supported") + + // Verify resource name (optional but we include it) + assert.Contains(t, response, "resource_name") + assert.Equal(t, "GitHub MCP Server", response["resource_name"]) + + // Verify bearer_methods_supported contains "header" + bearerMethods, ok := response["bearer_methods_supported"].([]any) + require.True(t, ok) + assert.Contains(t, bearerMethods, "header") + + // Verify authorization_servers is an array with GitHub OAuth + authServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + assert.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) +} + +func TestOAuthProtectedResourcePrefix(t *testing.T) { + t.Parallel() + + // RFC 9728 specifies this well-known path + assert.Equal(t, "/.well-known/oauth-protected-resource", OAuthProtectedResourcePrefix) +} + +func TestDefaultAuthorizationServer(t *testing.T) { + t.Parallel() + + assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) +} From e2699c847c69bc775b47cc1770ba273006b8f380 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 26 Jan 2026 15:53:58 +0000 Subject: [PATCH 20/38] replace custom protected resource metadata handler with our own --- pkg/http/oauth/oauth.go | 75 +++++++-------------- pkg/http/oauth/oauth_test.go | 46 ++++++++----- pkg/http/oauth/protected_resource.json.tmpl | 20 ------ 3 files changed, 55 insertions(+), 86 deletions(-) delete mode 100644 pkg/http/oauth/protected_resource.json.tmpl diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index f24db6786..7710dd581 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,17 +3,16 @@ package oauth import ( - "bytes" - _ "embed" "fmt" - "html" "net/http" "net/url" "strings" - "text/template" + + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/oauthex" ) const ( @@ -24,9 +23,6 @@ const ( DefaultAuthorizationServer = "https://github.com/login/oauth" ) -//go:embed protected_resource.json.tmpl -var protectedResourceTemplate []byte - // SupportedScopes lists all OAuth scopes that may be required by MCP tools. var SupportedScopes = []string{ "repo", @@ -66,8 +62,7 @@ type ProtectedResourceData struct { // AuthHandler handles OAuth-related HTTP endpoints. type AuthHandler struct { - cfg *Config - protectedResourceTemplate *template.Template + cfg *Config } // NewAuthHandler creates a new OAuth auth handler. @@ -81,14 +76,8 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { cfg.AuthorizationServer = DefaultAuthorizationServer } - tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate)) - if err != nil { - return nil, fmt.Errorf("failed to parse protected resource template: %w", err) - } - return &AuthHandler{ - cfg: cfg, - protectedResourceTemplate: tmpl, + cfg: cfg, }, nil } @@ -96,6 +85,7 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { var routePatterns = []string{ "", // Root: /.well-known/oauth-protected-resource "/readonly", // Read-only mode + "/insiders", // Insiders mode "/x/{toolset}", "/x/{toolset}/readonly", } @@ -105,12 +95,30 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { for _, pattern := range routePatterns { for _, route := range h.routesForPattern(pattern) { path := OAuthProtectedResourcePrefix + route - r.Get(path, h.handleProtectedResource) - r.Options(path, h.handleProtectedResource) // CORS support + + // Build metadata for this specific resource path + metadata := h.buildMetadata(route) + r.Handle(path, auth.ProtectedResourceMetadataHandler(metadata)) } } } +func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResourceMetadata { + baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") + resourceURL := baseURL + if resourcePath != "" && resourcePath != "/" { + resourceURL = baseURL + resourcePath + } + + return &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{h.cfg.AuthorizationServer}, + ResourceName: "GitHub MCP Server", + ScopesSupported: SupportedScopes, + BearerMethodsSupported: []string{"header"}, + } +} + // routesForPattern generates route variants for a given pattern. // GitHub strips the /mcp prefix before forwarding, so we register both variants: // - With /mcp prefix: for direct access or when GitHub doesn't strip @@ -124,37 +132,6 @@ func (h *AuthHandler) routesForPattern(pattern string) []string { } } -// handleProtectedResource handles requests for OAuth protected resource metadata. -func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) { - // Extract the resource path from the URL - resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix) - if resourcePath == "" || resourcePath == "/" { - resourcePath = "/" - } else { - resourcePath = strings.TrimPrefix(resourcePath, "/") - } - - data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath)) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - var buf bytes.Buffer - if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - - // Set CORS headers - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(buf.Bytes()) -} - // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. // It checks for the X-GitHub-Original-Path header set by GitHub, which contains // the exact path the client requested before the /mcp prefix was stripped. diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 035f5c35b..2bff37363 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -62,7 +62,6 @@ func TestNewAuthHandler(t *testing.T) { require.NotNil(t, handler) assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) - assert.NotNil(t, handler.protectedResourceTemplate) }) } } @@ -444,8 +443,10 @@ func TestHandleProtectedResource(t *testing.T) { validateResponse func(t *testing.T, body map[string]any) }{ { - name: "GET request returns protected resource metadata", - cfg: &Config{}, + name: "GET request returns protected resource metadata", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix, host: "api.example.com", method: http.MethodGet, @@ -454,7 +455,7 @@ func TestHandleProtectedResource(t *testing.T) { validateResponse: func(t *testing.T, body map[string]any) { t.Helper() assert.Equal(t, "GitHub MCP Server", body["resource_name"]) - assert.Contains(t, body["resource"], "api.example.com") + assert.Equal(t, "https://api.example.com", body["resource"]) authServers, ok := body["authorization_servers"].([]any) require.True(t, ok) @@ -463,40 +464,47 @@ func TestHandleProtectedResource(t *testing.T) { }, }, { - name: "OPTIONS request for CORS", - cfg: &Config{}, + name: "OPTIONS request for CORS preflight", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix, host: "api.example.com", method: http.MethodOptions, - expectedStatusCode: http.StatusOK, + expectedStatusCode: http.StatusNoContent, }, { - name: "path with /mcp suffix", - cfg: &Config{}, + name: "path with /mcp suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix + "/mcp", host: "api.example.com", method: http.MethodGet, expectedStatusCode: http.StatusOK, validateResponse: func(t *testing.T, body map[string]any) { t.Helper() - assert.Contains(t, body["resource"], "/mcp") + assert.Equal(t, "https://api.example.com/mcp", body["resource"]) }, }, { - name: "path with /readonly suffix", - cfg: &Config{}, + name: "path with /readonly suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, path: OAuthProtectedResourcePrefix + "/readonly", host: "api.example.com", method: http.MethodGet, expectedStatusCode: http.StatusOK, validateResponse: func(t *testing.T, body map[string]any) { t.Helper() - assert.Contains(t, body["resource"], "/readonly") + assert.Equal(t, "https://api.example.com/readonly", body["resource"]) }, }, { name: "custom authorization server in response", cfg: &Config{ + BaseURL: "https://api.example.com", AuthorizationServer: "https://custom.auth.example.com/oauth", }, path: OAuthProtectedResourcePrefix, @@ -559,7 +567,9 @@ func TestHandleProtectedResource(t *testing.T) { func TestRegisterRoutes(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(&Config{}) + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) require.NoError(t, err) router := chi.NewRouter() @@ -588,12 +598,12 @@ func TestRegisterRoutes(t *testing.T) { router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) - // Test OPTIONS (CORS) + // Test OPTIONS (CORS preflight) req = httptest.NewRequest(http.MethodOptions, route, nil) req.Host = "api.example.com" rec = httptest.NewRecorder() router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route) + assert.Equal(t, http.StatusNoContent, rec.Code, "OPTIONS %s should return 204", route) }) } } @@ -623,7 +633,9 @@ func TestSupportedScopes(t *testing.T) { func TestProtectedResourceResponseFormat(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(&Config{}) + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) require.NoError(t, err) router := chi.NewRouter() diff --git a/pkg/http/oauth/protected_resource.json.tmpl b/pkg/http/oauth/protected_resource.json.tmpl deleted file mode 100644 index 7a9257404..000000000 --- a/pkg/http/oauth/protected_resource.json.tmpl +++ /dev/null @@ -1,20 +0,0 @@ -{ - "resource_name": "GitHub MCP Server", - "resource": "{{.ResourceURL}}", - "authorization_servers": ["{{.AuthorizationServer}}"], - "bearer_methods_supported": ["header"], - "scopes_supported": [ - "repo", - "read:org", - "read:user", - "user:email", - "read:packages", - "write:packages", - "read:project", - "project", - "gist", - "notifications", - "workflow", - "codespace" - ] -} From 9c21eedde022e0dbbd2445ed634941adabc20b86 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Mon, 26 Jan 2026 16:39:03 +0000 Subject: [PATCH 21/38] Update pkg/http/oauth/oauth.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/http/oauth/oauth.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 7710dd581..8ef26d9c6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/modelcontextprotocol/go-sdk/auth" - "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/oauthex" From 49191a90764d5f3c7ca704f0c6fe73d75b04fe91 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 14:04:17 +0000 Subject: [PATCH 22/38] chore: retrigger ci From 03a508204613ca5089822da629c9111cda67547f Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 14:13:35 +0000 Subject: [PATCH 23/38] update more types --- pkg/http/oauth/oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 8ef26d9c6..8934a21c6 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -8,9 +8,9 @@ import ( "net/url" "strings" - "github.com/modelcontextprotocol/go-sdk/auth" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" ) From 97092a057738dc5ea287e53c7e95ebe68be6ecae Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 16:08:34 +0000 Subject: [PATCH 24/38] remove CAPI specific header --- pkg/http/headers/headers.go | 4 ---- pkg/http/oauth/oauth.go | 17 +++-------------- pkg/http/oauth/oauth_test.go | 30 ++++++++++++------------------ 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index c9846fa97..8c389c828 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -26,10 +26,6 @@ const ( // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. ForwardedProtoHeader = "X-Forwarded-Proto" - // OriginalPathHeader is set to preserve the original request path - // before the /mcp prefix was stripped during proxying. - OriginalPathHeader = "X-GitHub-Original-Path" - // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 8934a21c6..112e4c568 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -132,21 +132,10 @@ func (h *AuthHandler) routesForPattern(pattern string) []string { } // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. -// It checks for the X-GitHub-Original-Path header set by GitHub, which contains -// the exact path the client requested before the /mcp prefix was stripped. -// If the header is not present, it falls back to -// restoring the /mcp prefix. +// It uses the request's URL path directly. For deployments where a prefix like /mcp +// is stripped by a proxy, the proxy should set the BaseURL config appropriately. func GetEffectiveResourcePath(r *http.Request) string { - // Check for the original path header from GitHub (preferred method) - if originalPath := r.Header.Get(headers.OriginalPathHeader); originalPath != "" { - return originalPath - } - - // Fallback: GitHub strips /mcp prefix, so we need to restore it for the external URL - if r.URL.Path == "/" { - return "/mcp" - } - return "/mcp" + r.URL.Path + return r.URL.Path } // GetProtectedResourceData builds the OAuth protected resource data for a request. diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 2bff37363..30a516c9b 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -185,38 +185,32 @@ func TestGetEffectiveResourcePath(t *testing.T) { expectedPath string }{ { - name: "root path without original path header", + name: "root path", setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/", nil) - return req + return httptest.NewRequest(http.MethodGet, "/", nil) }, - expectedPath: "/mcp", + expectedPath: "/", }, { - name: "non-root path without original path header", + name: "mcp path", setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/readonly", nil) - return req + return httptest.NewRequest(http.MethodGet, "/mcp", nil) }, - expectedPath: "/mcp/readonly", + expectedPath: "/mcp", }, { - name: "with X-GitHub-Original-Path header", + name: "readonly path", setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/readonly", nil) - req.Header.Set(headers.OriginalPathHeader, "/mcp/x/repos/readonly") - return req + return httptest.NewRequest(http.MethodGet, "/readonly", nil) }, - expectedPath: "/mcp/x/repos/readonly", + expectedPath: "/readonly", }, { - name: "original path header takes precedence", + name: "nested path", setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/something-else", nil) - req.Header.Set(headers.OriginalPathHeader, "/mcp/custom/path") - return req + return httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) }, - expectedPath: "/mcp/custom/path", + expectedPath: "/mcp/x/repos", }, } From cfea76203d940995645e45f0d38b1afeaf467403 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Wed, 28 Jan 2026 16:11:24 +0000 Subject: [PATCH 25/38] restore mcp path specific logic --- pkg/http/oauth/oauth.go | 9 ++++++--- pkg/http/oauth/oauth_test.go | 17 +++++------------ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 112e4c568..a96322b8e 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -132,10 +132,13 @@ func (h *AuthHandler) routesForPattern(pattern string) []string { } // GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. -// It uses the request's URL path directly. For deployments where a prefix like /mcp -// is stripped by a proxy, the proxy should set the BaseURL config appropriately. +// Since proxies may strip the /mcp prefix before forwarding requests, this function +// restores the prefix for the external-facing URL. func GetEffectiveResourcePath(r *http.Request) string { - return r.URL.Path + if r.URL.Path == "/" { + return "/mcp" + } + return "/mcp" + r.URL.Path } // GetProtectedResourceData builds the OAuth protected resource data for a request. diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 30a516c9b..3a5188c72 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -185,30 +185,23 @@ func TestGetEffectiveResourcePath(t *testing.T) { expectedPath string }{ { - name: "root path", + name: "root path restores /mcp prefix", setupRequest: func() *http.Request { return httptest.NewRequest(http.MethodGet, "/", nil) }, - expectedPath: "/", - }, - { - name: "mcp path", - setupRequest: func() *http.Request { - return httptest.NewRequest(http.MethodGet, "/mcp", nil) - }, expectedPath: "/mcp", }, { - name: "readonly path", + name: "non-root path adds /mcp prefix", setupRequest: func() *http.Request { return httptest.NewRequest(http.MethodGet, "/readonly", nil) }, - expectedPath: "/readonly", + expectedPath: "/mcp/readonly", }, { - name: "nested path", + name: "nested path adds /mcp prefix", setupRequest: func() *http.Request { - return httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) }, expectedPath: "/mcp/x/repos", }, From 199e62c685f46f3053cac91431e4f750e4aef3b0 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Thu, 29 Jan 2026 11:01:05 +0100 Subject: [PATCH 26/38] WIP --- cmd/github-mcp-server/main.go | 9 +- pkg/context/mcp_info.go | 39 ++++++ pkg/context/token.go | 27 ++-- pkg/github/dependencies.go | 12 +- pkg/github/server.go | 16 ++- pkg/http/middleware/mcp_parse.go | 126 ++++++++++++++++++ pkg/http/middleware/scope_challenge.go | 173 +++++++++++++++++++++++++ pkg/http/middleware/token.go | 77 +---------- pkg/utils/token.go | 82 ++++++++++++ 9 files changed, 470 insertions(+), 91 deletions(-) create mode 100644 pkg/context/mcp_info.go create mode 100644 pkg/http/middleware/mcp_parse.go create mode 100644 pkg/http/middleware/scope_challenge.go create mode 100644 pkg/utils/token.go diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index fbca6ccff..4a778f8bd 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -136,7 +136,10 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") - rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") + + // Add port flag to http command + httpCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + httpCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -152,8 +155,8 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) - _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) - _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) + _ = viper.BindPFlag("port", httpCmd.PersistentFlags().Lookup("port")) + _ = viper.BindPFlag("base-url", httpCmd.PersistentFlags().Lookup("base-url")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/context/mcp_info.go b/pkg/context/mcp_info.go new file mode 100644 index 000000000..d93cc8e81 --- /dev/null +++ b/pkg/context/mcp_info.go @@ -0,0 +1,39 @@ +package context + +import "context" + +type mcpMethodInfoCtx string + +var mcpMethodInfoCtxKey mcpMethodInfoCtx = "mcpmethodinfo" + +// MCPMethodInfo contains pre-parsed MCP method information extracted from the JSON-RPC request. +// This is populated early in the request lifecycle to enable: +// - Inventory filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in middlewares (secret-scanning, scope-challenge) +// - Performance optimization for per-request server creation +type MCPMethodInfo struct { + // Method is the MCP method being called (e.g., "tools/call", "tools/list", "initialize") + Method string + // ItemName is the name of the specific item being accessed (tool name, resource URI, prompt name) + // Only populated for call/get methods (tools/call, prompts/get, resources/read) + ItemName string + // Owner is the repository owner from tool call arguments, if present + Owner string + // Repo is the repository name from tool call arguments, if present + Repo string + // Arguments contains the raw tool arguments for tools/call requests + Arguments map[string]any +} + +// ContextWithMCPMethodInfo stores the MCPMethodInfo in the context. +func ContextWithMCPMethodInfo(ctx context.Context, info *MCPMethodInfo) context.Context { + return context.WithValue(ctx, mcpMethodInfoCtxKey, info) +} + +// MCPMethod retrieves the MCPMethodInfo from the context. +func MCPMethod(ctx context.Context) (*MCPMethodInfo, bool) { + if info, ok := ctx.Value(mcpMethodInfoCtxKey).(*MCPMethodInfo); ok { + return info, true + } + return nil, false +} diff --git a/pkg/context/token.go b/pkg/context/token.go index dd303f029..0c86e38ab 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -1,19 +1,30 @@ package context -import "context" +import ( + "context" + + "github.com/github/github-mcp-server/pkg/utils" +) // tokenCtxKey is a context key for authentication token information -type tokenCtxKey struct{} +type tokenCtx string + +var tokenCtxKey tokenCtx = "tokenctx" + +type TokenInfo struct { + Token string + TokenType utils.TokenType +} // WithTokenInfo adds TokenInfo to the context -func WithTokenInfo(ctx context.Context, token string) context.Context { - return context.WithValue(ctx, tokenCtxKey{}, token) +func WithTokenInfo(ctx context.Context, token string, tokenType utils.TokenType) context.Context { + return context.WithValue(ctx, tokenCtxKey, TokenInfo{Token: token, TokenType: tokenType}) } // GetTokenInfo retrieves the authentication token from the context -func GetTokenInfo(ctx context.Context) (string, bool) { - if token, ok := ctx.Value(tokenCtxKey{}).(string); ok { - return token, true +func GetTokenInfo(ctx context.Context) (TokenInfo, bool) { + if tokenInfo, ok := ctx.Value(tokenCtxKey).(TokenInfo); ok { + return tokenInfo, true } - return "", false + return TokenInfo{}, false } diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index bdcafe933..028499b8f 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -283,7 +283,11 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { } // extract the token from the context - token, _ := ghcontext.GetTokenInfo(ctx) + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token baseRestURL, err := d.apiHosts.BaseRESTURL(ctx) if err != nil { @@ -309,7 +313,11 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error } // extract the token from the context - token, _ := ghcontext.GetTokenInfo(ctx) + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + return nil, fmt.Errorf("no token info in context") + } + token := tokenInfo.Token // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host diff --git a/pkg/github/server.go b/pkg/github/server.go index fddd85123..203dcabbd 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -8,6 +8,7 @@ import ( "strings" "time" + ghcontext "github.com/github/github-mcp-server/pkg/context" gherrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/octicons" @@ -73,10 +74,10 @@ type MCPServerConfig struct { type MCPServerOption func(*mcp.ServerOptions) -func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inventory *inventory.Inventory) (*mcp.Server, error) { +func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inv *inventory.Inventory) (*mcp.Server, error) { // Create the MCP server serverOpts := &mcp.ServerOptions{ - Instructions: inventory.Instructions(), + Instructions: inv.Instructions(), Logger: cfg.Logger, CompletionHandler: CompletionsHandler(deps.GetClient), } @@ -102,20 +103,25 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) - if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { + if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 { cfg.Logger.Warn("Warning: unrecognized toolsets ignored", "toolsets", strings.Join(unrecognized, ", ")) } + invToUse := inv + if methodInfo, ok := ghcontext.MCPMethod(ctx); ok && methodInfo != nil { + invToUse = inv.ForMCPRequest(methodInfo.Method, methodInfo.ItemName) + } + // Register GitHub tools/resources/prompts from the inventory. // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets // is empty - users enable toolsets at runtime via the dynamic tools below (but can // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(ctx, ghServer, deps) + invToUse.RegisterAll(ctx, ghServer, deps) // Register dynamic toolset management tools (enable/disable) - these are separate // meta-tools that control the inventory, not part of the inventory itself if cfg.DynamicToolsets { - registerDynamicTools(ghServer, inventory, deps, cfg.Translator) + registerDynamicTools(ghServer, invToUse, deps, cfg.Translator) } return ghServer, nil diff --git a/pkg/http/middleware/mcp_parse.go b/pkg/http/middleware/mcp_parse.go new file mode 100644 index 000000000..efff53a17 --- /dev/null +++ b/pkg/http/middleware/mcp_parse.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + ghcontext "github.com/github/github-mcp-server/pkg/context" +) + +// mcpJSONRPCRequest represents the structure of an MCP JSON-RPC request. +// We only parse the fields needed for routing and optimization. +type mcpJSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + // For tools/call + Name string `json:"name,omitempty"` + Arguments json.RawMessage `json:"arguments,omitempty"` + // For prompts/get + // Name is shared with tools/call + // For resources/read + URI string `json:"uri,omitempty"` + } `json:"params"` +} + +// WithMCPParse creates a middleware that parses MCP JSON-RPC requests early in the +// request lifecycle and stores the parsed information in the request context. +// This enables: +// - Registry filtering via ForMCPRequest (only register needed tools/resources/prompts) +// - Avoiding duplicate JSON parsing in downstream middlewares +// - Access to owner/repo for secret-scanning middleware +// +// The middleware reads the request body, parses it, restores the body for downstream +// handlers, and stores the parsed MCPMethodInfo in the request context. +func WithMCPParse() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Only parse POST requests (MCP uses JSON-RPC over POST) + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + // Read the request body + body, err := io.ReadAll(r.Body) + if err != nil { + // Log but continue - don't block requests on parse errors + next.ServeHTTP(w, r) + return + } + + // Restore the body for downstream handlers + r.Body = io.NopCloser(bytes.NewReader(body)) + + // Skip empty bodies + if len(body) == 0 { + next.ServeHTTP(w, r) + return + } + + // Parse the JSON-RPC request + var mcpReq mcpJSONRPCRequest + err = json.Unmarshal(body, &mcpReq) + if err != nil { + // Log but continue - could be a non-MCP request or malformed JSON + next.ServeHTTP(w, r) + return + } + + // Skip if not a valid JSON-RPC 2.0 request + if mcpReq.JSONRPC != "2.0" || mcpReq.Method == "" { + next.ServeHTTP(w, r) + return + } + + // Build the MCPMethodInfo + methodInfo := &ghcontext.MCPMethodInfo{ + Method: mcpReq.Method, + } + + // Extract item name based on method type + + switch mcpReq.Method { + case "tools/call": + methodInfo.ItemName = mcpReq.Params.Name + // Parse arguments if present + if len(mcpReq.Params.Arguments) > 0 { + var args map[string]any + err := json.Unmarshal(mcpReq.Params.Arguments, &args) + if err == nil { + methodInfo.Arguments = args + // Extract owner and repo if present + if owner, ok := args["owner"].(string); ok { + methodInfo.Owner = owner + } + if repo, ok := args["repo"].(string); ok { + methodInfo.Repo = repo + } + } + } + case "prompts/get": + methodInfo.ItemName = mcpReq.Params.Name + case "resources/read": + methodInfo.ItemName = mcpReq.Params.URI + default: + // Whatever + } + + // Store the parsed info in context + ctx = ghcontext.ContextWithMCPMethodInfo(ctx, methodInfo) + + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go new file mode 100644 index 000000000..435e5cebe --- /dev/null +++ b/pkg/http/middleware/scope_challenge.go @@ -0,0 +1,173 @@ +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/utils" +) + +// FetchScopesFromGitHubAPI fetches the OAuth scopes from the GitHub API by making +// a HEAD request and reading the X-OAuth-Scopes header. This is used as a fallback +// when scopes are not provided in the token info header. +func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiUrls *utils.APIHost) ([]string, error) { + baseUrl, err := apiUrls.BaseRESTURL(ctx) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, strings.TrimSuffix(baseUrl.String(), "/")+"/user", http.NoBody) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + scopeHeader := resp.Header.Get("X-OAuth-Scopes") + if scopeHeader == "" { + return []string{}, nil + } + + // Parse comma-separated scopes and trim whitespace + rawScopes := strings.Split(scopeHeader, ",") + parsedScopes := make([]string, 0, len(rawScopes)) + for _, s := range rawScopes { + trimmed := strings.TrimSpace(s) + if trimmed != "" { + parsedScopes = append(parsedScopes, trimmed) + } + } + return parsedScopes, nil +} + +// WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to +// complete the request and returns a scope challenge if not. +func WithScopeChallenge(oauthCfg *oauth.Config, apiUrls *utils.APIHost) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Skip health check endpoints + if r.URL.Path == "/_ping" { + next.ServeHTTP(w, r) + return + } + + // Get user from context + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok { + next.ServeHTTP(w, r) + return + } + + // Only check OAuth tokens - scope challenge allows OAuth apps to request additional scopes + if tokenInfo.TokenType != utils.TokenTypeOAuthAccessToken { + next.ServeHTTP(w, r) + return + } + + // Try to use pre-parsed MCP method info first (performance optimization) + // This avoids re-parsing the JSON body if WithMCPParse middleware ran earlier + var toolName string + if methodInfo, ok := ghcontext.MCPMethod(ctx); ok && methodInfo != nil { + // Only check tools/call requests + if methodInfo.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + toolName = methodInfo.ItemName + } else { + // Fallback: parse the request body directly + body, err := io.ReadAll(r.Body) + if err != nil { + next.ServeHTTP(w, r) + return + } + r.Body = io.NopCloser(bytes.NewReader(body)) + + var mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` + } `json:"params"` + } + + err = json.Unmarshal(body, &mcpRequest) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Only check tools/call requests + if mcpRequest.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + + toolName = mcpRequest.Params.Name + } + toolScopeInfo, err := scopes.GetToolScopeInfo(toolName) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // If tool not found in scope map, allow the request + if toolScopeInfo == nil { + next.ServeHTTP(w, r) + return + } + + // Get OAuth scopes from GitHub API + activeScopes, err := FetchScopesFromGitHubAPI(ctx, tokenInfo.Token, apiUrls) + + // Check if user has the required scopes + if toolScopeInfo.HasAcceptedScope(activeScopes...) { + next.ServeHTTP(w, r) + return + } + + // User lacks required scopes - get the scopes they need + requiredScopes := toolScopeInfo.GetRequiredScopesSlice() + + // Build the resource metadata URL using the shared utility + // GetEffectiveResourcePath returns the original path (e.g., /mcp or /mcp/x/all) + // which is used to construct the well-known OAuth protected resource URL + resourcePath := oauth.GetEffectiveResourcePath(r) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) + + // Build recommended scopes: existing scopes + required scopes + recommendedScopes := make([]string, 0, len(activeScopes)+len(requiredScopes)) + recommendedScopes = append(recommendedScopes, activeScopes...) + recommendedScopes = append(recommendedScopes, requiredScopes...) + + // Build the WWW-Authenticate header value + wwwAuthenticateHeader := fmt.Sprintf(`Bearer error="insufficient_scope", scope=%q, resource_metadata=%q, error_description=%q`, + strings.Join(recommendedScopes, " "), + resourceMetadataURL, + "Additional scopes required: "+strings.Join(requiredScopes, ", "), + ) + + // Send scope challenge response with the superset of existing and required scopes + w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) + http.Error(w, "Forbidden: insufficient scopes", http.StatusForbidden) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index e09026f24..ce952b073 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -4,49 +4,19 @@ import ( "errors" "fmt" "net/http" - "regexp" - "strings" ghcontext "github.com/github/github-mcp-server/pkg/context" - httpheaders "github.com/github/github-mcp-server/pkg/http/headers" - "github.com/github/github-mcp-server/pkg/http/mark" "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/utils" ) -type authType int - -const ( - authTypeUnknown authType = iota - authTypeIDE - authTypeGhToken -) - -var ( - errMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) - errBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) - errUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) -) - -var supportedThirdPartyTokenPrefixes = []string{ - "ghp_", // Personal access token (classic) - "github_pat_", // Fine-grained personal access token - "gho_", // OAuth access token - "ghu_", // User access token for a GitHub App - "ghs_", // Installation access token for a GitHub App (a.k.a. server-to-server token) -} - -// oldPatternRegexp is the regular expression for the old pattern of the token. -// Until 2021, GitHub API tokens did not have an identifiable prefix. They -// were 40 characters long and only contained the characters a-f and 0-9. -var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) - func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, token, err := parseAuthorizationHeader(r) + tokenType, token, err := utils.ParseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec - if errors.Is(err, errMissingAuthorizationHeader) { + if errors.Is(err, utils.ErrMissingAuthorizationHeader) { sendAuthChallenge(w, r, oauthCfg) return } @@ -56,7 +26,7 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl } ctx := r.Context() - ctx = ghcontext.WithTokenInfo(ctx, token) + ctx = ghcontext.WithTokenInfo(ctx, token, tokenType) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -71,42 +41,3 @@ func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.C w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) http.Error(w, "Unauthorized", http.StatusUnauthorized) } - -func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { - authHeader := req.Header.Get(httpheaders.AuthorizationHeader) - if authHeader == "" { - return 0, "", errMissingAuthorizationHeader - } - - switch { - // decrypt dotcom token and set it as token - case strings.HasPrefix(authHeader, "GitHub-Bearer "): - return 0, "", errUnsupportedAuthorizationHeader - default: - // support both "Bearer" and "bearer" to conform to api.github.com - if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { - token = authHeader[7:] - } else { - token = authHeader - } - } - - // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. - // ex: tid=1;exp=25145314523;chat=1: - if strings.Contains(token, ":") { - return authTypeIDE, token, nil - } - - for _, prefix := range supportedThirdPartyTokenPrefixes { - if strings.HasPrefix(token, prefix) { - return authTypeGhToken, token, nil - } - } - - matchesOldTokenPattern := oldPatternRegexp.MatchString(token) - if matchesOldTokenPattern { - return authTypeGhToken, token, nil - } - - return 0, "", errBadAuthorizationHeader -} diff --git a/pkg/utils/token.go b/pkg/utils/token.go new file mode 100644 index 000000000..62cacbf4d --- /dev/null +++ b/pkg/utils/token.go @@ -0,0 +1,82 @@ +package utils + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + httpheaders "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/mark" +) + +type TokenType int + +const ( + TokenTypeUnknown TokenType = iota + TokenTypePersonalAccessToken + TokenTypeFineGrainedPersonalAccessToken + TokenTypeOAuthAccessToken + TokenTypeUserToServerGitHubAppToken + TokenTypeServerToServerGitHubAppToken + TokenTypeIDEToken +) + +var supportedThirdPartyTokenPrefixes = map[string]TokenType{ + "ghp_": TokenTypePersonalAccessToken, // Personal access token (classic) + "github_pat_": TokenTypeFineGrainedPersonalAccessToken, // Fine-grained personal access token + "gho_": TokenTypeOAuthAccessToken, // OAuth access token + "ghu_": TokenTypeUserToServerGitHubAppToken, // User access token for a GitHub App + "ghs_": TokenTypeServerToServerGitHubAppToken, // Installation access token for a GitHub App (a.k.a. server-to-server token) +} + +var ( + ErrMissingAuthorizationHeader = fmt.Errorf("%w: missing required Authorization header", mark.ErrBadRequest) + ErrBadAuthorizationHeader = fmt.Errorf("%w: Authorization header is badly formatted", mark.ErrBadRequest) + ErrUnsupportedAuthorizationHeader = fmt.Errorf("%w: unsupported Authorization header", mark.ErrBadRequest) +) + +// oldPatternRegexp is the regular expression for the old pattern of the token. +// Until 2021, GitHub API tokens did not have an identifiable prefix. They +// were 40 characters long and only contained the characters a-f and 0-9. +var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) + +// ParseAuthorizationHeader parses the Authorization header from the HTTP request +func ParseAuthorizationHeader(req *http.Request) (tokenType TokenType, token string, _ error) { + authHeader := req.Header.Get(httpheaders.AuthorizationHeader) + if authHeader == "" { + return 0, "", ErrMissingAuthorizationHeader + } + + switch { + // decrypt dotcom token and set it as token + case strings.HasPrefix(authHeader, "GitHub-Bearer "): + return 0, "", ErrUnsupportedAuthorizationHeader + default: + // support both "Bearer" and "bearer" to conform to api.github.com + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = authHeader[7:] + } else { + token = authHeader + } + } + + // Do a naïve check for a colon in the token - currently, only the IDE token has a colon in it. + // ex: tid=1;exp=25145314523;chat=1: + if strings.Contains(token, ":") { + return TokenTypeIDEToken, token, nil + } + + for prefix, tokenType := range supportedThirdPartyTokenPrefixes { + if strings.HasPrefix(token, prefix) { + return tokenType, token, nil + } + } + + matchesOldTokenPattern := oldPatternRegexp.MatchString(token) + if matchesOldTokenPattern { + return TokenTypePersonalAccessToken, token, nil + } + + return 0, "", ErrBadAuthorizationHeader +} From 840b41e1e4891fde2985cf1a7d2ce687ec4676eb Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 29 Jan 2026 13:32:41 +0000 Subject: [PATCH 27/38] implement better resource path handling for OAuth server --- cmd/github-mcp-server/main.go | 3 + pkg/http/handler.go | 13 ++- pkg/http/middleware/token.go | 3 +- pkg/http/oauth/oauth.go | 208 ++++++++++++++++++++-------------- pkg/http/oauth/oauth_test.go | 161 ++++++++------------------ pkg/http/server.go | 7 +- 6 files changed, 193 insertions(+), 202 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index fbca6ccff..386733bf6 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -102,6 +102,7 @@ var ( Host: viper.GetString("host"), Port: viper.GetInt("port"), BaseURL: viper.GetString("base-url"), + ResourcePath: viper.GetString("resource-path"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -137,6 +138,7 @@ func init() { rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") + rootCmd.PersistentFlags().String("resource-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -154,6 +156,7 @@ func init() { _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) + _ = viper.BindPFlag("resource-path", rootCmd.PersistentFlags().Lookup("resource-path")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 3a14bd624..9378f6df1 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -93,13 +93,16 @@ func NewHTTPMcpHandler( // RegisterRoutes registers the routes for the MCP server // URL-based values take precedence over header-based values func (h *Handler) RegisterRoutes(r chi.Router) { - r.Use(middleware.WithRequestConfig) + mcpRouter := chi.NewRouter() + mcpRouter.Use(middleware.WithRequestConfig) - r.Mount("/", h) + mcpRouter.Mount("/", h) // Mount readonly and toolset routes - r.With(withToolset).Mount("/x/{toolset}", h) - r.With(withReadonly, withToolset).Mount("/x/{toolset}/readonly", h) - r.With(withReadonly).Mount("/readonly", h) + mcpRouter.With(withToolset).Mount("/x/{toolset}", h) + mcpRouter.With(withReadonly, withToolset).Mount("/x/{toolset}/readonly", h) + mcpRouter.With(withReadonly).Mount("/readonly", h) + + r.Mount("/", mcpRouter) } // withReadonly is middleware that sets readonly mode in the request context diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index e09026f24..26973a548 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -67,7 +67,8 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl // sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header // containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { - resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, "mcp") + resourcePath := oauth.ResolveResourcePath(r, oauthCfg) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) http.Error(w, "Unauthorized", http.StatusUnauthorized) } diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index a96322b8e..f8fdd75fc 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,14 +3,13 @@ package oauth import ( + "encoding/json" "fmt" "net/http" - "net/url" "strings" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" - "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" ) @@ -48,17 +47,12 @@ type Config struct { // Defaults to GitHub's OAuth server if not specified. AuthorizationServer string - // ResourcePath is the resource path suffix (e.g., "/mcp"). - // If empty, defaults to "/" + // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + // If empty, requests are treated as already using the external path. ResourcePath string } -// ProtectedResourceData contains the data needed to build an OAuth protected resource response. -type ProtectedResourceData struct { - ResourceURL string - AuthorizationServer string -} - // AuthHandler handles OAuth-related HTTP endpoints. type AuthHandler struct { cfg *Config @@ -94,28 +88,42 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { for _, pattern := range routePatterns { for _, route := range h.routesForPattern(pattern) { path := OAuthProtectedResourcePrefix + route - - // Build metadata for this specific resource path - metadata := h.buildMetadata(route) - r.Handle(path, auth.ProtectedResourceMetadataHandler(metadata)) + r.Handle(path, h.metadataHandler()) } } } -func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResourceMetadata { - baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/") - resourceURL := baseURL - if resourcePath != "" && resourcePath != "/" { - resourceURL = baseURL + resourcePath - } +func (h *AuthHandler) metadataHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // CORS headers for browser-based clients + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - return &oauthex.ProtectedResourceMetadata{ - Resource: resourceURL, - AuthorizationServers: []string{h.cfg.AuthorizationServer}, - ResourceName: "GitHub MCP Server", - ScopesSupported: SupportedScopes, - BearerMethodsSupported: []string{"header"}, - } + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + resourcePath := resolveResourcePath( + strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix), + h.cfg.ResourcePath, + ) + resourceURL := h.buildResourceURL(r, resourcePath) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{h.cfg.AuthorizationServer}, + ResourceName: "GitHub MCP Server", + ScopesSupported: SupportedScopes, + BearerMethodsSupported: []string{"header"}, + }) + }) } // routesForPattern generates route variants for a given pattern. @@ -123,90 +131,122 @@ func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResou // - With /mcp prefix: for direct access or when GitHub doesn't strip // - Without /mcp prefix: for when GitHub has stripped the prefix func (h *AuthHandler) routesForPattern(pattern string) []string { - return []string{ - pattern, - "/mcp" + pattern, - pattern + "/", - "/mcp" + pattern + "/", + basePaths := []string{""} + if basePath := normalizeBasePath(h.cfg.ResourcePath); basePath != "" { + basePaths = append(basePaths, basePath) + } else { + basePaths = append(basePaths, "/mcp") } + + routes := make([]string, 0, len(basePaths)*2) + for _, basePath := range basePaths { + routes = append(routes, joinRoute(basePath, pattern)) + routes = append(routes, joinRoute(basePath, pattern)+"/") + } + + return routes +} + +// resolveResourcePath returns the externally visible resource path, +// restoring the configured base path when proxies strip it before forwarding. +func resolveResourcePath(path, basePath string) string { + if path == "" { + path = "/" + } + base := normalizeBasePath(basePath) + if base == "" { + return path + } + if path == "/" { + return base + } + if path == base || strings.HasPrefix(path, base+"/") { + return path + } + return base + path } -// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs. -// Since proxies may strip the /mcp prefix before forwarding requests, this function -// restores the prefix for the external-facing URL. -func GetEffectiveResourcePath(r *http.Request) string { - if r.URL.Path == "/" { - return "/mcp" +// ResolveResourcePath returns the externally visible resource path for a request. +// Exported for use by middleware. +func ResolveResourcePath(r *http.Request, cfg *Config) string { + basePath := "" + if cfg != nil { + basePath = cfg.ResourcePath } - return "/mcp" + r.URL.Path + return resolveResourcePath(r.URL.Path, basePath) } -// GetProtectedResourceData builds the OAuth protected resource data for a request. -func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) { +// buildResourceURL constructs the full resource URL for OAuth metadata. +func (h *AuthHandler) buildResourceURL(r *http.Request, resourcePath string) string { host, scheme := GetEffectiveHostAndScheme(r, h.cfg) - - // Build the base URL baseURL := fmt.Sprintf("%s://%s", scheme, host) if h.cfg.BaseURL != "" { baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") } - - // Build the resource URL using url.JoinPath for proper path handling - var resourceURL string - var err error - if resourcePath == "/" { - resourceURL = baseURL + "/" - } else { - resourceURL, err = url.JoinPath(baseURL, resourcePath) - if err != nil { - return nil, fmt.Errorf("failed to build resource URL: %w", err) - } + if resourcePath == "" { + resourcePath = "/" } - - return &ProtectedResourceData{ - ResourceURL: resourceURL, - AuthorizationServer: h.cfg.AuthorizationServer, - }, nil + if !strings.HasPrefix(resourcePath, "/") { + resourcePath = "/" + resourcePath + } + return baseURL + resourcePath } // GetEffectiveHostAndScheme returns the effective host and scheme for a request. -// It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies), -// then falls back to the request's Host and TLS state. -func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive // parameters are required by http.oauth.BuildResourceMetadataURL signature - // Check for forwarded headers first (typically set by reverse proxies) - if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" { - host = forwardedHost +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive + if fh := r.Header.Get(headers.ForwardedHostHeader); fh != "" { + host = fh } else { host = r.Host } - - // Determine scheme - switch { - case r.Header.Get(headers.ForwardedProtoHeader) != "": - scheme = strings.ToLower(r.Header.Get(headers.ForwardedProtoHeader)) - case r.TLS != nil: - scheme = "https" - default: - // Default to HTTPS in production scenarios - scheme = "https" + if host == "" { + host = "localhost" } - - return host, scheme + if fp := r.Header.Get(headers.ForwardedProtoHeader); fp != "" { + scheme = strings.ToLower(fp) + } else { + scheme = "https" // Default to HTTPS + } + return } // BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { host, scheme := GetEffectiveHostAndScheme(r, cfg) - + suffix := "" + if resourcePath != "" && resourcePath != "/" { + if !strings.HasPrefix(resourcePath, "/") { + suffix = "/" + resourcePath + } else { + suffix = resourcePath + } + } if cfg != nil && cfg.BaseURL != "" { - baseURL := strings.TrimSuffix(cfg.BaseURL, "/") - return baseURL + OAuthProtectedResourcePrefix + "/" + strings.TrimPrefix(resourcePath, "/") + return strings.TrimSuffix(cfg.BaseURL, "/") + OAuthProtectedResourcePrefix + suffix } + return fmt.Sprintf("%s://%s%s%s", scheme, host, OAuthProtectedResourcePrefix, suffix) +} - path := OAuthProtectedResourcePrefix - if resourcePath != "" && resourcePath != "/" { - path = path + "/" + strings.TrimPrefix(resourcePath, "/") +func normalizeBasePath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" || trimmed == "/" { + return "" } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + return strings.TrimSuffix(trimmed, "/") +} - return fmt.Sprintf("%s://%s%s", scheme, host, path) +func joinRoute(basePath, pattern string) string { + if basePath == "" { + return pattern + } + if pattern == "" { + return basePath + } + if strings.HasSuffix(basePath, "/") { + return strings.TrimSuffix(basePath, "/") + pattern + } + return basePath + pattern } diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 3a5188c72..ffee5408b 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -176,138 +176,62 @@ func TestGetEffectiveHostAndScheme(t *testing.T) { } } -func TestGetEffectiveResourcePath(t *testing.T) { +func TestResolveResourcePath(t *testing.T) { t.Parallel() tests := []struct { name string + cfg *Config setupRequest func() *http.Request expectedPath string }{ { - name: "root path restores /mcp prefix", - setupRequest: func() *http.Request { - return httptest.NewRequest(http.MethodGet, "/", nil) - }, - expectedPath: "/mcp", - }, - { - name: "non-root path adds /mcp prefix", - setupRequest: func() *http.Request { - return httptest.NewRequest(http.MethodGet, "/readonly", nil) - }, - expectedPath: "/mcp/readonly", - }, - { - name: "nested path adds /mcp prefix", + name: "no base path uses request path", + cfg: &Config{}, setupRequest: func() *http.Request { return httptest.NewRequest(http.MethodGet, "/x/repos", nil) }, - expectedPath: "/mcp/x/repos", + expectedPath: "/x/repos", }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - req := tc.setupRequest() - path := GetEffectiveResourcePath(req) - - assert.Equal(t, tc.expectedPath, path) - }) - } -} - -func TestGetProtectedResourceData(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - cfg *Config - setupRequest func() *http.Request - resourcePath string - expectedResourceURL string - expectedAuthServer string - expectError bool - }{ { - name: "basic request with root resource path", - cfg: &Config{}, - setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Host = "api.example.com" - return req + name: "base path restored for root", + cfg: &Config{ + ResourcePath: "/mcp", }, - resourcePath: "/", - expectedResourceURL: "https://api.example.com/", - expectedAuthServer: DefaultAuthorizationServer, - }, - { - name: "basic request with custom resource path", - cfg: &Config{}, setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Host = "api.example.com" - return req + return httptest.NewRequest(http.MethodGet, "/", nil) }, - resourcePath: "/mcp", - expectedResourceURL: "https://api.example.com/mcp", - expectedAuthServer: DefaultAuthorizationServer, + expectedPath: "/mcp", }, { - name: "with custom base URL", + name: "base path restored for nested", cfg: &Config{ - BaseURL: "https://custom.example.com", + ResourcePath: "/mcp", }, setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Host = "api.example.com" - return req + return httptest.NewRequest(http.MethodGet, "/readonly", nil) }, - resourcePath: "/mcp", - expectedResourceURL: "https://custom.example.com/mcp", - expectedAuthServer: DefaultAuthorizationServer, + expectedPath: "/mcp/readonly", }, { - name: "with custom authorization server", + name: "base path preserved when already present", cfg: &Config{ - AuthorizationServer: "https://auth.example.com/oauth", + ResourcePath: "/mcp", }, setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Host = "api.example.com" - return req + return httptest.NewRequest(http.MethodGet, "/mcp/readonly/", nil) }, - resourcePath: "/mcp", - expectedResourceURL: "https://api.example.com/mcp", - expectedAuthServer: "https://auth.example.com/oauth", + expectedPath: "/mcp/readonly/", }, { - name: "base URL with trailing slash is trimmed", + name: "custom base path restored", cfg: &Config{ - BaseURL: "https://custom.example.com/", + ResourcePath: "/api", }, setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Host = "api.example.com" - return req - }, - resourcePath: "/mcp", - expectedResourceURL: "https://custom.example.com/mcp", - expectedAuthServer: DefaultAuthorizationServer, - }, - { - name: "nested resource path", - cfg: &Config{}, - setupRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/mcp/x/repos", nil) - req.Host = "api.example.com" - return req + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) }, - resourcePath: "/mcp/x/repos", - expectedResourceURL: "https://api.example.com/mcp/x/repos", - expectedAuthServer: DefaultAuthorizationServer, + expectedPath: "/api/x/repos", }, } @@ -315,20 +239,10 @@ func TestGetProtectedResourceData(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(tc.cfg) - require.NoError(t, err) - req := tc.setupRequest() - data, err := handler.GetProtectedResourceData(req, tc.resourcePath) - - if tc.expectError { - require.Error(t, err) - return - } + path := ResolveResourcePath(req, tc.cfg) - require.NoError(t, err) - assert.Equal(t, tc.expectedResourceURL, data.ResourceURL) - assert.Equal(t, tc.expectedAuthServer, data.AuthorizationServer) + assert.Equal(t, tc.expectedPath, path) }) } } @@ -354,6 +268,17 @@ func TestBuildResourceMetadataURL(t *testing.T) { resourcePath: "/", expectedURL: "https://api.example.com/.well-known/oauth-protected-resource", }, + { + name: "resource path preserves trailing slash", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/", + expectedURL: "https://api.example.com/.well-known/oauth-protected-resource/mcp/", + }, { name: "with custom resource path", cfg: &Config{}, @@ -442,7 +367,7 @@ func TestHandleProtectedResource(t *testing.T) { validateResponse: func(t *testing.T, body map[string]any) { t.Helper() assert.Equal(t, "GitHub MCP Server", body["resource_name"]) - assert.Equal(t, "https://api.example.com", body["resource"]) + assert.Equal(t, "https://api.example.com/", body["resource"]) authServers, ok := body["authorization_servers"].([]any) require.True(t, ok) @@ -488,6 +413,20 @@ func TestHandleProtectedResource(t *testing.T) { assert.Equal(t, "https://api.example.com/readonly", body["resource"]) }, }, + { + name: "path with trailing slash", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/mcp/", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/mcp/", body["resource"]) + }, + }, { name: "custom authorization server in response", cfg: &Config{ diff --git a/pkg/http/server.go b/pkg/http/server.go index 180cb75b5..17c6d7c0a 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -33,6 +33,10 @@ type ServerConfig struct { // If not set, the server will derive the URL from incoming request headers. BaseURL string + // ResourcePath is the externally visible base path for this server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + ResourcePath string + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -102,7 +106,8 @@ func RunHTTPServer(cfg ServerConfig) error { // Register OAuth protected resource metadata endpoints oauthCfg := &oauth.Config{ - BaseURL: cfg.BaseURL, + BaseURL: cfg.BaseURL, + ResourcePath: cfg.ResourcePath, } oauthHandler, err := oauth.NewAuthHandler(oauthCfg) if err != nil { From 203ebb3a0cccb454d0b477914d1f32d6d6efb8c5 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 29 Jan 2026 14:48:51 +0000 Subject: [PATCH 28/38] return auth handler to lib version --- pkg/http/oauth/oauth.go | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index f8fdd75fc..01f36078b 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,13 +3,13 @@ package oauth import ( - "encoding/json" "fmt" "net/http" "strings" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" ) @@ -95,34 +95,21 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { func (h *AuthHandler) metadataHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // CORS headers for browser-based clients - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - resourcePath := resolveResourcePath( strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix), h.cfg.ResourcePath, ) resourceURL := h.buildResourceURL(r, resourcePath) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(&oauthex.ProtectedResourceMetadata{ + metadata := &oauthex.ProtectedResourceMetadata{ Resource: resourceURL, AuthorizationServers: []string{h.cfg.AuthorizationServer}, ResourceName: "GitHub MCP Server", ScopesSupported: SupportedScopes, BearerMethodsSupported: []string{"header"}, - }) + } + + auth.ProtectedResourceMetadataHandler(metadata).ServeHTTP(w, r) }) } From 39903255ee471c8e1c59d59e7845ef70ee3da5d8 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Thu, 29 Jan 2026 15:12:08 +0000 Subject: [PATCH 29/38] rename to base-path flag --- cmd/github-mcp-server/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 386733bf6..31457edc0 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -102,7 +102,7 @@ var ( Host: viper.GetString("host"), Port: viper.GetInt("port"), BaseURL: viper.GetString("base-url"), - ResourcePath: viper.GetString("resource-path"), + ResourcePath: viper.GetString("base-path"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -138,7 +138,7 @@ func init() { rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") - rootCmd.PersistentFlags().String("resource-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") + rootCmd.PersistentFlags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -156,7 +156,7 @@ func init() { _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) _ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url")) - _ = viper.BindPFlag("resource-path", rootCmd.PersistentFlags().Lookup("resource-path")) + _ = viper.BindPFlag("base-path", rootCmd.PersistentFlags().Lookup("base-path")) // Add subcommands rootCmd.AddCommand(stdioCmd) From 4b690f5702b615f59ceb31bb3b1b412f2f260f1e Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Thu, 29 Jan 2026 17:01:01 +0100 Subject: [PATCH 30/38] Add scopes challenge middleware to HTTP handler and initialize global tool scope map --- cmd/github-mcp-server/main.go | 1 - pkg/http/handler.go | 14 ++- pkg/http/middleware/scope_challenge.go | 9 +- pkg/http/oauth/oauth.go | 3 + pkg/http/server.go | 48 +++++++-- pkg/scopes/map.go | 129 +++++++++++++++++++++++++ 6 files changed, 190 insertions(+), 14 deletions(-) create mode 100644 pkg/scopes/map.go diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 4a778f8bd..117df7c35 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -135,7 +135,6 @@ func init() { rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") - rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") // Add port flag to http command httpCmd.PersistentFlags().Int("port", 8082, "HTTP server port") diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 3a14bd624..d872974c1 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -90,11 +90,19 @@ func NewHTTPMcpHandler( } } +func (h *Handler) RegisterMiddleware(r chi.Router) { + r.Use( + middleware.ExtractUserToken(h.oauthCfg), + middleware.WithRequestConfig, + middleware.WithScopeChallenge(h.oauthCfg), + ) + + r.Use(middleware.WithScopeChallenge(h.oauthCfg)) +} + // RegisterRoutes registers the routes for the MCP server // URL-based values take precedence over header-based values func (h *Handler) RegisterRoutes(r chi.Router) { - r.Use(middleware.WithRequestConfig) - r.Mount("/", h) // Mount readonly and toolset routes r.With(withToolset).Mount("/x/{toolset}", h) @@ -144,7 +152,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Stateless: true, }) - middleware.ExtractUserToken(h.oauthCfg)(mcpHandler).ServeHTTP(w, r) + mcpHandler.ServeHTTP(w, r) } func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 435e5cebe..a97880375 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -12,14 +12,15 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/utils" ) // FetchScopesFromGitHubAPI fetches the OAuth scopes from the GitHub API by making // a HEAD request and reading the X-OAuth-Scopes header. This is used as a fallback // when scopes are not provided in the token info header. -func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiUrls *utils.APIHost) ([]string, error) { - baseUrl, err := apiUrls.BaseRESTURL(ctx) +func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { + baseUrl, err := apiHost.BaseRESTURL(ctx) if err != nil { return nil, err } @@ -56,7 +57,7 @@ func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiUrls *utils. // WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to // complete the request and returns a scope challenge if not. -func WithScopeChallenge(oauthCfg *oauth.Config, apiUrls *utils.APIHost) func(http.Handler) http.Handler { +func WithScopeChallenge(oauthCfg *oauth.Config) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -135,7 +136,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, apiUrls *utils.APIHost) func(htt } // Get OAuth scopes from GitHub API - activeScopes, err := FetchScopesFromGitHubAPI(ctx, tokenInfo.Token, apiUrls) + activeScopes, err := FetchScopesFromGitHubAPI(ctx, tokenInfo.Token, oauthCfg.ApiHosts) // Check if user has the required scopes if toolScopeInfo.HasAcceptedScope(activeScopes...) { diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 8934a21c6..f4ba3ab0d 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -40,6 +41,8 @@ var SupportedScopes = []string{ // Config holds the OAuth configuration for the MCP server. type Config struct { + ApiHosts utils.APIHostResolver + // BaseURL is the publicly accessible URL where this server is hosted. // This is used to construct the OAuth resource URL. BaseURL string diff --git a/pkg/http/server.go b/pkg/http/server.go index 180cb75b5..2d57ab0be 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -13,7 +13,9 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" @@ -98,21 +100,39 @@ func RunHTTPServer(cfg ServerConfig) error { nil, ) - r := chi.NewRouter() + // Initialize the global tool scope map + err = initGlobalToolScopeMap(t) + if err != nil { + return fmt.Errorf("failed to initialize tool scope map: %w", err) + } // Register OAuth protected resource metadata endpoints oauthCfg := &oauth.Config{ - BaseURL: cfg.BaseURL, + BaseURL: cfg.BaseURL, + ApiHosts: apiHost, } + + r := chi.NewRouter() + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) oauthHandler, err := oauth.NewAuthHandler(oauthCfg) if err != nil { return fmt.Errorf("failed to create OAuth handler: %w", err) } - oauthHandler.RegisterRoutes(r) - logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) - handler.RegisterRoutes(r) + r.Group(func(r chi.Router) { + // Register Middleware First, needs to be before route registration + handler.RegisterMiddleware(r) + + // Register MCP server routes + handler.RegisterRoutes(r) + logger.Info("MCP server routes registered") + }) + + r.Group(func(r chi.Router) { + // Register OAuth protected resource metadata endpoints + oauthHandler.RegisterRoutes(r) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) + }) addr := fmt.Sprintf(":%d", cfg.Port) httpSvr := http.Server{ @@ -144,3 +164,19 @@ func RunHTTPServer(cfg ServerConfig) error { logger.Info("server stopped gracefully") return nil } + +func initGlobalToolScopeMap(t translations.TranslationHelperFunc) error { + // Build inventory with all tools to extract scope information + inv, err := inventory.NewBuilder(). + SetTools(github.AllTools(t)). + Build() + + if err != nil { + return fmt.Errorf("failed to build inventory for tool scope map: %w", err) + } + + // Initialize the global scope map + scopes.SetToolScopeMapFromInventory(inv) + + return nil +} diff --git a/pkg/scopes/map.go b/pkg/scopes/map.go new file mode 100644 index 000000000..3c9833834 --- /dev/null +++ b/pkg/scopes/map.go @@ -0,0 +1,129 @@ +package scopes + +import "github.com/github/github-mcp-server/pkg/inventory" + +// ToolScopeMap maps tool names to their scope requirements. +type ToolScopeMap map[string]*ToolScopeInfo + +// ToolScopeInfo contains scope information for a single tool. +type ToolScopeInfo struct { + // RequiredScopes contains the scopes that are directly required by this tool. + RequiredScopes []string + + // AcceptedScopes contains all scopes that satisfy the requirements (including parent scopes). + AcceptedScopes []string +} + +// globalToolScopeMap is populated from inventory when SetToolScopeMapFromInventory is called +var globalToolScopeMap ToolScopeMap + +// SetToolScopeMapFromInventory builds and stores a tool scope map from an inventory. +// This should be called after building the inventory to make scopes available for middleware. +func SetToolScopeMapFromInventory(inv *inventory.Inventory) { + globalToolScopeMap = GetToolScopeMapFromInventory(inv) +} + +// SetGlobalToolScopeMap sets the global tool scope map directly. +// This is useful for testing when you don't have a full inventory. +func SetGlobalToolScopeMap(m ToolScopeMap) { + globalToolScopeMap = m +} + +// GetToolScopeMap returns the global tool scope map. +// Returns an empty map if SetToolScopeMapFromInventory hasn't been called yet. +func GetToolScopeMap() (ToolScopeMap, error) { + if globalToolScopeMap == nil { + return make(ToolScopeMap), nil + } + return globalToolScopeMap, nil +} + +// GetToolScopeInfo returns scope information for a specific tool from the global scope map. +func GetToolScopeInfo(toolName string) (*ToolScopeInfo, error) { + m, err := GetToolScopeMap() + if err != nil { + return nil, err + } + return m[toolName], nil +} + +// GetToolScopeMapFromInventory builds a tool scope map from an inventory. +// This extracts scope information from ServerTool.RequiredScopes and ServerTool.AcceptedScopes. +func GetToolScopeMapFromInventory(inv *inventory.Inventory) ToolScopeMap { + result := make(ToolScopeMap) + + // Get all tools from the inventory (both enabled and disabled) + // We need all tools for scope checking purposes + allTools := inv.AllTools() + for i := range allTools { + tool := &allTools[i] + if len(tool.RequiredScopes) > 0 || len(tool.AcceptedScopes) > 0 { + result[tool.Tool.Name] = &ToolScopeInfo{ + RequiredScopes: tool.RequiredScopes, + AcceptedScopes: tool.AcceptedScopes, + } + } + } + + return result +} + +// HasAcceptedScope checks if any of the provided user scopes satisfy the tool's requirements. +func (t *ToolScopeInfo) HasAcceptedScope(userScopes ...string) bool { + if t == nil || len(t.AcceptedScopes) == 0 { + return true // No scopes required + } + + userScopeSet := make(map[string]bool) + for _, scope := range userScopes { + userScopeSet[scope] = true + } + + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + return true + } + } + return false +} + +// MissingScopes returns the required scopes that are not present in the user's scopes. +func (t *ToolScopeInfo) MissingScopes(userScopes ...string) []string { + if t == nil || len(t.RequiredScopes) == 0 { + return nil + } + + // Create a set of user scopes for O(1) lookup + userScopeSet := make(map[string]bool, len(userScopes)) + for _, s := range userScopes { + userScopeSet[s] = true + } + + // Check if any accepted scope is present + hasAccepted := false + for _, scope := range t.AcceptedScopes { + if userScopeSet[scope] { + hasAccepted = true + break + } + } + + if hasAccepted { + return nil // User has sufficient scopes + } + + // Return required scopes as the minimum needed + missing := make([]string, len(t.RequiredScopes)) + copy(missing, t.RequiredScopes) + return missing +} + +// GetRequiredScopesSlice returns the required scopes as a slice of strings. +func (t *ToolScopeInfo) GetRequiredScopesSlice() []string { + if t == nil { + return nil + } + scopes := make([]string, len(t.RequiredScopes)) + copy(scopes, t.RequiredScopes) + return scopes +} From 4d679fddeb9cc3369096c848a443807fa7d01604 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Thu, 29 Jan 2026 17:21:34 +0100 Subject: [PATCH 31/38] Flags on the http command --- cmd/github-mcp-server/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 81b54f084..75fcffcf2 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -138,9 +138,9 @@ func init() { rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") // HTTP command flags - rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") - rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") - rootCmd.PersistentFlags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") + httpCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + httpCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") + httpCmd.PersistentFlags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) From 9a338d704e40b45bf6921f45a16af08bca8f0a86 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 10:46:19 +0100 Subject: [PATCH 32/38] Add tests for scope maps --- pkg/scopes/map_test.go | 194 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 pkg/scopes/map_test.go diff --git a/pkg/scopes/map_test.go b/pkg/scopes/map_test.go new file mode 100644 index 000000000..5f33cdda2 --- /dev/null +++ b/pkg/scopes/map_test.go @@ -0,0 +1,194 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetToolScopeMap(t *testing.T) { + // Reset and set up a test map + SetGlobalToolScopeMap(ToolScopeMap{ + "test_tool": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + m, err := GetToolScopeMap() + require.NoError(t, err) + require.NotNil(t, m) + require.Greater(t, len(m), 0, "expected at least one tool in the scope map") + + testTool, ok := m["test_tool"] + require.True(t, ok, "expected test_tool to be in the scope map") + assert.Contains(t, testTool.RequiredScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "read:org") + assert.Contains(t, testTool.AcceptedScopes, "admin:org") +} + +func TestGetToolScopeInfo(t *testing.T) { + // Set up test scope map + SetGlobalToolScopeMap(ToolScopeMap{ + "search_orgs": &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + }) + + info, err := GetToolScopeInfo("search_orgs") + require.NoError(t, err) + require.NotNil(t, info) + + // Non-existent tool should return nil + info, err = GetToolScopeInfo("nonexistent_tool") + require.NoError(t, err) + assert.Nil(t, info) +} + +func TestToolScopeInfo_HasAcceptedScope(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expected bool + }{ + { + name: "has exact required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expected: true, + }, + { + name: "has parent scope (admin:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"admin:org"}, + expected: true, + }, + { + name: "has parent scope (write:org grants read:org)", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"write:org"}, + expected: true, + }, + { + name: "missing required scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expected: false, + }, + { + name: "no scope required", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expected: true, + }, + { + name: "nil scope info", + scopeInfo: nil, + userScopes: []string{}, + expected: true, + }, + { + name: "repo scope for tool requiring repo", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"repo"}, + expected: true, + }, + { + name: "missing repo scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + userScopes: []string{"public_repo"}, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.scopeInfo.HasAcceptedScope(tc.userScopes...) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestToolScopeInfo_MissingScopes(t *testing.T) { + testCases := []struct { + name string + scopeInfo *ToolScopeInfo + userScopes []string + expectedLen int + expectedScopes []string + }{ + { + name: "has required scope - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"read:org"}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "missing scope", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{"read:org"}, + AcceptedScopes: []string{"read:org", "write:org", "admin:org"}, + }, + userScopes: []string{"repo"}, + expectedLen: 1, + expectedScopes: []string{"read:org"}, + }, + { + name: "no scope required - no missing", + scopeInfo: &ToolScopeInfo{ + RequiredScopes: []string{}, + AcceptedScopes: []string{}, + }, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + { + name: "nil scope info - no missing", + scopeInfo: nil, + userScopes: []string{}, + expectedLen: 0, + expectedScopes: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + missing := tc.scopeInfo.MissingScopes(tc.userScopes...) + assert.Len(t, missing, tc.expectedLen) + if tc.expectedScopes != nil { + for _, expected := range tc.expectedScopes { + assert.Contains(t, missing, expected) + } + } + }) + } +} From 7f6e0e87016734216567c5fc662a2db427a0bd12 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 15:10:12 +0100 Subject: [PATCH 33/38] Add scope challenge & pat filtering based on token scopes --- cmd/github-mcp-server/main.go | 3 ++ internal/ghmcp/server.go | 7 +--- pkg/context/token.go | 19 ++++++--- pkg/http/handler.go | 58 ++++++++++++++++++++++++-- pkg/http/handler_test.go | 27 ++++++++++++ pkg/http/middleware/scope_challenge.go | 8 +++- pkg/http/middleware/token.go | 5 ++- pkg/http/server.go | 14 ++++++- pkg/scopes/fetcher.go | 23 +++++++--- pkg/scopes/fetcher_test.go | 32 +++++++++++--- pkg/utils/api.go | 9 ++++ 11 files changed, 175 insertions(+), 30 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 75fcffcf2..60812a56b 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -109,6 +109,7 @@ var ( ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), RepoAccessCacheTTL: &ttl, + ScopeChallenge: viper.GetBool("scope-challenge"), } return ghhttp.RunHTTPServer(httpConfig) @@ -141,6 +142,7 @@ func init() { httpCmd.PersistentFlags().Int("port", 8082, "HTTP server port") httpCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") httpCmd.PersistentFlags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") + httpCmd.PersistentFlags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses and tool filtering based on token scopes") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -161,6 +163,7 @@ func init() { _ = viper.BindPFlag("port", httpCmd.PersistentFlags().Lookup("port")) _ = viper.BindPFlag("base-url", httpCmd.PersistentFlags().Lookup("base-url")) _ = viper.BindPFlag("base-path", httpCmd.PersistentFlags().Lookup("base-path")) + _ = viper.BindPFlag("scope-challenge", httpCmd.PersistentFlags().Lookup("scope-challenge")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a4316e6c9..b46787af3 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -366,13 +366,8 @@ func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, return nil, fmt.Errorf("failed to parse API host: %w", err) } - baseRestURL, err := apiHost.BaseRESTURL(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get base REST URL: %w", err) - } - fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: baseRestURL.String(), + APIHost: apiHost, }) return fetcher.FetchTokenScopes(ctx, token) diff --git a/pkg/context/token.go b/pkg/context/token.go index 0c86e38ab..e34879b9c 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -14,17 +14,26 @@ var tokenCtxKey tokenCtx = "tokenctx" type TokenInfo struct { Token string TokenType utils.TokenType + Scopes []string } // WithTokenInfo adds TokenInfo to the context -func WithTokenInfo(ctx context.Context, token string, tokenType utils.TokenType) context.Context { - return context.WithValue(ctx, tokenCtxKey, TokenInfo{Token: token, TokenType: tokenType}) +func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context { + return context.WithValue(ctx, tokenCtxKey, tokenInfo) +} + +func SetTokenScopes(ctx context.Context, scopes []string) context.Context { + if tokenInfo, ok := GetTokenInfo(ctx); ok { + tokenInfo.Scopes = scopes + return WithTokenInfo(ctx, tokenInfo) + } + return ctx } // GetTokenInfo retrieves the authentication token from the context -func GetTokenInfo(ctx context.Context) (TokenInfo, bool) { - if tokenInfo, ok := ctx.Value(tokenCtxKey).(TokenInfo); ok { +func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { + if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok { return tokenInfo, true } - return TokenInfo{}, false + return nil, false } diff --git a/pkg/http/handler.go b/pkg/http/handler.go index d872974c1..6eb6cd5e2 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -11,7 +11,9 @@ import ( "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -24,20 +26,29 @@ type Handler struct { config *ServerConfig deps github.ToolDependencies logger *slog.Logger + apiHosts utils.APIHostResolver t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc oauthCfg *oauth.Config + scopeFetcher scopes.FetcherInterface } type HandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc OAuthConfig *oauth.Config + ScopeFetcher scopes.FetcherInterface } type HandlerOption func(*HandlerOptions) +func WithScopeFetcher(f scopes.FetcherInterface) HandlerOption { + return func(o *HandlerOptions) { + o.ScopeFetcher = f + } +} + func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HandlerOption { return func(o *HandlerOptions) { o.GitHubMcpServerFactory = f @@ -62,6 +73,7 @@ func NewHTTPMcpHandler( deps github.ToolDependencies, t translations.TranslationHelperFunc, logger *slog.Logger, + apiHost utils.APIHostResolver, options ...HandlerOption) *Handler { opts := &HandlerOptions{} for _, o := range options { @@ -75,7 +87,14 @@ func NewHTTPMcpHandler( inventoryFactory := opts.InventoryFactory if inventoryFactory == nil { - inventoryFactory = DefaultInventoryFactory(cfg, t, nil) + inventoryFactory = DefaultInventoryFactory(cfg, t, nil, opts.ScopeFetcher) + } + + scopeFetcher := opts.ScopeFetcher + if scopeFetcher == nil { + scopeFetcher = scopes.NewFetcher(scopes.FetcherOptions{ + APIHost: apiHost, + }) } return &Handler{ @@ -83,10 +102,12 @@ func NewHTTPMcpHandler( config: cfg, deps: deps, logger: logger, + apiHosts: apiHost, t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, oauthCfg: opts.OAuthConfig, + scopeFetcher: scopeFetcher, } } @@ -94,10 +115,11 @@ func (h *Handler) RegisterMiddleware(r chi.Router) { r.Use( middleware.ExtractUserToken(h.oauthCfg), middleware.WithRequestConfig, - middleware.WithScopeChallenge(h.oauthCfg), ) - r.Use(middleware.WithScopeChallenge(h.oauthCfg)) + if h.config.ScopeChallenge { + r.Use(middleware.WithScopeChallenge(h.oauthCfg, h.scopeFetcher)) + } } // RegisterRoutes registers the routes for the MCP server @@ -159,7 +181,7 @@ func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies return github.NewMCPServer(r.Context(), cfg, deps, inventory) } -func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker) InventoryFactoryFunc { +func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc { return func(r *http.Request) (*inventory.Inventory, error) { b := github.NewInventory(t).WithDeprecatedAliases(github.DeprecatedToolAliases) @@ -170,6 +192,11 @@ func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFu } b = InventoryFiltersForRequest(r, b) + + if cfg.ScopeChallenge { + b = b.WithFilter(ScopeChallengeFilter(r, scopeFetcher)) + } + b.WithServerInstructions() return b.Build() @@ -198,3 +225,26 @@ func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *in return builder } + +func ScopeChallengeFilter(r *http.Request, fetcher scopes.FetcherInterface) inventory.ToolFilter { + ctx := r.Context() + + tokenInfo, ok := ghcontext.GetTokenInfo(ctx) + if !ok || tokenInfo == nil { + return nil + } + + // Fetch token scopes for scope-based tool filtering (PAT tokens only) + // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. + // Fine-grained PATs and other token types don't support this, so we skip filtering. + if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + return nil + } + + return github.CreateToolScopeFilter(scopesList) + } + + return nil +} diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index 3db764a85..47a6dd621 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -12,7 +12,9 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" @@ -32,6 +34,20 @@ func mockTool(name, toolsetID string, readOnly bool) inventory.ServerTool { } } +type allScopesFetcher struct{} + +func (f allScopesFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { + return []string{ + string(scopes.Repo), + string(scopes.WriteOrg), + string(scopes.User), + string(scopes.Gist), + string(scopes.Notifications), + }, nil +} + +var _ scopes.FetcherInterface = allScopesFetcher{} + func TestInventoryFiltersForRequest(t *testing.T) { tools := []inventory.ServerTool{ mockTool("get_file_contents", "repos", true), @@ -230,6 +246,8 @@ func TestHTTPHandlerRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var capturedInventory *inventory.Inventory + apiHost := utils.NewDefaultAPIHostResolver() + // Create inventory factory that captures the built inventory inventoryFactory := func(r *http.Request) (*inventory.Inventory, error) { builder := inventory.NewBuilder(). @@ -249,6 +267,8 @@ func TestHTTPHandlerRoutes(t *testing.T) { return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil } + allScopesFetcher := allScopesFetcher{} + // Create handler with our factories handler := NewHTTPMcpHandler( context.Background(), @@ -256,16 +276,23 @@ func TestHTTPHandlerRoutes(t *testing.T) { nil, // deps not needed for this test translations.NullTranslationHelper, slog.Default(), + apiHost, WithInventoryFactory(inventoryFactory), WithGitHubMCPServerFactory(mcpServerFactory), + WithScopeFetcher(allScopesFetcher), ) // Create router and register routes r := chi.NewRouter() + handler.RegisterMiddleware(r) handler.RegisterRoutes(r) // Create request req := httptest.NewRequest(http.MethodPost, tt.path, nil) + + // Ensure we're setting Authorization header for token context + req.Header.Set(headers.AuthorizationHeader, "Bearer ghp_testtoken") + for k, v := range tt.headers { req.Header.Set(k, v) } diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index e44714564..66edbefc3 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -57,7 +57,7 @@ func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiHost utils.A // WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to // complete the request and returns a scope challenge if not. -func WithScopeChallenge(oauthCfg *oauth.Config) func(http.Handler) http.Handler { +func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -136,7 +136,11 @@ func WithScopeChallenge(oauthCfg *oauth.Config) func(http.Handler) http.Handler } // Get OAuth scopes from GitHub API - activeScopes, err := FetchScopesFromGitHubAPI(ctx, tokenInfo.Token, oauthCfg.ApiHosts) + activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + next.ServeHTTP(w, r) + return + } // Check if user has the required scopes if toolScopeInfo.HasAcceptedScope(activeScopes...) { diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 4de9679fb..c362ea201 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -26,7 +26,10 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl } ctx := r.Context() - ctx = ghcontext.WithTokenInfo(ctx, token, tokenType) + ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ + Token: token, + TokenType: tokenType, + }) r = r.WithContext(ctx) next.ServeHTTP(w, r) diff --git a/pkg/http/server.go b/pkg/http/server.go index e93092248..ee7ad3c2f 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -57,6 +57,10 @@ type ServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // ScopeChallenge indicates if we should return OAuth scope challenges, and if we should perform + // tool filtering based on token scopes. + ScopeChallenge bool } func RunHTTPServer(cfg ServerConfig) error { @@ -117,8 +121,16 @@ func RunHTTPServer(cfg ServerConfig) error { ResourcePath: cfg.ResourcePath, } + severOptions := []HandlerOption{} + if cfg.ScopeChallenge { + scopeFetcher := scopes.NewFetcher(scopes.FetcherOptions{ + APIHost: apiHost, + }) + severOptions = append(severOptions, WithScopeFetcher(scopeFetcher)) + } + r := chi.NewRouter() - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg)) + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(severOptions, WithOAuthConfig(oauthCfg))...) oauthHandler, err := oauth.NewAuthHandler(oauthCfg) if err != nil { return fmt.Errorf("failed to create OAuth handler: %w", err) diff --git a/pkg/scopes/fetcher.go b/pkg/scopes/fetcher.go index 48e000179..cdd94fd0d 100644 --- a/pkg/scopes/fetcher.go +++ b/pkg/scopes/fetcher.go @@ -7,6 +7,8 @@ import ( "net/url" "strings" "time" + + "github.com/github/github-mcp-server/pkg/utils" ) // OAuthScopesHeader is the HTTP response header containing the token's OAuth scopes. @@ -23,14 +25,18 @@ type FetcherOptions struct { // APIHost is the GitHub API host (e.g., "https://api.github.com"). // Defaults to "https://api.github.com" if empty. - APIHost string + APIHost utils.APIHostResolver +} + +type FetcherInterface interface { + FetchTokenScopes(ctx context.Context, token string) ([]string, error) } // Fetcher retrieves token scopes from GitHub's API. // It uses an HTTP HEAD request to minimize bandwidth since we only need headers. type Fetcher struct { client *http.Client - apiHost string + apiHost utils.APIHostResolver } // NewFetcher creates a new scope fetcher with the given options. @@ -41,8 +47,8 @@ func NewFetcher(opts FetcherOptions) *Fetcher { } apiHost := opts.APIHost - if apiHost == "" { - apiHost = "https://api.github.com" + if apiHost == nil { + apiHost = utils.NewDefaultAPIHostResolver() } return &Fetcher{ @@ -61,8 +67,13 @@ func NewFetcher(opts FetcherOptions) *Fetcher { // Note: Fine-grained PATs don't return the X-OAuth-Scopes header, so an empty // slice is returned for those tokens. func (f *Fetcher) FetchTokenScopes(ctx context.Context, token string) ([]string, error) { + apiHostUrl, err := f.apiHost.BaseRESTURL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get API host URL: %w", err) + } + // Use a lightweight endpoint that requires authentication - endpoint, err := url.JoinPath(f.apiHost, "/") + endpoint, err := url.JoinPath(apiHostUrl.String(), "/") if err != nil { return nil, fmt.Errorf("failed to construct API URL: %w", err) } @@ -120,6 +131,6 @@ func FetchTokenScopes(ctx context.Context, token string) ([]string, error) { // FetchTokenScopesWithHost is a convenience function that creates a fetcher // for a specific API host and fetches the token scopes. -func FetchTokenScopesWithHost(ctx context.Context, token, apiHost string) ([]string, error) { +func FetchTokenScopesWithHost(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { return NewFetcher(FetcherOptions{APIHost: apiHost}).FetchTokenScopes(ctx, token) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 13feab5b0..5efc77c46 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -11,6 +12,23 @@ import ( "github.com/stretchr/testify/require" ) +type testApiHostResolver struct { + baseURL string +} + +func (t testApiHostResolver) BaseRESTURL(ctx context.Context) (*url.URL, error) { + return url.Parse(t.baseURL) +} +func (t testApiHostResolver) GraphqlURL(ctx context.Context) (*url.URL, error) { + return nil, nil +} +func (t testApiHostResolver) UploadURL(ctx context.Context) (*url.URL, error) { + return nil, nil +} +func (t testApiHostResolver) RawURL(ctx context.Context) (*url.URL, error) { + return nil, nil +} + func TestParseScopeHeader(t *testing.T) { tests := []struct { name string @@ -148,7 +166,7 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { defer server.Close() fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, + APIHost: testApiHostResolver{baseURL: server.URL}, }) scopes, err := fetcher.FetchTokenScopes(context.Background(), "test-token") @@ -170,7 +188,9 @@ func TestFetcher_DefaultOptions(t *testing.T) { fetcher := NewFetcher(FetcherOptions{}) // Verify default API host is set - assert.Equal(t, "https://api.github.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.com/", apiURL.String()) // Verify default HTTP client is set with timeout assert.NotNil(t, fetcher.client) @@ -189,10 +209,12 @@ func TestFetcher_CustomHTTPClient(t *testing.T) { func TestFetcher_CustomAPIHost(t *testing.T) { fetcher := NewFetcher(FetcherOptions{ - APIHost: "https://api.github.enterprise.com", + APIHost: testApiHostResolver{baseURL: "https://api.github.enterprise.com"}, }) - assert.Equal(t, "https://api.github.enterprise.com", fetcher.apiHost) + apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) + require.NoError(t, err) + assert.Equal(t, "https://api.github.enterprise.com", apiURL.String()) } func TestFetcher_ContextCancellation(t *testing.T) { @@ -203,7 +225,7 @@ func TestFetcher_ContextCancellation(t *testing.T) { defer server.Close() fetcher := NewFetcher(FetcherOptions{ - APIHost: server.URL, + APIHost: testApiHostResolver{baseURL: server.URL}, }) ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/utils/api.go b/pkg/utils/api.go index 4a33f1dd2..426f0b547 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -25,6 +25,15 @@ type APIHost struct { var _ APIHostResolver = APIHost{} +func NewDefaultAPIHostResolver() APIHostResolver { + a, err := newDotcomHost() + if err != nil { + // This should never happen + panic(fmt.Sprintf("failed to create default API host resolver: %v", err)) + } + return a +} + func NewAPIHost(s string) (APIHostResolver, error) { a, err := parseAPIHost(s) From 2b016e5d1b47971a6ec047757bbfe8a6feefbcd4 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 15:50:12 +0100 Subject: [PATCH 34/38] Add scope filtering if challenge is enabled --- pkg/http/handler.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index e6a994328..29399e4b3 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -211,7 +211,7 @@ func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelper b = InventoryFiltersForRequest(r, b) if cfg.ScopeChallenge { - b = b.WithFilter(ScopeChallengeFilter(r, scopeFetcher)) + b = ScopeChallengeFilter(b, r, scopeFetcher) } b.WithServerInstructions() @@ -243,12 +243,12 @@ func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *in return builder } -func ScopeChallengeFilter(r *http.Request, fetcher scopes.FetcherInterface) inventory.ToolFilter { +func ScopeChallengeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.FetcherInterface) *inventory.Builder { ctx := r.Context() tokenInfo, ok := ghcontext.GetTokenInfo(ctx) if !ok || tokenInfo == nil { - return nil + return b } // Fetch token scopes for scope-based tool filtering (PAT tokens only) @@ -257,11 +257,11 @@ func ScopeChallengeFilter(r *http.Request, fetcher scopes.FetcherInterface) inve if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) if err != nil { - return nil + return b } - return github.CreateToolScopeFilter(scopesList) + return b.WithFilter(github.CreateToolScopeFilter(scopesList)) } - return nil + return b } From f6d433715555a289668ff98b8cc7c3b2527cba22 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 16:43:42 +0100 Subject: [PATCH 35/38] Linter fixes and renamed scope challenge to PAT scope filter --- pkg/context/mcp_info.go | 4 ++-- pkg/context/token.go | 12 ++++++------ pkg/http/handler.go | 10 +++++----- pkg/http/middleware/mcp_parse.go | 2 +- pkg/http/middleware/scope_challenge.go | 4 ++-- pkg/http/oauth/oauth.go | 3 --- pkg/http/server.go | 1 - pkg/scopes/fetcher.go | 4 ++-- pkg/scopes/fetcher_test.go | 8 ++++---- 9 files changed, 22 insertions(+), 26 deletions(-) diff --git a/pkg/context/mcp_info.go b/pkg/context/mcp_info.go index d93cc8e81..ce5505682 100644 --- a/pkg/context/mcp_info.go +++ b/pkg/context/mcp_info.go @@ -25,8 +25,8 @@ type MCPMethodInfo struct { Arguments map[string]any } -// ContextWithMCPMethodInfo stores the MCPMethodInfo in the context. -func ContextWithMCPMethodInfo(ctx context.Context, info *MCPMethodInfo) context.Context { +// WithMCPMethodInfo stores the MCPMethodInfo in the context. +func WithMCPMethodInfo(ctx context.Context, info *MCPMethodInfo) context.Context { return context.WithValue(ctx, mcpMethodInfoCtxKey, info) } diff --git a/pkg/context/token.go b/pkg/context/token.go index e34879b9c..27f276740 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -12,9 +12,10 @@ type tokenCtx string var tokenCtxKey tokenCtx = "tokenctx" type TokenInfo struct { - Token string - TokenType utils.TokenType - Scopes []string + Token string + TokenType utils.TokenType + ScopesFetched bool + Scopes []string } // WithTokenInfo adds TokenInfo to the context @@ -22,12 +23,11 @@ func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context { return context.WithValue(ctx, tokenCtxKey, tokenInfo) } -func SetTokenScopes(ctx context.Context, scopes []string) context.Context { +func SetTokenScopes(ctx context.Context, scopes []string) { if tokenInfo, ok := GetTokenInfo(ctx); ok { tokenInfo.Scopes = scopes - return WithTokenInfo(ctx, tokenInfo) + tokenInfo.ScopesFetched = true } - return ctx } // GetTokenInfo retrieves the authentication token from the context diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 29399e4b3..55ea322c6 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -209,10 +209,7 @@ func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelper WithFeatureChecker(featureChecker) b = InventoryFiltersForRequest(r, b) - - if cfg.ScopeChallenge { - b = ScopeChallengeFilter(b, r, scopeFetcher) - } + b = PATScopeFilter(b, r, scopeFetcher) b.WithServerInstructions() @@ -243,7 +240,7 @@ func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *in return builder } -func ScopeChallengeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.FetcherInterface) *inventory.Builder { +func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.FetcherInterface) *inventory.Builder { ctx := r.Context() tokenInfo, ok := ghcontext.GetTokenInfo(ctx) @@ -260,6 +257,9 @@ func ScopeChallengeFilter(b *inventory.Builder, r *http.Request, fetcher scopes. return b } + // Store fetched scopes in context for downstream use + ghcontext.SetTokenScopes(ctx, scopesList) + return b.WithFilter(github.CreateToolScopeFilter(scopesList)) } diff --git a/pkg/http/middleware/mcp_parse.go b/pkg/http/middleware/mcp_parse.go index efff53a17..c82616b27 100644 --- a/pkg/http/middleware/mcp_parse.go +++ b/pkg/http/middleware/mcp_parse.go @@ -117,7 +117,7 @@ func WithMCPParse() func(http.Handler) http.Handler { } // Store the parsed info in context - ctx = ghcontext.ContextWithMCPMethodInfo(ctx, methodInfo) + ctx = ghcontext.WithMCPMethodInfo(ctx, methodInfo) next.ServeHTTP(w, r.WithContext(ctx)) } diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 66edbefc3..dab9cf003 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -20,12 +20,12 @@ import ( // a HEAD request and reading the X-OAuth-Scopes header. This is used as a fallback // when scopes are not provided in the token info header. func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { - baseUrl, err := apiHost.BaseRESTURL(ctx) + baseURL, err := apiHost.BaseRESTURL(ctx) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodHead, strings.TrimSuffix(baseUrl.String(), "/")+"/user", http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodHead, strings.TrimSuffix(baseURL.String(), "/")+"/user", http.NoBody) if err != nil { return nil, err } diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 269b615c7..ecdcf95ab 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/http/headers" - "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -40,8 +39,6 @@ var SupportedScopes = []string{ // Config holds the OAuth configuration for the MCP server. type Config struct { - ApiHosts utils.APIHostResolver - // BaseURL is the publicly accessible URL where this server is hosted. // This is used to construct the OAuth resource URL. BaseURL string diff --git a/pkg/http/server.go b/pkg/http/server.go index c7ba9203e..065961352 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -128,7 +128,6 @@ func RunHTTPServer(cfg ServerConfig) error { // Register OAuth protected resource metadata endpoints oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, - ApiHosts: apiHost, ResourcePath: cfg.ResourcePath, } diff --git a/pkg/scopes/fetcher.go b/pkg/scopes/fetcher.go index cdd94fd0d..80c0bb645 100644 --- a/pkg/scopes/fetcher.go +++ b/pkg/scopes/fetcher.go @@ -67,13 +67,13 @@ func NewFetcher(opts FetcherOptions) *Fetcher { // Note: Fine-grained PATs don't return the X-OAuth-Scopes header, so an empty // slice is returned for those tokens. func (f *Fetcher) FetchTokenScopes(ctx context.Context, token string) ([]string, error) { - apiHostUrl, err := f.apiHost.BaseRESTURL(ctx) + apiHostURL, err := f.apiHost.BaseRESTURL(ctx) if err != nil { return nil, fmt.Errorf("failed to get API host URL: %w", err) } // Use a lightweight endpoint that requires authentication - endpoint, err := url.JoinPath(apiHostUrl.String(), "/") + endpoint, err := url.JoinPath(apiHostURL.String(), "/") if err != nil { return nil, fmt.Errorf("failed to construct API URL: %w", err) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 5efc77c46..7f942a6e5 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -16,16 +16,16 @@ type testApiHostResolver struct { baseURL string } -func (t testApiHostResolver) BaseRESTURL(ctx context.Context) (*url.URL, error) { +func (t testApiHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { return url.Parse(t.baseURL) } -func (t testApiHostResolver) GraphqlURL(ctx context.Context) (*url.URL, error) { +func (t testApiHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { return nil, nil } -func (t testApiHostResolver) UploadURL(ctx context.Context) (*url.URL, error) { +func (t testApiHostResolver) UploadURL(_ context.Context) (*url.URL, error) { return nil, nil } -func (t testApiHostResolver) RawURL(ctx context.Context) (*url.URL, error) { +func (t testApiHostResolver) RawURL(_ context.Context) (*url.URL, error) { return nil, nil } From 2b1b9bb29407894bf700fc13c3642b64d763bf49 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 17:00:57 +0100 Subject: [PATCH 36/38] Linter issues. --- pkg/http/handler.go | 2 +- pkg/scopes/fetcher_test.go | 16 ++++++++-------- pkg/utils/api.go | 2 +- pkg/utils/token.go | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 55ea322c6..9ebd98892 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -202,7 +202,7 @@ func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies } // DefaultInventoryFactory creates the default inventory factory for HTTP mode -func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelperFunc, featureChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc { +func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, featureChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc { return func(r *http.Request) (*inventory.Inventory, error) { b := github.NewInventory(t). WithDeprecatedAliases(github.DeprecatedToolAliases). diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 7f942a6e5..3ea969236 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -12,20 +12,20 @@ import ( "github.com/stretchr/testify/require" ) -type testApiHostResolver struct { +type testAPIHostResolver struct { baseURL string } -func (t testApiHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { +func (t testAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { return url.Parse(t.baseURL) } -func (t testApiHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { +func (t testAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { return nil, nil } -func (t testApiHostResolver) UploadURL(_ context.Context) (*url.URL, error) { +func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { return nil, nil } -func (t testApiHostResolver) RawURL(_ context.Context) (*url.URL, error) { +func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { return nil, nil } @@ -166,7 +166,7 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { defer server.Close() fetcher := NewFetcher(FetcherOptions{ - APIHost: testApiHostResolver{baseURL: server.URL}, + APIHost: testAPIHostResolver{baseURL: server.URL}, }) scopes, err := fetcher.FetchTokenScopes(context.Background(), "test-token") @@ -209,7 +209,7 @@ func TestFetcher_CustomHTTPClient(t *testing.T) { func TestFetcher_CustomAPIHost(t *testing.T) { fetcher := NewFetcher(FetcherOptions{ - APIHost: testApiHostResolver{baseURL: "https://api.github.enterprise.com"}, + APIHost: testAPIHostResolver{baseURL: "https://api.github.enterprise.com"}, }) apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) @@ -225,7 +225,7 @@ func TestFetcher_ContextCancellation(t *testing.T) { defer server.Close() fetcher := NewFetcher(FetcherOptions{ - APIHost: testApiHostResolver{baseURL: server.URL}, + APIHost: testAPIHostResolver{baseURL: server.URL}, }) ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/utils/api.go b/pkg/utils/api.go index 426f0b547..19e908fc8 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -1,4 +1,4 @@ -package utils +package utils //nolint:revive //TODO: figure out a better name for this package import ( "context" diff --git a/pkg/utils/token.go b/pkg/utils/token.go index 62cacbf4d..fa3423942 100644 --- a/pkg/utils/token.go +++ b/pkg/utils/token.go @@ -1,4 +1,4 @@ -package utils +package utils //nolint:revive //TODO: figure out a better name for this package import ( "fmt" From dad1e71b30504c2bac37beaa0f1fb63f2561fb58 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 17:21:27 +0100 Subject: [PATCH 37/38] Remove unsused methoodod FetchScopesFromGitHubAPI, store active scopes in context --- pkg/http/middleware/scope_challenge.go | 44 ++------------------------ 1 file changed, 3 insertions(+), 41 deletions(-) diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index dab9cf003..da2f06752 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -2,13 +2,11 @@ package middleware import ( "bytes" - "context" "encoding/json" "fmt" "io" "net/http" "strings" - "time" ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/http/oauth" @@ -16,45 +14,6 @@ import ( "github.com/github/github-mcp-server/pkg/utils" ) -// FetchScopesFromGitHubAPI fetches the OAuth scopes from the GitHub API by making -// a HEAD request and reading the X-OAuth-Scopes header. This is used as a fallback -// when scopes are not provided in the token info header. -func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { - baseURL, err := apiHost.BaseRESTURL(ctx) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodHead, strings.TrimSuffix(baseURL.String(), "/")+"/user", http.NoBody) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+token) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - scopeHeader := resp.Header.Get("X-OAuth-Scopes") - if scopeHeader == "" { - return []string{}, nil - } - - // Parse comma-separated scopes and trim whitespace - rawScopes := strings.Split(scopeHeader, ",") - parsedScopes := make([]string, 0, len(rawScopes)) - for _, s := range rawScopes { - trimmed := strings.TrimSpace(s) - if trimmed != "" { - parsedScopes = append(parsedScopes, trimmed) - } - } - return parsedScopes, nil -} - // WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to // complete the request and returns a scope challenge if not. func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler { @@ -142,6 +101,9 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter return } + // Store active scopes in context for downstream use + ghcontext.SetTokenScopes(ctx, activeScopes) + // Check if user has the required scopes if toolScopeInfo.HasAcceptedScope(activeScopes...) { next.ServeHTTP(w, r) From ce05b87d6b5820d72462f8313c0bbe45e4af2532 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Fri, 30 Jan 2026 17:57:36 +0100 Subject: [PATCH 38/38] Require an API host when creating scope fetchers --- internal/ghmcp/server.go | 4 +--- pkg/http/handler.go | 5 ++--- pkg/http/handler_test.go | 3 ++- pkg/http/server.go | 10 ++++------ pkg/scopes/fetcher.go | 16 ++++++++-------- pkg/scopes/fetcher_test.go | 24 +++++++++++------------- pkg/utils/api.go | 9 --------- 7 files changed, 28 insertions(+), 43 deletions(-) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index cbc23546b..7ffb457ce 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -366,9 +366,7 @@ func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, return nil, fmt.Errorf("failed to parse API host: %w", err) } - fetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost, - }) + fetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) return fetcher.FetchTokenScopes(ctx, token) } diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 9ebd98892..c529f7405 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -93,9 +93,7 @@ func NewHTTPMcpHandler( scopeFetcher := opts.ScopeFetcher if scopeFetcher == nil { - scopeFetcher = scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost, - }) + scopeFetcher = scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) } inventoryFactory := opts.InventoryFactory @@ -121,6 +119,7 @@ func (h *Handler) RegisterMiddleware(r chi.Router) { r.Use( middleware.ExtractUserToken(h.oauthCfg), middleware.WithRequestConfig, + middleware.WithMCPParse(), ) if h.config.ScopeChallenge { diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index 136981c39..c92075569 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -276,7 +276,8 @@ func TestHTTPHandlerRoutes(t *testing.T) { // Create feature checker that reads from context (same as production) featureChecker := createHTTPFeatureChecker() - apiHost := utils.NewDefaultAPIHostResolver() + apiHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) // Create inventory factory that captures the built inventory inventoryFactory := func(r *http.Request) (*inventory.Inventory, error) { diff --git a/pkg/http/server.go b/pkg/http/server.go index 065961352..7a7ab46de 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -131,16 +131,14 @@ func RunHTTPServer(cfg ServerConfig) error { ResourcePath: cfg.ResourcePath, } - severOptions := []HandlerOption{} + serverOptions := []HandlerOption{} if cfg.ScopeChallenge { - scopeFetcher := scopes.NewFetcher(scopes.FetcherOptions{ - APIHost: apiHost, - }) - severOptions = append(severOptions, WithScopeFetcher(scopeFetcher)) + scopeFetcher := scopes.NewFetcher(apiHost, scopes.FetcherOptions{}) + serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher)) } r := chi.NewRouter() - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(severOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) oauthHandler, err := oauth.NewAuthHandler(oauthCfg) if err != nil { return fmt.Errorf("failed to create OAuth handler: %w", err) diff --git a/pkg/scopes/fetcher.go b/pkg/scopes/fetcher.go index 80c0bb645..458eaf7b7 100644 --- a/pkg/scopes/fetcher.go +++ b/pkg/scopes/fetcher.go @@ -40,17 +40,12 @@ type Fetcher struct { } // NewFetcher creates a new scope fetcher with the given options. -func NewFetcher(opts FetcherOptions) *Fetcher { +func NewFetcher(apiHost utils.APIHostResolver, opts FetcherOptions) *Fetcher { client := opts.HTTPClient if client == nil { client = &http.Client{Timeout: DefaultFetchTimeout} } - apiHost := opts.APIHost - if apiHost == nil { - apiHost = utils.NewDefaultAPIHostResolver() - } - return &Fetcher{ client: client, apiHost: apiHost, @@ -126,11 +121,16 @@ func ParseScopeHeader(header string) []string { // FetchTokenScopes is a convenience function that creates a default fetcher // and fetches the token scopes. func FetchTokenScopes(ctx context.Context, token string) ([]string, error) { - return NewFetcher(FetcherOptions{}).FetchTokenScopes(ctx, token) + apiHost, err := utils.NewAPIHost("https://api.github.com/") + if err != nil { + return nil, fmt.Errorf("failed to create default API host: %w", err) + } + + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } // FetchTokenScopesWithHost is a convenience function that creates a fetcher // for a specific API host and fetches the token scopes. func FetchTokenScopesWithHost(ctx context.Context, token string, apiHost utils.APIHostResolver) ([]string, error) { - return NewFetcher(FetcherOptions{APIHost: apiHost}).FetchTokenScopes(ctx, token) + return NewFetcher(apiHost, FetcherOptions{}).FetchTokenScopes(ctx, token) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 3ea969236..2d887d7a8 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -164,10 +164,8 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { server := httptest.NewServer(tt.handler) defer server.Close() - - fetcher := NewFetcher(FetcherOptions{ - APIHost: testAPIHostResolver{baseURL: server.URL}, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) scopes, err := fetcher.FetchTokenScopes(context.Background(), "test-token") @@ -185,12 +183,13 @@ func TestFetcher_FetchTokenScopes(t *testing.T) { } func TestFetcher_DefaultOptions(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{}) + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) // Verify default API host is set apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) require.NoError(t, err) - assert.Equal(t, "https://api.github.com/", apiURL.String()) + assert.Equal(t, "https://api.github.com", apiURL.String()) // Verify default HTTP client is set with timeout assert.NotNil(t, fetcher.client) @@ -200,7 +199,8 @@ func TestFetcher_DefaultOptions(t *testing.T) { func TestFetcher_CustomHTTPClient(t *testing.T) { customClient := &http.Client{Timeout: 5 * time.Second} - fetcher := NewFetcher(FetcherOptions{ + apiHost := testAPIHostResolver{baseURL: "https://api.github.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{ HTTPClient: customClient, }) @@ -208,9 +208,8 @@ func TestFetcher_CustomHTTPClient(t *testing.T) { } func TestFetcher_CustomAPIHost(t *testing.T) { - fetcher := NewFetcher(FetcherOptions{ - APIHost: testAPIHostResolver{baseURL: "https://api.github.enterprise.com"}, - }) + apiHost := testAPIHostResolver{baseURL: "https://api.github.enterprise.com"} + fetcher := NewFetcher(apiHost, FetcherOptions{}) apiURL, err := fetcher.apiHost.BaseRESTURL(context.Background()) require.NoError(t, err) @@ -224,9 +223,8 @@ func TestFetcher_ContextCancellation(t *testing.T) { })) defer server.Close() - fetcher := NewFetcher(FetcherOptions{ - APIHost: testAPIHostResolver{baseURL: server.URL}, - }) + apiHost := testAPIHostResolver{baseURL: server.URL} + fetcher := NewFetcher(apiHost, FetcherOptions{}) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately diff --git a/pkg/utils/api.go b/pkg/utils/api.go index 19e908fc8..24abf7342 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -25,15 +25,6 @@ type APIHost struct { var _ APIHostResolver = APIHost{} -func NewDefaultAPIHostResolver() APIHostResolver { - a, err := newDotcomHost() - if err != nil { - // This should never happen - panic(fmt.Sprintf("failed to create default API host resolver: %v", err)) - } - return a -} - func NewAPIHost(s string) (APIHostResolver, error) { a, err := parseAPIHost(s)