Skip to content

Commit 0e5e006

Browse files
authored
bug: only require CI=true and AccessToken when making API requests (#1278)
* check when inCI and whether we required Access Token * fast fail in APIClient when CI=true and no Access Token * use apiClient constructor in search_jobs * fail fast in login / auth token when CI=true and AccessToken not set * add comment * review feedback - rename `RequireAccessToken` to `checkIfCIAccessTokenRequired` - provide more context in error on why it happened * use sentinel error * use lib errors
1 parent 0267380 commit 0e5e006

File tree

8 files changed

+146
-29
lines changed

8 files changed

+146
-29
lines changed

cmd/src/auth_token.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ func init() {
5454
}
5555

5656
func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
57+
if err := cfg.requireCIAccessToken(); err != nil {
58+
return "", err
59+
}
60+
5761
if cfg.accessToken != "" {
5862
return cfg.accessToken, nil
5963
}

cmd/src/auth_token_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ func TestResolveAuthToken(t *testing.T) {
3535
}
3636
})
3737

38+
t.Run("requires access token in CI", func(t *testing.T) {
39+
reset := stubAuthTokenDependencies(t)
40+
defer reset()
41+
42+
loadCalled := false
43+
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
44+
loadCalled = true
45+
return nil, nil
46+
}
47+
48+
_, err := resolveAuthToken(context.Background(), &config{
49+
inCI: true,
50+
endpointURL: mustParseURL(t, "https://example.com"),
51+
})
52+
if err != errCIAccessTokenRequired {
53+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
54+
}
55+
if loadCalled {
56+
t.Fatal("expected OAuth token loader not to be called")
57+
}
58+
})
59+
3860
t.Run("uses stored oauth token", func(t *testing.T) {
3961
reset := stubAuthTokenDependencies(t)
4062
defer reset()

cmd/src/login.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ const (
100100
)
101101

102102
func loginCmd(ctx context.Context, p loginParams) error {
103+
if err := p.cfg.requireCIAccessToken(); err != nil {
104+
return err
105+
}
106+
103107
if p.cfg.configFilePath != "" {
104108
fmt.Fprintln(p.out)
105109
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)

cmd/src/login_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ func TestLogin(t *testing.T) {
6161
}
6262
})
6363

64+
t.Run("CI requires access token", func(t *testing.T) {
65+
u := &url.URL{Scheme: "https", Host: "example.com"}
66+
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
67+
if err != errCIAccessTokenRequired {
68+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
69+
}
70+
if out != "" {
71+
t.Fatalf("output = %q, want empty output", out)
72+
}
73+
})
74+
6475
t.Run("warning when using config file", func(t *testing.T) {
6576
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
6677
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)

cmd/src/main.go

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ var (
8282

8383
errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set")
8484
errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set")
85-
errCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set in CI")
85+
errCIAccessTokenRequired = errors.New("CI is true and SRC_ACCESS_TOKEN is not set or empty. When running in CI OAuth tokens cannot be used, only SRC_ACCESS_TOKEN. Either set CI=false or define a SRC_ACCESS_TOKEN")
8686
)
8787

8888
// commands contains all registered subcommands.
@@ -137,6 +137,7 @@ type config struct {
137137
proxyPath string
138138
configFilePath string
139139
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
140+
inCI bool
140141
}
141142

142143
// configFromFile holds the config as read from the config file,
@@ -162,16 +163,32 @@ func (c *config) AuthMode() AuthMode {
162163
return AuthModeOAuth
163164
}
164165

166+
func (c *config) InCI() bool {
167+
return c.inCI
168+
}
169+
170+
func (c *config) requireCIAccessToken() error {
171+
// In CI we typically do not have access to the keyring and the machine is also typically headless
172+
// we therefore require SRC_ACCESS_TOKEN to be set when in CI.
173+
// If someone really wants to run with OAuth in CI they can temporarily do CI=false
174+
if c.InCI() && c.AuthMode() != AuthModeAccessToken {
175+
return errCIAccessTokenRequired
176+
}
177+
178+
return nil
179+
}
180+
165181
// apiClient returns an api.Client built from the configuration.
166182
func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client {
167183
opts := api.ClientOpts{
168-
EndpointURL: c.endpointURL,
169-
AccessToken: c.accessToken,
170-
AdditionalHeaders: c.additionalHeaders,
171-
Flags: flags,
172-
Out: out,
173-
ProxyURL: c.proxyURL,
174-
ProxyPath: c.proxyPath,
184+
EndpointURL: c.endpointURL,
185+
AccessToken: c.accessToken,
186+
AdditionalHeaders: c.additionalHeaders,
187+
Flags: flags,
188+
Out: out,
189+
ProxyURL: c.proxyURL,
190+
ProxyPath: c.proxyPath,
191+
RequireAccessTokenInCI: c.InCI(),
175192
}
176193

177194
// Only use OAuth if we do not have SRC_ACCESS_TOKEN set
@@ -205,6 +222,7 @@ func readConfig() (*config, error) {
205222

206223
var cfgFromFile configFromFile
207224
var cfg config
225+
cfg.inCI = isCI()
208226
var endpointStr string
209227
var proxyStr string
210228
if err == nil {
@@ -312,10 +330,6 @@ func readConfig() (*config, error) {
312330
return nil, errConfigAuthorizationConflict
313331
}
314332

315-
if isCI() && cfg.accessToken == "" {
316-
return nil, errCIAccessTokenRequired
317-
}
318-
319333
return &cfg, nil
320334
}
321335

cmd/src/main_test.go

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package main
22

33
import (
4+
"context"
45
"encoding/json"
6+
"io"
57
"net/url"
68
"os"
79
"path/filepath"
@@ -10,6 +12,7 @@ import (
1012
"github.com/google/go-cmp/cmp"
1113
"github.com/google/go-cmp/cmp/cmpopts"
1214

15+
"github.com/sourcegraph/sourcegraph/lib/errors"
1316
"github.com/sourcegraph/src-cli/internal/api"
1417
)
1518

@@ -325,9 +328,13 @@ func TestReadConfig(t *testing.T) {
325328
wantErr: errConfigAuthorizationConflict.Error(),
326329
},
327330
{
328-
name: "CI requires access token",
329-
envCI: "1",
330-
wantErr: errCIAccessTokenRequired.Error(),
331+
name: "CI does not require access token during config read",
332+
envCI: "1",
333+
want: &config{
334+
endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"},
335+
additionalHeaders: map[string]string{},
336+
inCI: true,
337+
},
331338
},
332339
{
333340
name: "CI allows access token from config file",
@@ -340,6 +347,7 @@ func TestReadConfig(t *testing.T) {
340347
endpointURL: &url.URL{Scheme: "https", Host: "example.com"},
341348
accessToken: "deadbeef",
342349
additionalHeaders: map[string]string{},
350+
inCI: true,
343351
},
344352
},
345353
}
@@ -422,3 +430,36 @@ func TestConfigAuthMode(t *testing.T) {
422430
}
423431
})
424432
}
433+
434+
func TestConfigAPIClientCIAccessTokenGate(t *testing.T) {
435+
endpointURL := &url.URL{Scheme: "https", Host: "example.com"}
436+
437+
t.Run("requires access token in CI", func(t *testing.T) {
438+
client := (&config{endpointURL: endpointURL, inCI: true}).apiClient(nil, io.Discard)
439+
440+
_, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
441+
if !errors.Is(err, api.ErrCIAccessTokenRequired) {
442+
t.Fatalf("NewHTTPRequest() error = %v, want %v", err, api.ErrCIAccessTokenRequired)
443+
}
444+
})
445+
446+
t.Run("allows access token in CI", func(t *testing.T) {
447+
client := (&config{endpointURL: endpointURL, inCI: true, accessToken: "abc"}).apiClient(nil, io.Discard)
448+
449+
req, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
450+
if err != nil {
451+
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
452+
}
453+
if got := req.Header.Get("Authorization"); got != "token abc" {
454+
t.Fatalf("Authorization header = %q, want %q", got, "token abc")
455+
}
456+
})
457+
458+
t.Run("allows oauth mode outside CI", func(t *testing.T) {
459+
client := (&config{endpointURL: endpointURL}).apiClient(nil, io.Discard)
460+
461+
if _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil); err != nil {
462+
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
463+
}
464+
})
465+
}

