Skip to content

Commit e0576a7

Browse files
committed
feat/auth: add src auth token (#1275)
* add TokenRefresher and NewTransport - TokenRefresher centralizes refreshing of OAuth tokens - NewTransport creates an OAuth transport while making sure it is initialized with a TokenRefresher * add cmd `src auth token` that prints access token or oauth token * set refresh window to 5 min (cherry picked from commit d206288)
1 parent 841bad0 commit e0576a7

File tree

7 files changed

+282
-36
lines changed

7 files changed

+282
-36
lines changed

cmd/src/auth.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
)
7+
8+
var authCommands commander
9+
10+
func init() {
11+
usage := `'src auth' provides authentication-related helper commands.
12+
13+
Usage:
14+
15+
src auth command [command options]
16+
17+
The commands are:
18+
19+
token prints the current authentication token
20+
21+
Use "src auth [command] -h" for more information about a command.
22+
`
23+
24+
flagSet := flag.NewFlagSet("auth", flag.ExitOnError)
25+
handler := func(args []string) error {
26+
authCommands.run(flagSet, "src auth", usage, args)
27+
return nil
28+
}
29+
30+
commands = append(commands, &command{
31+
flagSet: flagSet,
32+
handler: handler,
33+
usageFunc: func() {
34+
fmt.Println(usage)
35+
},
36+
})
37+
}

cmd/src/auth_token.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
8+
"github.com/sourcegraph/sourcegraph/lib/errors"
9+
10+
"github.com/sourcegraph/src-cli/internal/oauth"
11+
)
12+
13+
var (
14+
loadOAuthToken = oauth.LoadToken
15+
newOAuthTokenRefresher = func(token *oauth.Token) oauthTokenRefresher {
16+
return oauth.NewTokenRefresher(token)
17+
}
18+
)
19+
20+
type oauthTokenRefresher interface {
21+
GetToken(ctx context.Context) (oauth.Token, error)
22+
}
23+
24+
func init() {
25+
flagSet := flag.NewFlagSet("token", flag.ExitOnError)
26+
usageFunc := func() {
27+
fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src auth token':\n")
28+
flagSet.PrintDefaults()
29+
}
30+
31+
handler := func(args []string) error {
32+
if err := flagSet.Parse(args); err != nil {
33+
return err
34+
}
35+
36+
token, err := resolveAuthToken(context.Background(), cfg)
37+
if err != nil {
38+
return err
39+
}
40+
41+
fmt.Println(token)
42+
return nil
43+
}
44+
45+
authCommands = append(authCommands, &command{
46+
flagSet: flagSet,
47+
handler: handler,
48+
usageFunc: usageFunc,
49+
})
50+
}
51+
52+
func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
53+
if cfg.accessToken != "" {
54+
return cfg.accessToken, nil
55+
}
56+
57+
oauthToken, err := loadOAuthToken(ctx, cfg.endpointURL)
58+
if err != nil {
59+
return "", errors.Wrap(err, "error loading OAuth token; set SRC_ACCESS_TOKEN or run `src login`")
60+
}
61+
62+
token, err := newOAuthTokenRefresher(oauthToken).GetToken(ctx)
63+
if err != nil {
64+
return "", errors.Wrap(err, "refreshing OAuth token")
65+
}
66+
67+
return token.AccessToken, nil
68+
}

cmd/src/auth_token_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"testing"
8+
9+
"github.com/sourcegraph/src-cli/internal/oauth"
10+
)
11+
12+
func TestResolveAuthToken(t *testing.T) {
13+
t.Run("uses configured access token before keyring", func(t *testing.T) {
14+
reset := stubAuthTokenDependencies(t)
15+
defer reset()
16+
17+
newRefresherCalled := false
18+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
19+
newRefresherCalled = true
20+
return fakeOAuthTokenRefresher{}
21+
}
22+
23+
token, err := resolveAuthToken(context.Background(), &config{
24+
accessToken: "access-token",
25+
endpointURL: mustParseURL(t, "https://example.com"),
26+
})
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
if token != "access-token" {
31+
t.Fatalf("token = %q, want %q", token, "access-token")
32+
}
33+
if newRefresherCalled {
34+
t.Fatal("expected OAuth token refresher not to be created")
35+
}
36+
})
37+
38+
t.Run("uses stored oauth token", func(t *testing.T) {
39+
reset := stubAuthTokenDependencies(t)
40+
defer reset()
41+
42+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
43+
return &oauth.Token{
44+
AccessToken: "oauth-token",
45+
}, nil
46+
}
47+
48+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
49+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}}
50+
}
51+
52+
token, err := resolveAuthToken(context.Background(), &config{
53+
endpointURL: mustParseURL(t, "https://example.com"),
54+
})
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
if token != "oauth-token" {
59+
t.Fatalf("token = %q, want %q", token, "oauth-token")
60+
}
61+
})
62+
63+
t.Run("refreshes expiring oauth token", func(t *testing.T) {
64+
reset := stubAuthTokenDependencies(t)
65+
defer reset()
66+
67+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
68+
return &oauth.Token{AccessToken: "old-token"}, nil
69+
}
70+
71+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
72+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}}
73+
}
74+
75+
token, err := resolveAuthToken(context.Background(), &config{
76+
endpointURL: mustParseURL(t, "https://example.com"),
77+
})
78+
if err != nil {
79+
t.Fatal(err)
80+
}
81+
if token != "new-token" {
82+
t.Fatalf("token = %q, want %q", token, "new-token")
83+
}
84+
})
85+
86+
t.Run("returns refresh error when shared refresh logic fails", func(t *testing.T) {
87+
reset := stubAuthTokenDependencies(t)
88+
defer reset()
89+
90+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
91+
return &oauth.Token{AccessToken: "old-token"}, nil
92+
}
93+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
94+
return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")}
95+
}
96+
97+
_, err := resolveAuthToken(context.Background(), &config{
98+
endpointURL: mustParseURL(t, "https://example.com"),
99+
})
100+
if err == nil {
101+
t.Fatal("expected error")
102+
}
103+
})
104+
}
105+
106+
func stubAuthTokenDependencies(t *testing.T) func() {
107+
t.Helper()
108+
109+
prevLoad := loadOAuthToken
110+
prevNewRefresher := newOAuthTokenRefresher
111+
112+
return func() {
113+
loadOAuthToken = prevLoad
114+
newOAuthTokenRefresher = prevNewRefresher
115+
}
116+
}
117+
118+
type fakeOAuthTokenRefresher struct {
119+
token oauth.Token
120+
err error
121+
}
122+
123+
func (r fakeOAuthTokenRefresher) GetToken(context.Context) (oauth.Token, error) {
124+
if r.err != nil {
125+
return oauth.Token{}, r.err
126+
}
127+
return r.token, nil
128+
}

