From dbdd3a08b6657102e117fa9aba7324df684ba593 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Wed, 11 Mar 2026 22:26:27 +0800 Subject: [PATCH 1/2] fix(main): bound HTTP response reads and use sync.Once for config init - Limit all HTTP response body reads to 1 MB using io.LimitReader to prevent memory exhaustion - Replace non-thread-safe configInitialized bool with sync.Once for safe concurrent initialization Co-Authored-By: Claude Opus 4.6 --- main.go | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/main.go b/main.go index a26b871..89d943c 100644 --- a/main.go +++ b/main.go @@ -28,21 +28,24 @@ import ( ) var ( - serverURL string - clientID string - tokenFile string - tokenStoreMode string - flagServerURL *string - flagClientID *string - flagTokenFile *string - flagTokenStore *string - configInitialized bool - retryClient *retry.Client - tokenStore credstore.Store[credstore.Token] + serverURL string + clientID string + tokenFile string + tokenStoreMode string + flagServerURL *string + flagClientID *string + flagTokenFile *string + flagTokenStore *string + configOnce sync.Once + retryClient *retry.Client + tokenStore credstore.Store[credstore.Token] ) const defaultKeyringService = "authgate-device-cli" +// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS). +const maxResponseBodySize = 1 << 20 // 1 MB + // Timeout configuration for different operations const ( deviceCodeRequestTimeout = 10 * time.Second @@ -107,11 +110,12 @@ func init() { // initConfig parses flags and initializes configuration // Separated from init() to avoid conflicts with test flag parsing func initConfig() { - if configInitialized { - return - } - configInitialized = true + configOnce.Do(func() { + doInitConfig() + }) +} +func doInitConfig() { flag.Parse() // Priority: flag > env > default @@ -438,7 +442,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -638,7 +642,7 @@ func exchangeDeviceCode( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -696,7 +700,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return fmt.Errorf("failed to read response: %w", err) } @@ -746,7 +750,7 @@ func refreshAccessToken( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return credstore.Token{}, fmt.Errorf("failed to read response: %w", err) } @@ -871,7 +875,7 @@ func makeAPICallWithAutoRefresh( defer resp.Body.Close() } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return fmt.Errorf("failed to read response: %w", err) } From b0c76329845cbff67ad91e274ed1fea17e5b59ab Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Fri, 13 Mar 2026 22:39:32 +0800 Subject: [PATCH 2/2] fix(http): detect oversized responses instead of silent truncation - Add readResponseBody helper with explicit size limit detection - Replace 5 inline io.LimitReader calls with the shared helper - Return errResponseTooLarge for responses exceeding 1MB - Add unit tests for boundary, oversized, small, and empty responses - Add end-to-end test for oversized response in requestDeviceCode Co-Authored-By: Claude Opus 4.6 --- main.go | 26 +++++++++++++---- main_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index 89d943c..4c502ec 100644 --- a/main.go +++ b/main.go @@ -46,6 +46,22 @@ const defaultKeyringService = "authgate-device-cli" // maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS). const maxResponseBodySize = 1 << 20 // 1 MB +// errResponseTooLarge indicates the server returned an oversized response body. +var errResponseTooLarge = errors.New("response body exceeds maximum allowed size") + +// readResponseBody reads the response body up to maxResponseBodySize. +// Returns errResponseTooLarge if the body exceeds the limit. +func readResponseBody(body io.Reader) ([]byte, error) { + data, err := io.ReadAll(io.LimitReader(body, maxResponseBodySize+1)) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if int64(len(data)) > maxResponseBodySize { + return nil, errResponseTooLarge + } + return data, nil +} + // Timeout configuration for different operations const ( deviceCodeRequestTimeout = 10 * time.Second @@ -442,7 +458,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) } defer resp.Body.Close() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -642,7 +658,7 @@ func exchangeDeviceCode( } defer resp.Body.Close() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -700,7 +716,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error } defer resp.Body.Close() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) + body, err := readResponseBody(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) } @@ -750,7 +766,7 @@ func refreshAccessToken( } defer resp.Body.Close() - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) + body, err := readResponseBody(resp.Body) if err != nil { return credstore.Token{}, fmt.Errorf("failed to read response: %w", err) } @@ -875,7 +891,7 @@ func makeAPICallWithAutoRefresh( defer resp.Body.Close() } - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) + body, err := readResponseBody(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) } diff --git a/main_test.go b/main_test.go index fe888ae..27066dd 100644 --- a/main_test.go +++ b/main_test.go @@ -1,13 +1,16 @@ package main import ( + "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "sync" "sync/atomic" "testing" @@ -593,3 +596,81 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) { t.Errorf("Expected 2 attempts (1 retry), got %d", finalCount) } } + +func TestReadResponseBody_ExactlyAtLimit(t *testing.T) { + data := make([]byte, maxResponseBodySize) + body, err := readResponseBody(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != int(maxResponseBodySize) { + t.Errorf("expected %d bytes, got %d", maxResponseBodySize, len(body)) + } +} + +func TestReadResponseBody_ExceedsLimit(t *testing.T) { + data := make([]byte, maxResponseBodySize+1) + _, err := readResponseBody(bytes.NewReader(data)) + if !errors.Is(err, errResponseTooLarge) { + t.Errorf("expected errResponseTooLarge, got %v", err) + } +} + +func TestReadResponseBody_SmallBody(t *testing.T) { + expected := "hello world" + body, err := readResponseBody(strings.NewReader(expected)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expected { + t.Errorf("expected %q, got %q", expected, string(body)) + } +} + +func TestReadResponseBody_EmptyBody(t *testing.T) { + body, err := readResponseBody(strings.NewReader("")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != 0 { + t.Errorf("expected empty body, got %d bytes", len(body)) + } +} + +func TestRequestDeviceCode_OversizedResponse(t *testing.T) { + // Server that returns a response larger than maxResponseBodySize + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Write more than maxResponseBodySize + data := make([]byte, maxResponseBodySize+100) + for i := range data { + data[i] = 'a' + } + _, _ = w.Write(data) + })) + defer server.Close() + + oldServerURL := serverURL + serverURL = server.URL + defer func() { serverURL = oldServerURL }() + + oldClient := retryClient + newClient, err := retry.NewBackgroundClient( + retry.WithHTTPClient(server.Client()), + ) + if err != nil { + t.Fatalf("failed to create retry client: %v", err) + } + retryClient = newClient + defer func() { retryClient = oldClient }() + + ctx := context.Background() + _, err = requestDeviceCode(ctx) + if err == nil { + t.Fatal("expected error for oversized response, got nil") + } + if !errors.Is(err, errResponseTooLarge) { + t.Errorf("expected errResponseTooLarge in error chain, got: %v", err) + } +}