cmd/src/search_jobs.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,7 @@ func parseColumns(columnsFlag string) []string {
155155

156156
// createSearchJobsClient creates a reusable API client for search jobs commands
157157
func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client {
158-
return api.NewClient(api.ClientOpts{
159-
EndpointURL: cfg.endpointURL,
160-
AccessToken: cfg.accessToken,
161-
Out: out.Output(),
162-
Flags: apiFlags,
163-
})
158+
return cfg.apiClient(apiFlags, out.Output())
164159
}
165160

166161
// parseSearchJobsArgs parses command arguments with the provided flag set

internal/api/api.go

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919

2020
"github.com/sourcegraph/src-cli/internal/oauth"
2121
"github.com/sourcegraph/src-cli/internal/version"
22+
23+
"github.com/sourcegraph/sourcegraph/lib/errors"
2224
)
2325

2426
// Client instances provide methods to create API requests.
@@ -71,9 +73,10 @@ type request struct {
7173

7274
// ClientOpts encapsulates the options given to NewClient.
7375
type ClientOpts struct {
74-
EndpointURL *url.URL
75-
AccessToken string
76-
AdditionalHeaders map[string]string
76+
EndpointURL *url.URL
77+
AccessToken string
78+
AdditionalHeaders map[string]string
79+
RequireAccessTokenInCI bool
7780

7881
// Flags are the standard API client flags provided by NewFlags. If nil,
7982
// default values will be used.
@@ -89,6 +92,9 @@ type ClientOpts struct {
8992
OAuthToken *oauth.Token
9093
}
9194

95+
// ErrCIAccessTokenRequired indicates SRC_ACCESS_TOKEN must be set when CI=true.
96+
var ErrCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set when CI=true")
97+
9298
func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
9399
var transport http.RoundTripper
94100
{
@@ -109,6 +115,9 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
109115
transport = tp
110116
}
111117

118+
// not we do not fail here if requireAccessToken is true, because that would
119+
// mean returning an error on construction which we want to avoid for now
120+
// TODO(burmudar): allow returning of an error upon client construction
112121
if opts.AccessToken == "" && opts.OAuthToken != nil {
113122
transport = oauth.NewTransport(transport, opts.OAuthToken)
114123
}
@@ -135,15 +144,24 @@ func NewClient(opts ClientOpts) Client {
135144

136145
return &client{
137146
opts: ClientOpts{
138-
EndpointURL: opts.EndpointURL,
139-
AccessToken: opts.AccessToken,
140-
AdditionalHeaders: opts.AdditionalHeaders,
141-
Flags: flags,
142-
Out: opts.Out,
147+
EndpointURL: opts.EndpointURL,
148+
AccessToken: opts.AccessToken,
149+
AdditionalHeaders: opts.AdditionalHeaders,
150+
RequireAccessTokenInCI: opts.RequireAccessTokenInCI,
151+
Flags: flags,
152+
Out: opts.Out,
143153
},
144154
httpClient: httpClient,
145155
}
146156
}
157+
158+
func (c *client) checkIfCIAccessTokenRequired() error {
159+
if c.opts.RequireAccessTokenInCI && c.opts.AccessToken == "" {
160+
return ErrCIAccessTokenRequired
161+
}
162+
163+
return nil
164+
}
147165
func (c *client) NewQuery(query string) Request {
148166
return c.NewRequest(query, nil)
149167
}
@@ -170,6 +188,10 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R
170188
}
171189

172190
func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) {
191+
if err := c.checkIfCIAccessTokenRequired(); err != nil {
192+
return nil, err
193+
}
194+
173195
// Can't use c.opts.EndpointURL.JoinPath(p) here because `p` could contain a query string
174196
req, err := http.NewRequestWithContext(ctx, method, c.opts.EndpointURL.String()+"/"+p, body)
175197
if err != nil {
@@ -199,6 +221,10 @@ func (c *client) createHTTPRequest(ctx context.Context, method, p string, body i
199221
}
200222

201223
func (r *request) do(ctx context.Context, result any) (bool, error) {
224+
if err := r.client.checkIfCIAccessTokenRequired(); err != nil {
225+
return false, err
226+
}
227+
202228
if *r.client.opts.Flags.getCurl {
203229
curl, err := r.curlCmd()
204230
if err != nil {

0 commit comments

Comments
 (0)