Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package cluster

import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
Expand All @@ -27,6 +26,7 @@ import (
"net/url"
"os"
"path"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
Expand Down Expand Up @@ -312,55 +312,87 @@ func CheckStatusCode(res *http.Response) error {
return errors.New(string(body))
}

func getStringClaim(claims jwt.MapClaims, name string) (string, error) {
value, ok := claims[name]
if !ok {
return "", fmt.Errorf("missing %q claim", name)
}

text, ok := value.(string)
if !ok || text == "" {
return "", fmt.Errorf("invalid %q claim", name)
}

return text, nil
}

func (cluster *Cluster) getAccessToken() (string, error) {
token, _ := jwt.Parse(cluster.OIDCRefreshToken, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
iss, err := token.Claims.GetIssuer()
claims := jwt.MapClaims{}
_, _, err := jwt.NewParser().ParseUnverified(cluster.OIDCRefreshToken, claims)
if err != nil {
return "", fmt.Errorf("invalid OIDC refresh token: %w", err)
}

issuer, err := claims.GetIssuer()
if err != nil {
fmt.Println(err)
return "", fmt.Errorf("invalid OIDC refresh token issuer: %w", err)
}
url := iss + "/protocol/openid-connect/token"

scope, err := getStringClaim(claims, "scope")
if err != nil {
fmt.Println(err)
}
var scope string
var clientId string
//client_id := token.Claims.
if str, ok := token.Claims.(jwt.MapClaims); ok {
scope = str["scope"].(string)
clientId = str["azp"].(string)
} else {
fmt.Println("error")
return "", fmt.Errorf("invalid OIDC refresh token: %w", err)
}

jsonBody := []byte("grant_type=refresh_token&refresh_token=" +
cluster.OIDCRefreshToken +
"&client_id=" + clientId + "&scope=" + scope)
clientID, err := getStringClaim(claims, "azp")
if err != nil {
return "", fmt.Errorf("invalid OIDC refresh token: %w", err)
}

bodyReader := bytes.NewReader(jsonBody)
req, err := http.NewRequest(http.MethodPost, url, bodyReader)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
tokenURL, err := url.Parse(issuer)
if err != nil {
return "", fmt.Errorf("error at new request: %v", err)
return "", fmt.Errorf("invalid issuer URL in OIDC refresh token: %w", err)
}
var res *http.Response
tokenURL.Path = path.Join(tokenURL.Path, "protocol/openid-connect/token")

form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", cluster.OIDCRefreshToken)
form.Set("client_id", clientID)
form.Set("scope", scope)

req, err := http.NewRequest(http.MethodPost, tokenURL.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", fmt.Errorf("error creating OIDC token request: %w", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

// Defensive timeout to avoid hanging the TUI when the IdP is slow/unreachable.
client := &http.Client{Timeout: 15 * time.Second}
res, err = client.Do(req)
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error in the request : %v", err)
return "", fmt.Errorf("error sending OIDC token request: %w", err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(res.Body)
respBytes := buf.String()
defer res.Body.Close()

respString := string(respBytes)
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
body, readErr := io.ReadAll(io.LimitReader(res.Body, 4096))
if readErr != nil {
return "", fmt.Errorf("OIDC token endpoint returned %s and the error body could not be read: %w", res.Status, readErr)
}
message := strings.TrimSpace(string(body))
if message == "" {
return "", fmt.Errorf("OIDC token endpoint returned %s", res.Status)
}
return "", fmt.Errorf("OIDC token endpoint returned %s: %s", res.Status, message)
}

var rrt ResponseRefreshToken
err = json.Unmarshal([]byte(respString), &rrt)
if err != nil {
return "", fmt.Errorf("error: cannot read the response json: %v", err)
if err := json.NewDecoder(res.Body).Decode(&rrt); err != nil {
return "", fmt.Errorf("cannot decode OIDC token response: %w", err)
}
if rrt.AccessToken == "" {
return "", errors.New("OIDC token response did not include an access_token")
}

return rrt.AccessToken, nil
}
79 changes: 79 additions & 0 deletions pkg/cluster/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/grycap/oscar/v3/pkg/types"
)

Expand Down Expand Up @@ -258,3 +259,81 @@ func TestGetClusterStatusError(t *testing.T) {
t.Fatalf("expected boom error, got %v", err)
}
}

func TestGetClientSafeInvalidOIDCRefreshToken(t *testing.T) {
c := &Cluster{
Endpoint: "https://cluster.example",
OIDCRefreshToken: "change-me",
SSLVerify: true,
}

_, err := c.GetClientSafe()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), "unable to get the OIDC token from refresh token") {
t.Fatalf("expected wrapped refresh token error, got %v", err)
}
if !strings.Contains(err.Error(), "invalid OIDC refresh token") {
t.Fatalf("expected invalid token detail, got %v", err)
}
}

func TestGetAccessToken(t *testing.T) {
const (
issuerPath = "/realms/oscar"
expectedAzp = "oscar-cli"
expectedScope = "openid profile"
accessToken = "access-token"
)

var receivedRefreshToken string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != issuerPath+"/protocol/openid-connect/token" {
http.NotFound(w, r)
return
}
if err := r.ParseForm(); err != nil {
t.Fatalf("parsing form: %v", err)
}
if got := r.Form.Get("grant_type"); got != "refresh_token" {
t.Fatalf("expected refresh_token grant type, got %q", got)
}
if got := r.Form.Get("client_id"); got != expectedAzp {
t.Fatalf("expected client_id %q, got %q", expectedAzp, got)
}
if got := r.Form.Get("scope"); got != expectedScope {
t.Fatalf("expected scope %q, got %q", expectedScope, got)
}
receivedRefreshToken = r.Form.Get("refresh_token")
if err := json.NewEncoder(w).Encode(ResponseRefreshToken{AccessToken: accessToken}); err != nil {
t.Fatalf("encoding token response: %v", err)
}
}))
defer server.Close()

refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": server.URL + issuerPath,
"scope": expectedScope,
"azp": expectedAzp,
}).SignedString([]byte("secret"))
if err != nil {
t.Fatalf("signing refresh token: %v", err)
}

c := &Cluster{
OIDCRefreshToken: refreshToken,
}

got, err := c.getAccessToken()
if err != nil {
t.Fatalf("getAccessToken returned error: %v", err)
}
if got != accessToken {
t.Fatalf("expected access token %q, got %q", accessToken, got)
}
if receivedRefreshToken != refreshToken {
t.Fatalf("expected refresh token to be forwarded unchanged")
}
}
Loading