diff --git a/cli/cmd/root.go b/cli/cmd/root.go index 06bd2dfa0..ffacf58e6 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -280,6 +280,7 @@ func Execute(rootCmd *cobra.Command, stdin io.Reader, stdout io.Writer, stderr i vmUpdateCmd := runCmds.InitVMUpdateCommand(vmCmd) runCmds.InitVMUpdateTTL(vmUpdateCmd) + runCmds.InitVMUpdateRBACPolicy(vmUpdateCmd) vmPortCmd := runCmds.InitVMPort(vmCmd) runCmds.InitVMPortLs(vmPortCmd) @@ -384,6 +385,16 @@ func Execute(rootCmd *cobra.Command, stdin io.Reader, stdout io.Writer, stderr i } } + // If running inside a CMX VM, use the api_url from the MMDS so the CLI + // talks to the same vendor-api instance that issued the token. Only + // applies when REPLICATED_API_ORIGIN is not explicitly set by the user. + if creds.IsCMX && creds.APIOrigin != "" && os.Getenv("REPLICATED_API_ORIGIN") == "" { + platformOrigin = strings.TrimRight(creds.APIOrigin, "/") + if debugFlag { + fmt.Fprintf(os.Stderr, "[DEBUG] Using CMX MMDS vendor API origin: %s\n", platformOrigin) + } + } + if debugFlag { fmt.Fprintf(os.Stderr, "[DEBUG] Platform API origin: %s\n", platformOrigin) } diff --git a/cli/cmd/runner.go b/cli/cmd/runner.go index 7c5b9f617..f23654e32 100644 --- a/cli/cmd/runner.go +++ b/cli/cmd/runner.go @@ -222,8 +222,9 @@ type runnerArgs struct { createVMWaitDuration time.Duration createVMTags []string createVMNetwork string - createVMDryRun bool - createVMPublicKeys []string + createVMDryRun bool + createVMPublicKeys []string + createVMRBACPolicyName string lsVMShowTerminated bool lsVMStartTime string @@ -235,7 +236,8 @@ type runnerArgs struct { removeVMNames []string removeVMDryRun bool - updateVMTTL string + updateVMTTL string + updateVMRBACPolicyName string updateVMName string updateVMID string diff --git a/cli/cmd/vm_create.go b/cli/cmd/vm_create.go index cfba1b9c4..d1ea44e97 100644 --- a/cli/cmd/vm_create.go +++ b/cli/cmd/vm_create.go @@ -74,6 +74,8 @@ replicated vm create --distribution ubuntu --version 20.04 --ssh-public-key ~/.s cmd.Flags().StringArrayVar(&r.args.createVMTags, "tag", []string{}, "Tag to apply to the VM (key=value format, can be specified multiple times)") cmd.Flags().StringArrayVar(&r.args.createVMPublicKeys, "ssh-public-key", []string{}, "Path to SSH public key file to add to the VM (can be specified multiple times)") + cmd.Flags().StringVar(&r.args.createVMRBACPolicyName, "rbac-policy-name", "", "(alpha) Name of the RBAC policy to assign to the VM (enables automatic vendor-api authentication inside the VM)") + cmd.Flags().MarkHidden("rbac-policy-name") cmd.Flags().BoolVar(&r.args.createVMDryRun, "dry-run", false, "Dry run") @@ -103,6 +105,15 @@ func (r *runners) createVM(cmd *cobra.Command, args []string) error { publicKeys = append(publicKeys, publicKey) } + var rbacPolicyID string + if r.args.createVMRBACPolicyName != "" { + p, err := r.kotsAPI.GetPolicyByName(r.args.createVMRBACPolicyName) + if err != nil { + return errors.Wrap(err, "get rbac policy") + } + rbacPolicyID = p.ID + } + opts := kotsclient.CreateVMOpts{ Name: r.args.createVMName, Distribution: r.args.createVMDistribution, @@ -115,6 +126,7 @@ func (r *runners) createVM(cmd *cobra.Command, args []string) error { Tags: tags, PublicKeys: publicKeys, DryRun: r.args.createVMDryRun, + RBACPolicyID: rbacPolicyID, } vms, err := r.createAndWaitForVM(opts) diff --git a/cli/cmd/vm_update_rbac_policy.go b/cli/cmd/vm_update_rbac_policy.go new file mode 100644 index 000000000..d58652901 --- /dev/null +++ b/cli/cmd/vm_update_rbac_policy.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/replicatedhq/replicated/pkg/platformclient" + "github.com/spf13/cobra" +) + +func (r *runners) InitVMUpdateRBACPolicy(parent *cobra.Command) *cobra.Command { + cmd := &cobra.Command{ + Use: "rbac-policy [ID_OR_NAME]", + Hidden: true, + Short: "(alpha) Update the RBAC policy assigned to a VM.", + Long: `(alpha) The 'rbac-policy' command assigns or removes the RBAC policy on a running VM. + +When a policy is assigned, the VM's OIDC client credentials are used by the replicated CLI +inside the VM to authenticate with vendor-api automatically using that policy's permissions. +Pass an empty string to '--rbac-policy-name' to remove the policy from the VM. + +Note: this feature is currently in alpha and requires the cmx_vm_rbac feature flag to be enabled.`, + Example: `# Assign an RBAC policy to a VM by VM ID +replicated vm update rbac-policy aaaaa11 --rbac-policy-name "Read Only" + +# Assign an RBAC policy to a VM by VM name +replicated vm update rbac-policy my-test-vm --rbac-policy-name "Read Only" + +# Remove the RBAC policy from a VM +replicated vm update rbac-policy my-test-vm --rbac-policy-name ""`, + RunE: r.updateVMRBACPolicy, + SilenceUsage: true, + ValidArgsFunction: r.completeVMIDsAndNames, + } + parent.AddCommand(cmd) + + cmd.Flags().StringVar(&r.args.updateVMRBACPolicyName, "rbac-policy-name", "", "(alpha) Name of the RBAC policy to assign to the VM (pass empty string to remove)") + cmd.MarkFlagRequired("rbac-policy-name") + + return cmd +} + +func (r *runners) updateVMRBACPolicy(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + vmID, err := r.getVMIDFromArg(args[0]) + if err != nil { + return errors.Wrap(err, "get vm id from arg") + } + r.args.updateVMID = vmID + } else if err := r.ensureUpdateVMIDArg(args); err != nil { + return errors.Wrap(err, "ensure vm id arg") + } + + var policyID string + if r.args.updateVMRBACPolicyName != "" { + p, err := r.kotsAPI.GetPolicyByName(r.args.updateVMRBACPolicyName) + if err != nil { + return errors.Wrap(err, "get rbac policy") + } + policyID = p.ID + } + + if err := r.kotsAPI.UpdateVMRBACPolicy(r.args.updateVMID, policyID); err != nil { + if errors.Cause(err) == platformclient.ErrForbidden { + return ErrCompatibilityMatrixTermsNotAccepted + } + return errors.Wrap(err, "update vm rbac policy") + } + + fmt.Fprintln(r.w, "RBAC policy updated.") + return nil +} diff --git a/pkg/cmxmetadata/cmxmetadata.go b/pkg/cmxmetadata/cmxmetadata.go new file mode 100644 index 000000000..6e261bbec --- /dev/null +++ b/pkg/cmxmetadata/cmxmetadata.go @@ -0,0 +1,168 @@ +package cmxmetadata + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +const ( + mmdsIPv4Addr = "169.254.170.254" + mmdsPath = "/latest/vendor-api" + mmdsTimeout = 500 * time.Millisecond // fail fast if not in CMX + tokenLeeway = 60 * time.Second // refresh token this early before expiry +) + +// ErrNotAvailable is returned when the CMX metadata service is not reachable. +// This is the normal case when the CLI is not running inside a Firecracker VM. +var ErrNotAvailable = errors.New("CMX metadata service not available") + +// VMMetadata holds the OIDC client credentials provisioned by vendor-api into +// the Firecracker MMDS. +type VMMetadata struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + APIURL string `json:"api_url"` + TokenEndpoint string `json:"token_endpoint"` +} + +// GetVMMetadata attempts to read OIDC credentials from the Firecracker MMDS. +// It returns ErrNotAvailable if the metadata service is not reachable (i.e. +// the CLI is not running inside a CMX VM). +// +// Firecracker MMDS v1 returns a newline-separated key listing when querying a +// nested object path. Sending Accept: application/json causes it to return the +// full JSON subtree instead, which is what we need to parse in one request. +func GetVMMetadata() (*VMMetadata, error) { + client := &http.Client{ + Timeout: mmdsTimeout, + } + + mmdsURL := fmt.Sprintf("http://%s%s", mmdsIPv4Addr, mmdsPath) + req, err := http.NewRequest(http.MethodGet, mmdsURL, nil) + if err != nil { + return nil, ErrNotAvailable + } + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, ErrNotAvailable + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, ErrNotAvailable + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, ErrNotAvailable + } + + var meta VMMetadata + if err := json.Unmarshal(body, &meta); err != nil { + return nil, ErrNotAvailable + } + + if meta.ClientID == "" || meta.ClientSecret == "" { + return nil, ErrNotAvailable + } + + return &meta, nil +} + +// tokenCache holds a cached access token along with its expiry time. +type tokenCache struct { + mu sync.Mutex + token string + expiresAt time.Time +} + +// package-level cache shared across all calls within a process lifetime. +var cache = &tokenCache{} + +// GetAccessToken returns a valid access token for the given VMMetadata, using a +// cached token when possible and refreshing it when it is about to expire. +func GetAccessToken(meta *VMMetadata) (string, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + if cache.token != "" && time.Until(cache.expiresAt) > tokenLeeway { + return cache.token, nil + } + + token, expiresAt, err := exchangeCredentials(meta) + if err != nil { + return "", err + } + + cache.token = token + if expiresAt != nil { + cache.expiresAt = *expiresAt + } + + return token, nil +} + +// tokenResponse is the JSON response from the token endpoint. +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +// exchangeCredentials performs the client_credentials grant against the token +// endpoint and returns the access token along with its absolute expiry time. +func exchangeCredentials(meta *VMMetadata) (string, *time.Time, error) { + formData := url.Values{} + formData.Set("grant_type", "client_credentials") + formData.Set("client_id", meta.ClientID) + formData.Set("client_secret", meta.ClientSecret) + + client := &http.Client{ + Timeout: 10 * time.Second, + } + + resp, err := client.Post( + meta.TokenEndpoint, + "application/x-www-form-urlencoded", + strings.NewReader(formData.Encode()), + ) + if err != nil { + return "", nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, fmt.Errorf("reading token response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp tokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return "", nil, fmt.Errorf("token endpoint returned empty access_token") + } + + var expiresAt *time.Time + if tokenResp.ExpiresIn > 0 { + t := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + expiresAt = &t + } + + return tokenResp.AccessToken, expiresAt, nil +} diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index f666488d5..2979238ec 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -7,6 +7,7 @@ import ( "path" "path/filepath" + "github.com/replicatedhq/replicated/pkg/cmxmetadata" "github.com/replicatedhq/replicated/pkg/credentials/types" ) @@ -54,6 +55,7 @@ func GetCurrentCredentials() (*types.Credentials, error) { // 2. Named profile (if profileName is provided) // 3. Default profile from config file (if profileName is empty) // 4. Legacy single token from config file (backward compatibility) +// 5. CMX VM metadata service (automatic OIDC auth inside Firecracker VMs) func GetCredentialsWithProfile(profileName string) (*types.Credentials, error) { // Priority 1: Check environment variables first envCredentials, err := getEnvCredentials() @@ -82,6 +84,17 @@ func GetCredentialsWithProfile(profileName string) (*types.Credentials, error) { return configFileCredentials, nil } + // Priority 5: CMX VM metadata service (automatic OIDC auth inside Firecracker VMs). + // This is last so that any explicitly configured token always takes precedence, + // making it easy to override the VM identity during development or testing. + vmMeta, err := cmxmetadata.GetVMMetadata() + if err == nil { + token, err := cmxmetadata.GetAccessToken(vmMeta) + if err == nil { + return &types.Credentials{APIToken: token, IsCMX: true, APIOrigin: vmMeta.APIURL}, nil + } + } + return nil, ErrCredentialsNotFound } diff --git a/pkg/credentials/types/types.go b/pkg/credentials/types/types.go index 569b2fec2..e1ea2b3a8 100644 --- a/pkg/credentials/types/types.go +++ b/pkg/credentials/types/types.go @@ -7,6 +7,12 @@ type Credentials struct { IsEnv bool `json:"-"` IsConfigFile bool `json:"-"` IsProfile bool `json:"-"` + IsCMX bool `json:"-"` + + // APIOrigin is populated when IsCMX is true. It holds the api_url from the + // Firecracker MMDS so the CLI can talk to the same vendor-api instance that + // issued the token, without requiring REPLICATED_API_ORIGIN to be set manually. + APIOrigin string `json:"-"` } // Profile represents a named authentication profile diff --git a/pkg/kotsclient/policy_list.go b/pkg/kotsclient/policy_list.go new file mode 100644 index 000000000..5b1e654f2 --- /dev/null +++ b/pkg/kotsclient/policy_list.go @@ -0,0 +1,39 @@ +package kotsclient + +import ( + "context" + "fmt" + "net/http" + + "github.com/replicatedhq/replicated/pkg/types" +) + +type ListPoliciesResponse struct { + Policies []*types.Policy `json:"policies"` +} + +// ListPolicies returns all RBAC policies for the authenticated team. +func (c *VendorV3Client) ListPolicies() ([]*types.Policy, error) { + resp := ListPoliciesResponse{} + if err := c.DoJSON(context.TODO(), "GET", "/v3/policies", http.StatusOK, nil, &resp); err != nil { + return nil, err + } + return resp.Policies, nil +} + +// GetPolicyByName returns the RBAC policy with the given name, or an error if +// no policy with that name exists for the team. +func (c *VendorV3Client) GetPolicyByName(name string) (*types.Policy, error) { + policies, err := c.ListPolicies() + if err != nil { + return nil, fmt.Errorf("list policies: %w", err) + } + + for _, p := range policies { + if p.Name == name { + return p, nil + } + } + + return nil, fmt.Errorf("policy %q not found", name) +} diff --git a/pkg/kotsclient/vm_create.go b/pkg/kotsclient/vm_create.go index 9c7e8184c..174a692fd 100644 --- a/pkg/kotsclient/vm_create.go +++ b/pkg/kotsclient/vm_create.go @@ -22,6 +22,7 @@ type CreateVMRequest struct { InstanceType string `json:"instance_type"` Tags []types.Tag `json:"tags"` PublicKeys []string `json:"public_keys,omitempty"` + RBACPolicyID string `json:"rbac_policy_id,omitempty"` } type CreateVMResponse struct { @@ -48,6 +49,7 @@ type CreateVMOpts struct { Tags []types.Tag PublicKeys []string DryRun bool + RBACPolicyID string } type CreateVMErrorResponse struct { @@ -77,6 +79,7 @@ func (c *VendorV3Client) CreateVM(opts CreateVMOpts) ([]*types.VM, *CreateVMErro InstanceType: opts.InstanceType, Tags: opts.Tags, PublicKeys: opts.PublicKeys, + RBACPolicyID: opts.RBACPolicyID, } if opts.DryRun { diff --git a/pkg/kotsclient/vm_update_rbac_policy.go b/pkg/kotsclient/vm_update_rbac_policy.go new file mode 100644 index 000000000..e09509618 --- /dev/null +++ b/pkg/kotsclient/vm_update_rbac_policy.go @@ -0,0 +1,19 @@ +package kotsclient + +import ( + "context" + "fmt" + "net/http" +) + +type UpdateVMRBACPolicyRequest struct { + PolicyID string `json:"policy_id"` +} + +func (c *VendorV3Client) UpdateVMRBACPolicy(vmID, policyID string) error { + req := UpdateVMRBACPolicyRequest{ + PolicyID: policyID, + } + endpoint := fmt.Sprintf("/v3/vm/%s/rbac-policy", vmID) + return c.DoJSON(context.TODO(), "PUT", endpoint, http.StatusNoContent, req, nil) +} diff --git a/pkg/types/policy.go b/pkg/types/policy.go new file mode 100644 index 000000000..656f96a7b --- /dev/null +++ b/pkg/types/policy.go @@ -0,0 +1,15 @@ +package types + +import "time" + +// Policy represents a team RBAC policy as returned by the vendor-api. +type Policy struct { + ID string `json:"id"` + TeamID string `json:"teamId"` + Name string `json:"name"` + Description string `json:"description"` + Definition string `json:"definition"` + CreatedAt time.Time `json:"createdAt"` + ModifiedAt *time.Time `json:"modifiedAt"` + ReadOnly bool `json:"readOnly"` +}