diff --git a/main.go b/main.go index a26b871..de4d563 100644 --- a/main.go +++ b/main.go @@ -181,7 +181,8 @@ func initConfig() { retry.WithHTTPClient(baseHTTPClient), ) if err != nil { - panic(fmt.Sprintf("failed to create retry client: %v", err)) + fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err) + os.Exit(1) } // Initialize token store based on mode @@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error { // Demonstrate automatic refresh on 401 if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil { // Check if error is due to expired refresh token - if err == ErrRefreshTokenExpired { + if errors.Is(err, ErrRefreshTokenExpired) { d.ReAuthRequired() storage, err = performDeviceFlow(ctx, d) if err != nil { @@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error { d.Fatal(err) return err } - d.APICallOK() } else { d.APICallFailed(err) } @@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) // performDeviceFlow performs the OAuth device authorization flow func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) { + // Only TokenURL and ClientID are used downstream; + // requestDeviceCode() builds its own request directly. config := &oauth2.Config{ ClientID: clientID, Endpoint: oauth2.Endpoint{ - DeviceAuthURL: serverURL + endpointDeviceCode, - TokenURL: serverURL + endpointToken, + TokenURL: serverURL + endpointToken, }, - Scopes: []string{"read", "write"}, } // Step 1: Request device code (with retry logic) @@ -527,7 +527,7 @@ func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, e } // pollForTokenWithProgress polls for token while reporting progress via Displayer. -// Implements exponential backoff for slow_down errors per RFC 8628. +// Implements additive backoff (+5s) for slow_down errors per RFC 8628 §3.5. func pollForTokenWithProgress( ctx context.Context, config *oauth2.Config, @@ -540,9 +540,8 @@ func pollForTokenWithProgress( interval = 5 // Default to 5 seconds per RFC 8628 } - // Exponential backoff state + // Backoff state pollInterval := time.Duration(interval) * time.Second - backoffMultiplier := 1.0 pollTicker := time.NewTicker(pollInterval) defer pollTicker.Stop() @@ -572,12 +571,8 @@ func pollForTokenWithProgress( continue case oauthErrSlowDown: - // Server requests slower polling - increase interval - backoffMultiplier *= 1.5 - pollInterval = min( - time.Duration(float64(pollInterval)*backoffMultiplier), - 60*time.Second, - ) + // Server requests slower polling - add 5s per RFC 8628 §3.5 + pollInterval = min(pollInterval+5*time.Second, 60*time.Second) pollTicker.Reset(pollInterval) d.PollSlowDown(pollInterval) continue @@ -829,16 +824,19 @@ func makeAPICallWithAutoRefresh( if err != nil { return fmt.Errorf("API request failed: %w", err) } - defer resp.Body.Close() - // If 401, try to refresh and retry + // If 401, drain and close the first response body to allow connection reuse, + // then refresh the token and retry. if resp.StatusCode == http.StatusUnauthorized { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + d.AccessTokenRejected() newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d) if err != nil { // If refresh token is expired, propagate the error to trigger device flow - if err == ErrRefreshTokenExpired { + if errors.Is(err, ErrRefreshTokenExpired) { return ErrRefreshTokenExpired } return fmt.Errorf("refresh failed: %w", err) @@ -846,9 +844,7 @@ func makeAPICallWithAutoRefresh( // Update storage in memory // Note: newStorage has already been saved to disk by refreshAccessToken() - storage.AccessToken = newStorage.AccessToken - storage.RefreshToken = newStorage.RefreshToken - storage.ExpiresAt = newStorage.ExpiresAt + *storage = newStorage d.TokenRefreshedRetrying() @@ -869,6 +865,8 @@ func makeAPICallWithAutoRefresh( return fmt.Errorf("retry failed: %w", err) } defer resp.Body.Close() + } else { + defer resp.Body.Close() } body, err := io.ReadAll(resp.Body) diff --git a/main_test.go b/main_test.go index fe888ae..2221e00 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "sync" "sync/atomic" "testing" @@ -271,7 +272,7 @@ func TestValidateTokenResponse(t *testing.T) { t.Errorf("validateTokenResponse() expected error but got nil") return } - if tt.errContains != "" && !contains(err.Error(), tt.errContains) { + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { t.Errorf( "validateTokenResponse() error = %v, want error containing %q", err, @@ -285,21 +286,6 @@ func TestValidateTokenResponse(t *testing.T) { } } -// contains checks if string s contains substr -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || - (len(s) > 0 && len(substr) > 0 && stringContains(s, substr))) -} - -func stringContains(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - func TestRefreshAccessToken_RotationMode(t *testing.T) { // Save original values origServerURL := serverURL @@ -526,7 +512,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) { t.Errorf("refreshAccessToken() expected error but got nil") return } - if tt.errContains != "" && !contains(err.Error(), tt.errContains) { + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { t.Errorf( "refreshAccessToken() error = %v, want error containing %q", err, diff --git a/polling_test.go b/polling_test.go index 02e9d28..71a12db 100644 --- a/polling_test.go +++ b/polling_test.go @@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) { slowDownCount := atomic.Int32{} server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempts.Add(1) + n := attempts.Add(1) - // Return slow_down for first 2 attempts - if attempts.Load() <= 2 { + // Return slow_down on the first attempt + if n == 1 { slowDownCount.Add(1) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) { return } - // Return authorization_pending after slow_down - if attempts.Load() < 5 { + // Return authorization_pending on second attempt + if n == 2 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) _ = json.NewEncoder(w).Encode(map[string]string{ @@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) { return } - // Success + // Success on third attempt w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": testAccessToken, @@ -125,7 +125,9 @@ func TestPollForToken_SlowDown(t *testing.T) { Interval: 1, // 1 second for testing } - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + // After 1 slow_down the interval becomes 1+5=6s; with an additional authorization_pending + // before success, the third attempt occurs after ~1s + 6s + 6s ≈ 13s, so use a generous timeout. + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) defer cancel() token, err := pollForTokenWithProgress(ctx, config, deviceAuth, tui.NoopDisplayer{}) @@ -137,14 +139,14 @@ func TestPollForToken_SlowDown(t *testing.T) { t.Errorf("Expected access token 'test-access-token', got '%s'", token.AccessToken) } - if slowDownCount.Load() < 2 { - t.Errorf("Expected at least 2 slow_down responses, got %d", slowDownCount.Load()) + if slowDownCount.Load() < 1 { + t.Errorf("Expected at least 1 slow_down response, got %d", slowDownCount.Load()) } // Verify that polling continued after slow_down - if attempts.Load() < 5 { + if attempts.Load() < 3 { t.Errorf( - "Expected at least 5 attempts (2 slow_down + 2 pending + 1 success), got %d", + "Expected at least 3 attempts (1 slow_down + 1 pending + 1 success), got %d", attempts.Load(), ) }