Skip to content

Commit d206288

Browse files
authored
feat/auth: add src auth token (#1275)
* add TokenRefresher and NewTransport - TokenRefresher centralizes refreshing of OAuth tokens - NewTransport creates an OAuth transport while making sure it is initialized with a TokenRefresher * add cmd `src auth token` that prints access token or oauth token * set refresh window to 5 min
1 parent 882a074 commit d206288

7 files changed

Lines changed: 282 additions & 36 deletions

File tree

cmd/src/auth.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
)
7+
8+
var authCommands commander
9+
10+
func init() {
11+
usage := `'src auth' provides authentication-related helper commands.
12+
13+
Usage:
14+
15+
src auth command [command options]
16+
17+
The commands are:
18+
19+
token prints the current authentication token
20+
21+
Use "src auth [command] -h" for more information about a command.
22+
`
23+
24+
flagSet := flag.NewFlagSet("auth", flag.ExitOnError)
25+
handler := func(args []string) error {
26+
authCommands.run(flagSet, "src auth", usage, args)
27+
return nil
28+
}
29+
30+
commands = append(commands, &command{
31+
flagSet: flagSet,
32+
handler: handler,
33+
usageFunc: func() {
34+
fmt.Println(usage)
35+
},
36+
})
37+
}

cmd/src/auth_token.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
8+
"github.com/sourcegraph/sourcegraph/lib/errors"
9+
10+
"github.com/sourcegraph/src-cli/internal/oauth"
11+
)
12+
13+
var (
14+
loadOAuthToken = oauth.LoadToken
15+
newOAuthTokenRefresher = func(token *oauth.Token) oauthTokenRefresher {
16+
return oauth.NewTokenRefresher(token)
17+
}
18+
)
19+
20+
type oauthTokenRefresher interface {
21+
GetToken(ctx context.Context) (oauth.Token, error)
22+
}
23+
24+
func init() {
25+
flagSet := flag.NewFlagSet("token", flag.ExitOnError)
26+
usageFunc := func() {
27+
fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src auth token':\n")
28+
flagSet.PrintDefaults()
29+
}
30+
31+
handler := func(args []string) error {
32+
if err := flagSet.Parse(args); err != nil {
33+
return err
34+
}
35+
36+
token, err := resolveAuthToken(context.Background(), cfg)
37+
if err != nil {
38+
return err
39+
}
40+
41+
fmt.Println(token)
42+
return nil
43+
}
44+
45+
authCommands = append(authCommands, &command{
46+
flagSet: flagSet,
47+
handler: handler,
48+
usageFunc: usageFunc,
49+
})
50+
}
51+
52+
func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
53+
if cfg.accessToken != "" {
54+
return cfg.accessToken, nil
55+
}
56+
57+
oauthToken, err := loadOAuthToken(ctx, cfg.endpointURL)
58+
if err != nil {
59+
return "", errors.Wrap(err, "error loading OAuth token; set SRC_ACCESS_TOKEN or run `src login`")
60+
}
61+
62+
token, err := newOAuthTokenRefresher(oauthToken).GetToken(ctx)
63+
if err != nil {
64+
return "", errors.Wrap(err, "refreshing OAuth token")
65+
}
66+
67+
return token.AccessToken, nil
68+
}

cmd/src/auth_token_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"testing"
8+
9+
"github.com/sourcegraph/src-cli/internal/oauth"
10+
)
11+
12+
func TestResolveAuthToken(t *testing.T) {
13+
t.Run("uses configured access token before keyring", func(t *testing.T) {
14+
reset := stubAuthTokenDependencies(t)
15+
defer reset()
16+
17+
newRefresherCalled := false
18+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
19+
newRefresherCalled = true
20+
return fakeOAuthTokenRefresher{}
21+
}
22+
23+
token, err := resolveAuthToken(context.Background(), &config{
24+
accessToken: "access-token",
25+
endpointURL: mustParseURL(t, "https://example.com"),
26+
})
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
if token != "access-token" {
31+
t.Fatalf("token = %q, want %q", token, "access-token")
32+
}
33+
if newRefresherCalled {
34+
t.Fatal("expected OAuth token refresher not to be created")
35+
}
36+
})
37+
38+
t.Run("uses stored oauth token", func(t *testing.T) {
39+
reset := stubAuthTokenDependencies(t)
40+
defer reset()
41+
42+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
43+
return &oauth.Token{
44+
AccessToken: "oauth-token",
45+
}, nil
46+
}
47+
48+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
49+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}}
50+
}
51+
52+
token, err := resolveAuthToken(context.Background(), &config{
53+
endpointURL: mustParseURL(t, "https://example.com"),
54+
})
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
if token != "oauth-token" {
59+
t.Fatalf("token = %q, want %q", token, "oauth-token")
60+
}
61+
})
62+
63+
t.Run("refreshes expiring oauth token", func(t *testing.T) {
64+
reset := stubAuthTokenDependencies(t)
65+
defer reset()
66+
67+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
68+
return &oauth.Token{AccessToken: "old-token"}, nil
69+
}
70+
71+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
72+
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}}
73+
}
74+
75+
token, err := resolveAuthToken(context.Background(), &config{
76+
endpointURL: mustParseURL(t, "https://example.com"),
77+
})
78+
if err != nil {
79+
t.Fatal(err)
80+
}
81+
if token != "new-token" {
82+
t.Fatalf("token = %q, want %q", token, "new-token")
83+
}
84+
})
85+
86+
t.Run("returns refresh error when shared refresh logic fails", func(t *testing.T) {
87+
reset := stubAuthTokenDependencies(t)
88+
defer reset()
89+
90+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
91+
return &oauth.Token{AccessToken: "old-token"}, nil
92+
}
93+
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
94+
return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")}
95+
}
96+
97+
_, err := resolveAuthToken(context.Background(), &config{
98+
endpointURL: mustParseURL(t, "https://example.com"),
99+
})
100+
if err == nil {
101+
t.Fatal("expected error")
102+
}
103+
})
104+
}
105+
106+
func stubAuthTokenDependencies(t *testing.T) func() {
107+
t.Helper()
108+
109+
prevLoad := loadOAuthToken
110+
prevNewRefresher := newOAuthTokenRefresher
111+
112+
return func() {
113+
loadOAuthToken = prevLoad
114+
newOAuthTokenRefresher = prevNewRefresher
115+
}
116+
}
117+
118+
type fakeOAuthTokenRefresher struct {
119+
token oauth.Token
120+
err error
121+
}
122+
123+
func (r fakeOAuthTokenRefresher) GetToken(context.Context) (oauth.Token, error) {
124+
if r.err != nil {
125+
return oauth.Token{}, r.err
126+
}
127+
return r.token, nil
128+
}

