diff --git a/authentication/openshift.go b/authentication/openshift.go index 3944f3ac..0aa793fe 100644 --- a/authentication/openshift.go +++ b/authentication/openshift.go @@ -77,6 +77,7 @@ type OpenShiftAuthenticator struct { oauth2Config oauth2.Config cookieName string handler http.Handler + oauthEnabled bool } //nolint:funlen @@ -139,42 +140,25 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string, MaxRetries: 0, // Retry indefinitely. }) - var authURL *url.URL - - var tokenURL *url.URL - - for b.Reset(); b.Ongoing(); { - authURL, tokenURL, err = openshift.DiscoverOAuth(client) - if err != nil { - level.Error(logger).Log( - "tenant", tenant, - "msg", errors.Wrap(err, "unable to auto discover OpenShift OAuth endpoints")) - registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc() - b.Wait() - - continue - } - - break - } + authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, registrationRetryCount, b) var clientID string - var clientSecret string - for b.Reset(); b.Ongoing(); { - clientID, clientSecret, err = openshift.DiscoverCredentials(config.ServiceAccount) - if err != nil { - level.Error(logger).Log( - "tenant", tenant, - "msg", errors.Wrap(err, "unable to read serviceaccount credentials")) - registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc() - b.Wait() + if oauthEnabled { + for b.Reset(); b.Ongoing(); { + clientID, clientSecret, err = openshift.DiscoverCredentials(config.ServiceAccount) + if err != nil { + level.Error(logger).Log( + "tenant", tenant, + "msg", errors.Wrap(err, "unable to read serviceaccount credentials")) + registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc() + b.Wait() - continue + continue + } + break } - - break } authOpts := openshift.DelegatingAuthenticationOptions{ @@ -221,7 +205,13 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string, client: client, config: config, cookieName: fmt.Sprintf("observatorium_%s", tenant), - oauth2Config: oauth2.Config{ + oauthEnabled: oauthEnabled, + } + + r := chi.NewRouter() + r.Use(tracing.WithChiRoutePattern) + if oauthEnabled { + osAuthenticator.oauth2Config = oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, Endpoint: oauth2.Endpoint{ @@ -235,13 +225,12 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string, defaultOAuthScopeListProjects, }, RedirectURL: config.RedirectURL, - }, + } + + r.Handle(loginRoute, osAuthenticator.openshiftLoginHandler()) + r.Handle(callbackRoute, osAuthenticator.openshiftCallbackHandler()) } - r := chi.NewRouter() - r.Use(tracing.WithChiRoutePattern) - r.Handle(loginRoute, osAuthenticator.openshiftLoginHandler()) - r.Handle(callbackRoute, osAuthenticator.openshiftCallbackHandler()) osAuthenticator.handler = r return osAuthenticator, nil @@ -442,6 +431,13 @@ func (a OpenShiftAuthenticator) Middleware() Middleware { // when users went through the OAuth2 flow supported by this // provider. Observatorium stores a self-signed JWT token on a // cookie per tenant to identify the subject of incoming requests. + if !a.oauthEnabled { + msg := "OAuth authentication not available" + level.Debug(a.logger).Log("msg", msg) + httperr.PrometheusAPIError(w, msg, http.StatusUnauthorized) + return + } + cookie, err := r.Cookie(a.cookieName) if err != nil { tenant, ok := GetTenant(r.Context()) @@ -541,3 +537,24 @@ func (a OpenShiftAuthenticator) GRPCMiddleware() grpc.StreamServerInterceptor { func (a OpenShiftAuthenticator) Handler() (string, http.Handler) { return "/openshift/{tenant}", a.handler } + +func discoverOAuthEndpoints(client *http.Client, logger log.Logger, tenant string, registrationRetryCount *prometheus.CounterVec, b *backoff.Backoff) (*url.URL, *url.URL, bool) { + var authURL, tokenURL *url.URL + var err error + for b.Reset(); b.Ongoing(); { + authURL, tokenURL, err = openshift.DiscoverOAuth(client) + if err != nil { + if errors.Is(err, openshift.ErrOAuthServerNotFound) { + return nil, nil, false + } + level.Error(logger).Log( + "tenant", tenant, + "msg", errors.Wrap(err, "unable to auto discover OpenShift OAuth endpoints")) + registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc() + b.Wait() + continue + } + break + } + return authURL, tokenURL, true +} diff --git a/authentication/openshift/discovery.go b/authentication/openshift/discovery.go index 22b34848..f748208b 100644 --- a/authentication/openshift/discovery.go +++ b/authentication/openshift/discovery.go @@ -8,10 +8,12 @@ import ( "net/url" "os" "strings" + + "github.com/pkg/errors" ) const ( - oauthWellKnownPath = "/.well-known/oauth-authorization-server" + OauthWellKnownPath = "/.well-known/oauth-authorization-server" // ServiceAccountNamespacePath is the path to the default serviceaccount namespace. ServiceAccountNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" @@ -21,6 +23,8 @@ const ( ServiceAccountCAPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" ) +var ErrOAuthServerNotFound = errors.Errorf("OAuth server not found") + // GetServiceAccountCACert returns the PEM-encoded CA certificate currently mounted. func GetServiceAccountCACert() ([]byte, error) { rawCA, err := os.ReadFile(ServiceAccountCAPath) @@ -54,7 +58,7 @@ func DiscoverCredentials(name string) (string, string, error) { // DiscoverOAuth return the authorization and token endpoints of the OpenShift OAuth server. // Returns an error if requesting the `/.well-known/oauth-authorization-server` fails. func DiscoverOAuth(client *http.Client) (authURL, tokenURL *url.URL, err error) { - oauthURL := toKubeAPIURLWithPath(oauthWellKnownPath) + oauthURL := toKubeAPIURLWithPath(OauthWellKnownPath) req, err := http.NewRequest(http.MethodGet, oauthURL.String(), nil) if err != nil { @@ -73,6 +77,9 @@ func DiscoverOAuth(client *http.Client) (authURL, tokenURL *url.URL, err error) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if resp.StatusCode == 404 { + return nil, nil, ErrOAuthServerNotFound + } return nil, nil, fmt.Errorf("got %d %s", resp.StatusCode, body) } diff --git a/authentication/openshift_test.go b/authentication/openshift_test.go new file mode 100644 index 00000000..64691b51 --- /dev/null +++ b/authentication/openshift_test.go @@ -0,0 +1,147 @@ +package authentication + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/efficientgo/core/backoff" + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + + "github.com/observatorium/api/authentication/openshift" + "github.com/observatorium/api/logger" +) + +// redirectTransport redirects all requests to the target host. +type redirectTransport struct { + targetHost string + transport http.RoundTripper +} + +func (rt *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Redirect request to mock server while keeping the path + req.URL.Host = rt.targetHost + req.URL.Scheme = "http" + return rt.transport.RoundTrip(req) +} + +func TestDiscoverOAuthEndpoints_OAuthEnabled(t *testing.T) { + tenant := "tenant" + logger := logger.NewLogger("warn", logger.LogFormatLogfmt, "") + r := chi.NewMux() + + r.Get(openshift.OauthWellKnownPath, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{ + "authorization_endpoint": "https://oauth.example.com/authorize", + "token_endpoint": "https://oauth.example.com/token" + }`)); err != nil { + t.Fatalf("failed to write response %v", err) + } + }) + + mockAPIServer := httptest.NewServer(r) + defer mockAPIServer.Close() + + mockURL, err := url.Parse(mockAPIServer.URL) + if err != nil { + t.Fatalf("failed to parse mock server URL: %v", err) + } + + // Split host and port for KUBERNETES env vars + host, _, err := net.SplitHostPort(mockURL.Host) + if err != nil { + t.Fatalf("failed to parse mock server address: %v", err) + } + + t.Setenv("KUBERNETES_SERVICE_HOST", host) + + retryCounter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_retries_total", + Help: "Total number of OAuth discovery retries", + }, + []string{"tenant", "type"}, + ) + + client := &http.Client{ + Transport: &redirectTransport{ + targetHost: mockURL.Host, + transport: http.DefaultTransport, + }, + } + + b := backoff.New(context.TODO(), backoff.Config{ + Min: 500 * time.Millisecond, + Max: 5 * time.Second, + MaxRetries: 0, // Retry indefinitely. + }) + + authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, retryCounter, b) + + assert.NotNil(t, authURL) + assert.NotNil(t, tokenURL) + assert.True(t, oauthEnabled) + assert.Equal(t, authURL.String(), "https://oauth.example.com/authorize") + assert.Equal(t, tokenURL.String(), "https://oauth.example.com/token") +} + +func TestDiscoverOAuthEndpoints_OAuthDisabled(t *testing.T) { + tenant := "tenant" + logger := logger.NewLogger("warn", logger.LogFormatLogfmt, "") + r := chi.NewMux() + + r.Get(openshift.OauthWellKnownPath, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + if _, err := w.Write([]byte("404 page not found")); err != nil { + t.Fatalf("failed to write response %v", err) + } + }) + + mockAPIServer := httptest.NewServer(r) + defer mockAPIServer.Close() + + mockURL, err := url.Parse(mockAPIServer.URL) + if err != nil { + t.Fatalf("failed to parse mock server URL: %v", err) + } + + // Split host and port for KUBERNETES env vars + host, _, err := net.SplitHostPort(mockURL.Host) + if err != nil { + t.Fatalf("failed to parse mock server address: %v", err) + } + + t.Setenv("KUBERNETES_SERVICE_HOST", host) + + retryCounter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_retries_total", + Help: "Total number of OAuth discovery retries", + }, + []string{"tenant", "type"}, + ) + + client := &http.Client{ + Transport: &redirectTransport{ + targetHost: mockURL.Host, + transport: http.DefaultTransport, + }, + } + + b := backoff.New(context.TODO(), backoff.Config{ + Min: 500 * time.Millisecond, + Max: 5 * time.Second, + MaxRetries: 0, // Retry indefinitely. + }) + + authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, retryCounter, b) + + assert.Nil(t, authURL) + assert.Nil(t, tokenURL) + assert.False(t, oauthEnabled) +}