Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,40 @@ import (
)

var (
serverURL string
clientID string
tokenFile string
tokenStoreMode string
flagServerURL *string
flagClientID *string
flagTokenFile *string
flagTokenStore *string
configInitialized bool
retryClient *retry.Client
tokenStore credstore.Store[credstore.Token]
serverURL string
clientID string
tokenFile string
tokenStoreMode string
flagServerURL *string
flagClientID *string
flagTokenFile *string
flagTokenStore *string
configOnce sync.Once
retryClient *retry.Client
tokenStore credstore.Store[credstore.Token]
)

const defaultKeyringService = "authgate-device-cli"

// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS).
const maxResponseBodySize = 1 << 20 // 1 MB
Comment on lines +46 to +47
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces a new security boundary (maxResponseBodySize) but there are no tests exercising the oversized-response path (e.g., server returns >1MB, client should fail in a well-defined way). Since this file already has extensive HTTP/flow tests, consider adding a test case that returns a payload just over the limit and asserts the resulting error (ideally a dedicated "response too large" error if you add explicit truncation detection).

Copilot uses AI. Check for mistakes.

// errResponseTooLarge indicates the server returned an oversized response body.
var errResponseTooLarge = errors.New("response body exceeds maximum allowed size")

// readResponseBody reads the response body up to maxResponseBodySize.
// Returns errResponseTooLarge if the body exceeds the limit.
func readResponseBody(body io.Reader) ([]byte, error) {
data, err := io.ReadAll(io.LimitReader(body, maxResponseBodySize+1))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if int64(len(data)) > maxResponseBodySize {
return nil, errResponseTooLarge
}
return data, nil
}

// Timeout configuration for different operations
const (
deviceCodeRequestTimeout = 10 * time.Second
Expand Down Expand Up @@ -107,11 +126,12 @@ func init() {
// initConfig parses flags and initializes configuration
// Separated from init() to avoid conflicts with test flag parsing
func initConfig() {
if configInitialized {
return
}
configInitialized = true
configOnce.Do(func() {
doInitConfig()
})
}

func doInitConfig() {
flag.Parse()

// Priority: flag > env > default
Expand Down Expand Up @@ -438,7 +458,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -638,7 +658,7 @@ func exchangeDeviceCode(
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -696,7 +716,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -746,7 +766,7 @@ func refreshAccessToken(
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return credstore.Token{}, fmt.Errorf("failed to read response: %w", err)
}
Expand Down Expand Up @@ -871,7 +891,7 @@ func makeAPICallWithAutoRefresh(
defer resp.Body.Close()
}

body, err := io.ReadAll(resp.Body)
body, err := readResponseBody(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
Expand Down
81 changes: 81 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package main

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -593,3 +596,81 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) {
t.Errorf("Expected 2 attempts (1 retry), got %d", finalCount)
}
}

func TestReadResponseBody_ExactlyAtLimit(t *testing.T) {
data := make([]byte, maxResponseBodySize)
body, err := readResponseBody(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(body) != int(maxResponseBodySize) {
t.Errorf("expected %d bytes, got %d", maxResponseBodySize, len(body))
}
}

func TestReadResponseBody_ExceedsLimit(t *testing.T) {
data := make([]byte, maxResponseBodySize+1)
_, err := readResponseBody(bytes.NewReader(data))
if !errors.Is(err, errResponseTooLarge) {
t.Errorf("expected errResponseTooLarge, got %v", err)
}
}

func TestReadResponseBody_SmallBody(t *testing.T) {
expected := "hello world"
body, err := readResponseBody(strings.NewReader(expected))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expected {
t.Errorf("expected %q, got %q", expected, string(body))
}
}

func TestReadResponseBody_EmptyBody(t *testing.T) {
body, err := readResponseBody(strings.NewReader(""))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(body) != 0 {
t.Errorf("expected empty body, got %d bytes", len(body))
}
}

func TestRequestDeviceCode_OversizedResponse(t *testing.T) {
// Server that returns a response larger than maxResponseBodySize
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// Write more than maxResponseBodySize
data := make([]byte, maxResponseBodySize+100)
for i := range data {
data[i] = 'a'
}
_, _ = w.Write(data)
}))
defer server.Close()

oldServerURL := serverURL
serverURL = server.URL
defer func() { serverURL = oldServerURL }()

oldClient := retryClient
newClient, err := retry.NewBackgroundClient(
retry.WithHTTPClient(server.Client()),
)
if err != nil {
t.Fatalf("failed to create retry client: %v", err)
}
retryClient = newClient
defer func() { retryClient = oldClient }()

ctx := context.Background()
_, err = requestDeviceCode(ctx)
if err == nil {
t.Fatal("expected error for oversized response, got nil")
}
if !errors.Is(err, errResponseTooLarge) {
t.Errorf("expected errResponseTooLarge in error chain, got: %v", err)
}
}
Loading