@@ -98,10 +98,52 @@ func TestLogin(t *testing.T) {
9898 t .Errorf ("got output %q, want %q" , out , wantOut )
9999 }
100100 })
101+
102+ t .Run ("reuses stored oauth token before device flow" , func (t * testing.T ) {
103+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
104+ fmt .Fprintln (w , `{"data":{"currentUser":{"username":"alice"}}}` )
105+ }))
106+ defer s .Close ()
107+
108+ restoreStoredOAuthLoader (t , func (context.Context , string ) (* oauth.Token , error ) {
109+ return & oauth.Token {
110+ Endpoint : s .URL ,
111+ ClientID : oauth .DefaultClientID ,
112+ AccessToken : "oauth-token" ,
113+ ExpiresAt : time .Now ().Add (time .Hour ),
114+ }, nil
115+ })
116+
117+ startCalled := false
118+ var out bytes.Buffer
119+ err := loginCmd (context .Background (), loginParams {
120+ cfg : & config {Endpoint : s .URL },
121+ client : (& config {Endpoint : s .URL }).apiClient (nil , io .Discard ),
122+ endpoint : s .URL ,
123+ out : & out ,
124+ oauthClient : fakeOAuthClient {
125+ startErr : fmt .Errorf ("unexpected call to Start" ),
126+ startCalled : & startCalled ,
127+ },
128+ })
129+ if err != nil {
130+ t .Fatal (err )
131+ }
132+ if startCalled {
133+ t .Fatal ("expected stored oauth token to avoid device flow" )
134+ }
135+ gotOut := strings .TrimSpace (out .String ())
136+ wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n \n \n ✔︎ Authenticated with OAuth credentials"
137+ wantOut = strings .ReplaceAll (wantOut , "$ENDPOINT" , s .URL )
138+ if gotOut != wantOut {
139+ t .Errorf ("got output %q, want %q" , gotOut , wantOut )
140+ }
141+ })
101142}
102143
103144type fakeOAuthClient struct {
104- startErr error
145+ startErr error
146+ startCalled * bool
105147}
106148
107149func (f fakeOAuthClient ) ClientID () string {
@@ -113,6 +155,9 @@ func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfigura
113155}
114156
115157func (f fakeOAuthClient ) Start (context.Context , string , []string ) (* oauth.DeviceAuthResponse , error ) {
158+ if f .startCalled != nil {
159+ * f .startCalled = true
160+ }
116161 return nil , f .startErr
117162}
118163
@@ -158,3 +203,13 @@ func TestSelectLoginFlow(t *testing.T) {
158203 }
159204 })
160205}
206+
207+ func restoreStoredOAuthLoader (t * testing.T , loader func (context.Context , string ) (* oauth.Token , error )) {
208+ t .Helper ()
209+
210+ prev := loadStoredOAuthToken
211+ loadStoredOAuthToken = loader
212+ t .Cleanup (func () {
213+ loadStoredOAuthToken = prev
214+ })
215+ }
0 commit comments