Skip to content

Commit 9af34a9

Browse files
committed
check for OAuth token precense before starting a new OAuth flow
1 parent 51440fb commit 9af34a9

3 files changed

Lines changed: 72 additions & 3 deletions

File tree

cmd/src/login.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func loginCmd(ctx context.Context, p loginParams) error {
105105
return flow(ctx, p)
106106
}
107107

108-
// selectLoginFlow decides what login flow to run based on configigured AuthMode.
108+
// selectLoginFlow decides what login flow to run based on configured AuthMode.
109109
func selectLoginFlow(_ context.Context, p loginParams) (loginFlowKind, loginFlow) {
110110
endpointArg := cleanEndpoint(p.endpoint)
111111

cmd/src/login_oauth.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/sourcegraph/src-cli/internal/oauth"
1414
)
1515

16+
var loadStoredOAuthToken = oauth.LoadToken
17+
1618
func runOAuthLogin(ctx context.Context, p loginParams) error {
1719
endpointArg := cleanEndpoint(p.endpoint)
1820
client, err := oauthLoginClient(ctx, p, endpointArg)
@@ -32,7 +34,15 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
3234
return nil
3335
}
3436

37+
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
38+
// and use it if one is present.
39+
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
3540
func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) {
41+
// if we have a stored token, used it. Otherwise run the device flow
42+
if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil {
43+
return newOAuthAPIClient(p, endpoint, token), nil
44+
}
45+
3646
token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient)
3747
if err != nil {
3848
return nil, err
@@ -43,6 +53,10 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
4353
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
4454
}
4555

56+
return newOAuthAPIClient(p, endpoint, token), nil
57+
}
58+
59+
func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client {
4660
return api.NewClient(api.ClientOpts{
4761
Endpoint: endpoint,
4862
AdditionalHeaders: p.cfg.AdditionalHeaders,
@@ -51,7 +65,7 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
5165
ProxyURL: p.cfg.ProxyURL,
5266
ProxyPath: p.cfg.ProxyPath,
5367
OAuthToken: token,
54-
}), nil
68+
})
5569
}
5670

5771
func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) {

cmd/src/login_test.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,52 @@ func TestLogin(t *testing.T) {
9898
t.Errorf("got output %q, want %q", out, wantOut)
9999
}
100100
})
101+
102+
t.Run("reuses stored oauth token before device flow", func(t *testing.T) {
103+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
105+
}))
106+
defer s.Close()
107+
108+
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
109+
return &oauth.Token{
110+
Endpoint: s.URL,
111+
ClientID: oauth.DefaultClientID,
112+
AccessToken: "oauth-token",
113+
ExpiresAt: time.Now().Add(time.Hour),
114+
}, nil
115+
})
116+
117+
startCalled := false
118+
var out bytes.Buffer
119+
err := loginCmd(context.Background(), loginParams{
120+
cfg: &config{Endpoint: s.URL},
121+
client: (&config{Endpoint: s.URL}).apiClient(nil, io.Discard),
122+
endpoint: s.URL,
123+
out: &out,
124+
oauthClient: fakeOAuthClient{
125+
startErr: fmt.Errorf("unexpected call to Start"),
126+
startCalled: &startCalled,
127+
},
128+
})
129+
if err != nil {
130+
t.Fatal(err)
131+
}
132+
if startCalled {
133+
t.Fatal("expected stored oauth token to avoid device flow")
134+
}
135+
gotOut := strings.TrimSpace(out.String())
136+
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials"
137+
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
138+
if gotOut != wantOut {
139+
t.Errorf("got output %q, want %q", gotOut, wantOut)
140+
}
141+
})
101142
}
102143

103144
type fakeOAuthClient struct {
104-
startErr error
145+
startErr error
146+
startCalled *bool
105147
}
106148

107149
func (f fakeOAuthClient) ClientID() string {
@@ -113,6 +155,9 @@ func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfigura
113155
}
114156

115157
func (f fakeOAuthClient) Start(context.Context, string, []string) (*oauth.DeviceAuthResponse, error) {
158+
if f.startCalled != nil {
159+
*f.startCalled = true
160+
}
116161
return nil, f.startErr
117162
}
118163

@@ -158,3 +203,13 @@ func TestSelectLoginFlow(t *testing.T) {
158203
}
159204
})
160205
}
206+
207+
func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, string) (*oauth.Token, error)) {
208+
t.Helper()
209+
210+
prev := loadStoredOAuthToken
211+
loadStoredOAuthToken = loader
212+
t.Cleanup(func() {
213+
loadStoredOAuthToken = prev
214+
})
215+
}

0 commit comments

Comments
 (0)