Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ func initConfig() {
retry.WithHTTPClient(baseHTTPClient),
)
if err != nil {
panic(fmt.Sprintf("failed to create retry client: %v", err))
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
os.Exit(1)
}

// Initialize token store based on mode
Expand Down Expand Up @@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error {
// Demonstrate automatic refresh on 401
if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil {
// Check if error is due to expired refresh token
if err == ErrRefreshTokenExpired {
if errors.Is(err, ErrRefreshTokenExpired) {
d.ReAuthRequired()
storage, err = performDeviceFlow(ctx, d)
if err != nil {
Expand All @@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error {
d.Fatal(err)
return err
}
d.APICallOK()
} else {
d.APICallFailed(err)
}
Expand Down Expand Up @@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)

// performDeviceFlow performs the OAuth device authorization flow
func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) {
// Only TokenURL and ClientID are used downstream;
// requestDeviceCode() builds its own request directly.
config := &oauth2.Config{
ClientID: clientID,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: serverURL + endpointDeviceCode,
TokenURL: serverURL + endpointToken,
TokenURL: serverURL + endpointToken,
},
Scopes: []string{"read", "write"},
}

// Step 1: Request device code (with retry logic)
Expand Down Expand Up @@ -527,7 +527,7 @@ func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, e
}

// pollForTokenWithProgress polls for token while reporting progress via Displayer.
// Implements exponential backoff for slow_down errors per RFC 8628.
// Implements additive backoff (+5s) for slow_down errors per RFC 8628 §3.5.
func pollForTokenWithProgress(
ctx context.Context,
config *oauth2.Config,
Expand All @@ -540,9 +540,8 @@ func pollForTokenWithProgress(
interval = 5 // Default to 5 seconds per RFC 8628
}

// Exponential backoff state
// Backoff state
pollInterval := time.Duration(interval) * time.Second
backoffMultiplier := 1.0

Comment on lines +543 to 545
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block now implements a linear +5s backoff on slow_down (per RFC 8628), but the surrounding function-level documentation still mentions “exponential backoff”. Please update the doc/comment wording to match the new behavior so future readers aren’t misled.

Copilot uses AI. Check for mistakes.
pollTicker := time.NewTicker(pollInterval)
defer pollTicker.Stop()
Expand Down Expand Up @@ -572,12 +571,8 @@ func pollForTokenWithProgress(
continue

case oauthErrSlowDown:
// Server requests slower polling - increase interval
backoffMultiplier *= 1.5
pollInterval = min(
time.Duration(float64(pollInterval)*backoffMultiplier),
60*time.Second,
)
// Server requests slower polling - add 5s per RFC 8628 §3.5
pollInterval = min(pollInterval+5*time.Second, 60*time.Second)
pollTicker.Reset(pollInterval)
d.PollSlowDown(pollInterval)
continue
Expand Down Expand Up @@ -829,26 +824,27 @@ func makeAPICallWithAutoRefresh(
if err != nil {
return fmt.Errorf("API request failed: %w", err)
}
defer resp.Body.Close()

// If 401, try to refresh and retry
// If 401, drain and close the first response body to allow connection reuse,
// then refresh the token and retry.
if resp.StatusCode == http.StatusUnauthorized {
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()

Comment on lines +828 to +833
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resp.Body.Close() is deferred earlier, but the 401 branch also explicitly closes the body after draining it. That can lead to a double-close of the original response body. Consider restructuring so the initial defer resp.Body.Close() is only set for the non-401 path (or removed and replaced with explicit closes) while still draining+closing before the retry to allow connection reuse.

Copilot uses AI. Check for mistakes.
d.AccessTokenRejected()

newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d)
if err != nil {
// If refresh token is expired, propagate the error to trigger device flow
if err == ErrRefreshTokenExpired {
if errors.Is(err, ErrRefreshTokenExpired) {
return ErrRefreshTokenExpired
}
return fmt.Errorf("refresh failed: %w", err)
}

// Update storage in memory
// Note: newStorage has already been saved to disk by refreshAccessToken()
storage.AccessToken = newStorage.AccessToken
storage.RefreshToken = newStorage.RefreshToken
storage.ExpiresAt = newStorage.ExpiresAt
*storage = newStorage

d.TokenRefreshedRetrying()

Expand All @@ -869,6 +865,8 @@ func makeAPICallWithAutoRefresh(
return fmt.Errorf("retry failed: %w", err)
}
defer resp.Body.Close()
} else {
defer resp.Body.Close()
}

body, err := io.ReadAll(resp.Body)
Expand Down
20 changes: 3 additions & 17 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -271,7 +272,7 @@ func TestValidateTokenResponse(t *testing.T) {
t.Errorf("validateTokenResponse() expected error but got nil")
return
}
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf(
"validateTokenResponse() error = %v, want error containing %q",
err,
Expand All @@ -285,21 +286,6 @@ func TestValidateTokenResponse(t *testing.T) {
}
}

// contains checks if string s contains substr
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
}

func stringContains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

func TestRefreshAccessToken_RotationMode(t *testing.T) {
// Save original values
origServerURL := serverURL
Expand Down Expand Up @@ -526,7 +512,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) {
t.Errorf("refreshAccessToken() expected error but got nil")
return
}
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf(
"refreshAccessToken() error = %v, want error containing %q",
err,
Expand Down
24 changes: 13 additions & 11 deletions polling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) {
slowDownCount := atomic.Int32{}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts.Add(1)
n := attempts.Add(1)

// Return slow_down for first 2 attempts
if attempts.Load() <= 2 {
// Return slow_down on the first attempt
if n == 1 {
slowDownCount.Add(1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
Expand All @@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) {
return
}

// Return authorization_pending after slow_down
if attempts.Load() < 5 {
// Return authorization_pending on second attempt
if n == 2 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{
Expand All @@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
return
}

// Success
// Success on third attempt
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": testAccessToken,
Expand All @@ -125,7 +125,9 @@ func TestPollForToken_SlowDown(t *testing.T) {
Interval: 1, // 1 second for testing
}

ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
// After 1 slow_down the interval becomes 1+5=6s; with an additional authorization_pending
// before success, the third attempt occurs after ~1s + 6s + 6s ≈ 13s, so use a generous timeout.
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
defer cancel()

token, err := pollForTokenWithProgress(ctx, config, deviceAuth, tui.NoopDisplayer{})
Expand All @@ -137,14 +139,14 @@ func TestPollForToken_SlowDown(t *testing.T) {
t.Errorf("Expected access token 'test-access-token', got '%s'", token.AccessToken)
}

if slowDownCount.Load() < 2 {
t.Errorf("Expected at least 2 slow_down responses, got %d", slowDownCount.Load())
if slowDownCount.Load() < 1 {
t.Errorf("Expected at least 1 slow_down response, got %d", slowDownCount.Load())
}

// Verify that polling continued after slow_down
if attempts.Load() < 5 {
if attempts.Load() < 3 {
t.Errorf(
"Expected at least 5 attempts (2 slow_down + 2 pending + 1 success), got %d",
"Expected at least 3 attempts (1 slow_down + 1 pending + 1 success), got %d",
attempts.Load(),
)
}
Expand Down
Loading