cmd/src/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ The options are:
5050
5151
The commands are:
5252
53+
auth authentication helper commands
5354
api interacts with the Sourcegraph GraphQL API
5455
batch manages batch changes
5556
code-intel manages code intelligence data

internal/api/api.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,7 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
111111
}
112112

113113
if opts.AccessToken == "" && opts.OAuthToken != nil {
114-
transport = &oauth.Transport{
115-
Base: transport,
116-
Token: opts.OAuthToken,
117-
}
114+
transport = oauth.NewTransport(transport, opts.OAuthToken)
118115
}
119116

120117
return transport

internal/oauth/http_transport.go

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,24 @@ var _ http.Transport
1111

1212
var _ http.RoundTripper = (*Transport)(nil)
1313

14+
const defaultRefreshWindow = 5 * time.Minute
15+
1416
type Transport struct {
15-
Base http.RoundTripper
16-
//Token is a OAuth token (which has a refresh token) that should be used during roundtrip to automatically
17-
//refresh the OAuth access token once the current one has expired or is soon to expire
18-
Token *Token
17+
Base http.RoundTripper
18+
refresher *TokenRefresher
19+
}
1920

20-
//mu is a mutex that should be acquired whenever token used
21-
mu sync.Mutex
21+
type TokenRefresher struct {
22+
token *Token
23+
mu sync.Mutex
24+
}
25+
26+
func NewTokenRefresher(token *Token) *TokenRefresher {
27+
return &TokenRefresher{token: token}
28+
}
29+
30+
func NewTransport(base http.RoundTripper, token *Token) *Transport {
31+
return &Transport{Base: base, refresher: NewTokenRefresher(token)}
2232
}
2333

2434
// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during
@@ -28,8 +38,7 @@ var storeRefreshedTokenFn = StoreToken
2838
// RoundTrip implements http.RoundTripper.
2939
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
3040
ctx := req.Context()
31-
32-
token, err := t.getToken(ctx)
41+
token, err := t.refresher.GetToken(ctx)
3342
if err != nil {
3443
return nil, err
3544
}
@@ -43,35 +52,39 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
4352
return http.DefaultTransport.RoundTrip(req2)
4453
}
4554

46-
// getToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
55+
// GetToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
4756
// Once the token is refreshed, the in-memory token is updated and a best effort is made to store the token.
4857
//
4958
// If storing the token fails, no error is returned. An error is only returned if refreshing the token
5059
// fails.
51-
func (t *Transport) getToken(ctx context.Context) (Token, error) {
52-
t.mu.Lock()
53-
defer t.mu.Unlock()
60+
func (r *TokenRefresher) GetToken(ctx context.Context) (Token, error) {
61+
r.mu.Lock()
62+
defer r.mu.Unlock()
5463

55-
prevToken := t.Token
56-
token, err := maybeRefresh(ctx, t.Token)
64+
prevToken := r.token
65+
token, err := maybeRefreshToken(ctx, r.token)
5766
if err != nil {
5867
return Token{}, err
5968
}
60-
t.Token = token
69+
r.token = token
6170
if token != prevToken {
6271
// try to save the token if we fail let the request continue with in memory token
6372
_ = storeRefreshedTokenFn(ctx, token)
6473
}
6574

66-
return *t.Token, nil
75+
return *r.token, nil
6776
}
6877

69-
// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s
70-
// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original
71-
// token is returned.
72-
func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
78+
// maybeRefreshToken conditionally refreshes the token. If the token has expired or is
79+
// expiring within the default refresh window, it will be refreshed and the updated token returned.
80+
// Otherwise, no refresh occurs and the original token is returned.
81+
func maybeRefreshToken(ctx context.Context, token *Token) (*Token, error) {
82+
if token == nil {
83+
return nil, errors.New("token is nil")
84+
}
85+
7386
// token has NOT expired and is NOT about to expire in 30s
74-
if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) {
87+
if !(token.HasExpired() || token.ExpiringIn(defaultRefreshWindow)) {
7588
return token, nil
7689
}
7790
client := NewClient(token.ClientID)

0 commit comments

Comments
 (0)