cmd/src/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ The options are:
5050
5151
The commands are:
5252
53+
auth authentication helper commands
5354
api interacts with the Sourcegraph GraphQL API
5455
batch manages batch changes
5556
code-intel manages code intelligence data

internal/api/api.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,7 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
110110
}
111111

112112
if opts.AccessToken == "" && opts.OAuthToken != nil {
113-
transport = &oauth.Transport{
114-
Base: transport,
115-
Token: opts.OAuthToken,
116-
}
113+
transport = oauth.NewTransport(transport, opts.OAuthToken)
117114
}
118115

119116
return transport

internal/oauth/http_transport.go

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,24 @@ var _ http.Transport
1313

1414
var _ http.RoundTripper = (*Transport)(nil)
1515

16+
const defaultRefreshWindow = 5 * time.Minute
17+
1618
type Transport struct {
17-
Base http.RoundTripper
18-
//Token is a OAuth token (which has a refresh token) that should be used during roundtrip to automatically
19-
//refresh the OAuth access token once the current one has expired or is soon to expire
20-
Token *Token
19+
Base http.RoundTripper
20+
refresher *TokenRefresher
21+
}
2122

22-
//mu is a mutex that should be acquired whenever token used
23-
mu sync.Mutex
23+
type TokenRefresher struct {
24+
token *Token
25+
mu sync.Mutex
26+
}
27+
28+
func NewTokenRefresher(token *Token) *TokenRefresher {
29+
return &TokenRefresher{token: token}
30+
}
31+
32+
func NewTransport(base http.RoundTripper, token *Token) *Transport {
33+
return &Transport{Base: base, refresher: NewTokenRefresher(token)}
2434
}
2535

2636
// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during
@@ -30,8 +40,7 @@ var storeRefreshedTokenFn = StoreToken
3040
// RoundTrip implements http.RoundTripper.
3141
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
3242
ctx := req.Context()
33-
34-
token, err := t.getToken(ctx)
43+
token, err := t.refresher.GetToken(ctx)
3544
if err != nil {
3645
return nil, err
3746
}
@@ -45,36 +54,40 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
4554
return http.DefaultTransport.RoundTrip(req2)
4655
}
4756

48-
// getToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
57+
// GetToken returns a value copy of the token. If the token has expired or expiring soon it will be refreshed before returning.
4958
// Once the token is refreshed, the in-memory token is updated and a best effort is made to store the token.
5059
//
5160
// If storing the token fails, no error is returned. An error is only returned if refreshing the token
5261
// fails.
53-
func (t *Transport) getToken(ctx context.Context) (Token, error) {
54-
t.mu.Lock()
55-
defer t.mu.Unlock()
62+
func (r *TokenRefresher) GetToken(ctx context.Context) (Token, error) {
63+
r.mu.Lock()
64+
defer r.mu.Unlock()
5665

57-
prevToken := t.Token
58-
token, err := maybeRefresh(ctx, t.Token)
66+
prevToken := r.token
67+
token, err := maybeRefreshToken(ctx, r.token)
5968
if err != nil {
6069
return Token{}, err
6170
}
62-
t.Token = token
71+
r.token = token
6372
if token != prevToken {
6473
// Try to save the token.
6574
// If we fail let the request continue with the in-memory token
6675
_ = storeRefreshedTokenFn(ctx, token)
6776
}
6877

69-
return *t.Token, nil
78+
return *r.token, nil
7079
}
7180

72-
// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s
73-
// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original
74-
// token is returned.
75-
func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
81+
// maybeRefreshToken conditionally refreshes the token. If the token has expired or is
82+
// expiring within the default refresh window, it will be refreshed and the updated token returned.
83+
// Otherwise, no refresh occurs and the original token is returned.
84+
func maybeRefreshToken(ctx context.Context, token *Token) (*Token, error) {
85+
if token == nil {
86+
return nil, errors.New("token is nil")
87+
}
88+
7689
// token has NOT expired and is NOT about to expire in 30s
77-
if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) {
90+
if !(token.HasExpired() || token.ExpiringIn(defaultRefreshWindow)) {
7891
return token, nil
7992
}
8093
client := NewClient(token.ClientID)

0 commit comments

Comments
 (0)