diff --git a/main.go b/main.go index a26b871..4c502ec 100644 --- a/main.go +++ b/main.go @@ -28,21 +28,40 @@ import ( ) var ( - serverURL string - clientID string - tokenFile string - tokenStoreMode string - flagServerURL *string - flagClientID *string - flagTokenFile *string - flagTokenStore *string - configInitialized bool - retryClient *retry.Client - tokenStore credstore.Store[credstore.Token] + serverURL string + clientID string + tokenFile string + tokenStoreMode string + flagServerURL *string + flagClientID *string + flagTokenFile *string + flagTokenStore *string + configOnce sync.Once + retryClient *retry.Client + tokenStore credstore.Store[credstore.Token] ) const defaultKeyringService = "authgate-device-cli" +// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS). +const maxResponseBodySize = 1 << 20 // 1 MB + +// errResponseTooLarge indicates the server returned an oversized response body. +var errResponseTooLarge = errors.New("response body exceeds maximum allowed size") + +// readResponseBody reads the response body up to maxResponseBodySize. +// Returns errResponseTooLarge if the body exceeds the limit. +func readResponseBody(body io.Reader) ([]byte, error) { + data, err := io.ReadAll(io.LimitReader(body, maxResponseBodySize+1)) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if int64(len(data)) > maxResponseBodySize { + return nil, errResponseTooLarge + } + return data, nil +} + // Timeout configuration for different operations const ( deviceCodeRequestTimeout = 10 * time.Second @@ -107,11 +126,12 @@ func init() { // initConfig parses flags and initializes configuration // Separated from init() to avoid conflicts with test flag parsing func initConfig() { - if configInitialized { - return - } - configInitialized = true + configOnce.Do(func() { + doInitConfig() + }) +} +func doInitConfig() { flag.Parse() // Priority: flag > env > default @@ -438,7 +458,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -638,7 +658,7 @@ func exchangeDeviceCode( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -696,7 +716,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) } @@ -746,7 +766,7 @@ func refreshAccessToken( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return credstore.Token{}, fmt.Errorf("failed to read response: %w", err) } @@ -871,7 +891,7 @@ func makeAPICallWithAutoRefresh( defer resp.Body.Close() } - body, err := io.ReadAll(resp.Body) + body, err := readResponseBody(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) } diff --git a/main_test.go b/main_test.go index fe888ae..27066dd 100644 --- a/main_test.go +++ b/main_test.go @@ -1,13 +1,16 @@ package main import ( + "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "sync" "sync/atomic" "testing" @@ -593,3 +596,81 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) { t.Errorf("Expected 2 attempts (1 retry), got %d", finalCount) } } + +func TestReadResponseBody_ExactlyAtLimit(t *testing.T) { + data := make([]byte, maxResponseBodySize) + body, err := readResponseBody(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != int(maxResponseBodySize) { + t.Errorf("expected %d bytes, got %d", maxResponseBodySize, len(body)) + } +} + +func TestReadResponseBody_ExceedsLimit(t *testing.T) { + data := make([]byte, maxResponseBodySize+1) + _, err := readResponseBody(bytes.NewReader(data)) + if !errors.Is(err, errResponseTooLarge) { + t.Errorf("expected errResponseTooLarge, got %v", err) + } +} + +func TestReadResponseBody_SmallBody(t *testing.T) { + expected := "hello world" + body, err := readResponseBody(strings.NewReader(expected)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expected { + t.Errorf("expected %q, got %q", expected, string(body)) + } +} + +func TestReadResponseBody_EmptyBody(t *testing.T) { + body, err := readResponseBody(strings.NewReader("")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(body) != 0 { + t.Errorf("expected empty body, got %d bytes", len(body)) + } +} + +func TestRequestDeviceCode_OversizedResponse(t *testing.T) { + // Server that returns a response larger than maxResponseBodySize + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Write more than maxResponseBodySize + data := make([]byte, maxResponseBodySize+100) + for i := range data { + data[i] = 'a' + } + _, _ = w.Write(data) + })) + defer server.Close() + + oldServerURL := serverURL + serverURL = server.URL + defer func() { serverURL = oldServerURL }() + + oldClient := retryClient + newClient, err := retry.NewBackgroundClient( + retry.WithHTTPClient(server.Client()), + ) + if err != nil { + t.Fatalf("failed to create retry client: %v", err) + } + retryClient = newClient + defer func() { retryClient = oldClient }() + + ctx := context.Background() + _, err = requestDeviceCode(ctx) + if err == nil { + t.Fatal("expected error for oversized response, got nil") + } + if !errors.Is(err, errResponseTooLarge) { + t.Errorf("expected errResponseTooLarge in error chain, got: %v", err) + } +}