Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions controllers/object_controls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3298,16 +3298,12 @@ func resolveDriverTag(n ClusterPolicyController, driverSpec interface{}) (string
return image, nil
}

func getOSName(osTag string) string {
// Extract base OS ID by stripping version suffix from osTag
// Examples: "rhel10" -> "rhel", "ubuntu22.04" -> "ubuntu", "rocky9" -> "rocky"
osID := strings.TrimRight(osTag, "0123456789.")
return osID
}

// getRepoConfigPath returns the standard OS specific path for repository configuration files.
func (n ClusterPolicyController) getRepoConfigPath() (string, error) {
osID := getOSName(n.gpuNodeOSTag)
osID := n.gpuNodeOSRelease
if osID == "" {
return "", fmt.Errorf("GPU node OS name is empty")
}
if path, ok := RepoConfigPathMap[osID]; ok {
return path, nil
}
Expand All @@ -3316,7 +3312,10 @@ func (n ClusterPolicyController) getRepoConfigPath() (string, error) {

// getCertConfigPath returns the standard OS specific path for ssl keys/certificates.
func (n ClusterPolicyController) getCertConfigPath() (string, error) {
osID := getOSName(n.gpuNodeOSTag)
osID := n.gpuNodeOSRelease
if osID == "" {
return "", fmt.Errorf("GPU node OS name is empty")
}
if path, ok := CertConfigPathMap[osID]; ok {
return path, nil
}
Expand All @@ -3326,7 +3325,10 @@ func (n ClusterPolicyController) getCertConfigPath() (string, error) {
// getSubscriptionPathsToVolumeSources returns the MountPathToVolumeSource map containing all
// OS-specific subscription/entitlement paths that need to be mounted in the container.
func (n ClusterPolicyController) getSubscriptionPathsToVolumeSources() (MountPathToVolumeSource, error) {
osID := getOSName(n.gpuNodeOSTag)
osID := n.gpuNodeOSRelease
if osID == "" {
return nil, fmt.Errorf("GPU node OS name is empty")
}
if pathToVolumeSource, ok := SubscriptionPathMap[osID]; ok {
return pathToVolumeSource, nil
}
Expand Down Expand Up @@ -3594,7 +3596,10 @@ func transformDriverContainer(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicy
}
}

osID := getOSName(n.gpuNodeOSTag)
osID := n.gpuNodeOSRelease
if osID == "" {
return fmt.Errorf("ERROR: failed to determine GPU node OS name")
}
// set up subscription entitlements for RHEL(using K8s with a non-CRIO runtime) and SLES
if (osID == "rhel" && n.openshift == "" && n.runtime != gpuv1.CRIO) || osID == "sles" || osID == "sl-micro" {
n.logger.Info("Mounting subscriptions into the driver container", "OS", osID)
Expand Down
3 changes: 2 additions & 1 deletion controllers/object_controls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,14 @@ func setup() error {
if gpuNodeCount == 0 {
return fmt.Errorf("no gpu nodes in mock cluster")
}
gpuNodeOSTag, err := clusterPolicyController.getGPUNodeOSTag()
gpuNodeOSRelease, gpuNodeOSTag, err := clusterPolicyController.getGPUNodeOSInfo()
if err != nil {
return fmt.Errorf("unable to get GPU node tag: %w", err)
}

clusterPolicyController.hasGPUNodes = gpuNodeCount != 0
clusterPolicyController.hasNFDLabels = hasNFDLabels
clusterPolicyController.gpuNodeOSRelease = gpuNodeOSRelease
clusterPolicyController.gpuNodeOSTag = gpuNodeOSTag

// setup kernelVersionMap for pre-compiled driver tests
Expand Down
57 changes: 40 additions & 17 deletions controllers/state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,12 @@ type ClusterPolicyController struct {
openshift string
ocpDriverToolkit OpenShiftDriverToolkit

runtime gpuv1.Runtime
gpuNodeOSTag string
hasGPUNodes bool
hasNFDLabels bool
sandboxEnabled bool
runtime gpuv1.Runtime
gpuNodeOSTag string
gpuNodeOSRelease string
hasGPUNodes bool
hasNFDLabels bool
sandboxEnabled bool
}

func addState(n *ClusterPolicyController, path string) {
Expand Down Expand Up @@ -637,7 +638,7 @@ func getRuntimeString(node corev1.Node) (gpuv1.Runtime, error) {
return runtime, nil
}

func (n *ClusterPolicyController) getGPUNodeOSTag() (string, error) {
func (n *ClusterPolicyController) getGPUNodeOSInfo() (string, string, error) {
ctx := n.ctx
opts := []client.ListOption{
client.MatchingLabels(map[string]string{commonGPULabelKey: commonGPULabelValue}),
Expand All @@ -646,34 +647,55 @@ func (n *ClusterPolicyController) getGPUNodeOSTag() (string, error) {
nodeList := &corev1.NodeList{}
err := n.client.List(ctx, nodeList, opts...)
if err != nil {
return "", fmt.Errorf("unable to list nodes with GPU present: %w", err)
return "", "", fmt.Errorf("unable to list nodes with GPU present: %w", err)
}
if len(nodeList.Items) == 0 {
return "", fmt.Errorf("no nodes found with GPU present")
return "", "", fmt.Errorf("no nodes found with GPU present")
}

labels := nodeList.Items[0].Labels
osName, ok := labels[nfdOSReleaseIDLabelKey]
if !ok {
return "", fmt.Errorf("unable to retrieve OS name from label %s", nfdOSReleaseIDLabelKey)
return "", "", fmt.Errorf("unable to retrieve OS name from label %s", nfdOSReleaseIDLabelKey)
}
osVersion, ok := labels[nfdOSVersionIDLabelKey]
if !ok {
return "", fmt.Errorf("unable to retrieve OS version from label %s", nfdOSVersionIDLabelKey)
return "", "", fmt.Errorf("unable to retrieve OS version from label %s", nfdOSVersionIDLabelKey)
}
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return "", fmt.Errorf("error processing OS major version %s: %w", osMajorVersion, err)
}

// If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag
if osName == "rocky" || (osName == "rhel" && osMajorNumber >= 10) {
switch osName {
case "rocky":
osVersion = osMajorVersion
case "rhel":
osMajorNumber, err := parseOSMajorVersion(osVersion)
if err != nil {
return "", "", err
}
if osMajorNumber >= 10 {
osVersion = osMajorVersion
}
}
osTag := fmt.Sprintf("%s%s", osName, osVersion)

return osTag, nil
return osName, osTag, nil
}

func parseOSMajorVersion(osVersion string) (int, error) {
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorVersion = strings.TrimSpace(osMajorVersion)
osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V")
if osMajorVersion == "" {
return 0, fmt.Errorf("empty OS major version")
}

osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return 0, fmt.Errorf("error processing OS major version %s: %w", osMajorVersion, err)
}

return osMajorNumber, nil
}

func (n *ClusterPolicyController) setPodSecurityLabelsForNamespace() error {
Expand Down Expand Up @@ -939,10 +961,11 @@ func (n *ClusterPolicyController) init(ctx context.Context, reconciler *ClusterP
n.hasNFDLabels = hasNFDLabels

if n.hasGPUNodes {
gpuNodeOSTag, err := n.getGPUNodeOSTag()
gpuNodeOSRelease, gpuNodeOSTag, err := n.getGPUNodeOSInfo()
if err != nil {
return fmt.Errorf("failed to retrieve GPU node OS tag: %w", err)
}
n.gpuNodeOSRelease = gpuNodeOSRelease
n.gpuNodeOSTag = gpuNodeOSTag
}
// fetch all nodes and annotate gpu nodes
Expand Down
103 changes: 103 additions & 0 deletions controllers/state_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,119 @@
package controllers

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client/fake"

gpuv1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1"
)

func TestGetGPUNodeOSInfo(t *testing.T) {
testCases := []struct {
name string
osName string
osVersion string
expected string
expectError bool
errorContainsText string
}{
{
name: "talos version with v prefix",
osName: "talos",
osVersion: "v1.12.6",
expected: "talosv1.12.6",
},
{
name: "rhel 10 omits minor version",
osName: "rhel",
osVersion: "10.2",
expected: "rhel10",
},
{
name: "rocky omits minor version",
osName: "rocky",
osVersion: "9.5",
expected: "rocky9",
},
{
name: "ubuntu preserves full version",
osName: "ubuntu",
osVersion: "24.04",
expected: "ubuntu24.04",
},
{
name: "sles preserves dotted version",
osName: "sles",
osVersion: "15.6",
expected: "sles15.6",
},
{
name: "sles preserves service-pack version",
osName: "sles",
osVersion: "15-SP6",
expected: "sles15-SP6",
},
{
name: "sl-micro preserves dotted version",
osName: "sl-micro",
osVersion: "6.0",
expected: "sl-micro6.0",
},
{
name: "archlinux preserves rolling version",
osName: "archlinux",
osVersion: "rolling",
expected: "archlinuxrolling",
},
{
name: "rhel invalid major version errors",
osName: "rhel",
osVersion: "A.10",
expectError: true,
errorContainsText: "error processing OS major version",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
scheme := runtime.NewScheme()
require.NoError(t, corev1.AddToScheme(scheme))

node := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "gpu-node-1",
Labels: map[string]string{
commonGPULabelKey: commonGPULabelValue,
nfdOSReleaseIDLabelKey: tc.osName,
nfdOSVersionIDLabelKey: tc.osVersion,
},
},
}

client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(node).Build()
controller := ClusterPolicyController{ctx: context.Background(), client: client}

osName, osTag, err := controller.getGPUNodeOSInfo()
if tc.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), tc.errorContainsText)
return
}

require.NoError(t, err)
require.Equal(t, tc.osName, osName)
require.Equal(t, tc.expected, osTag)
})
}
}

func TestGetRuntimeString(t *testing.T) {
testCases := []struct {
description string
Expand Down
12 changes: 6 additions & 6 deletions controllers/transforms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3016,7 +3016,7 @@ func TestTransformDriver(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec,
ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"})
if tc.errorExpected {
require.Error(t, err)
return
Expand Down Expand Up @@ -3430,7 +3430,7 @@ func TestTransformDriverWithLicensingConfig(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec,
ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"})
if tc.errorExpected {
require.Error(t, err)
return
Expand Down Expand Up @@ -3562,7 +3562,7 @@ func TestTransformDriverWithResources(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec,
ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"})
if tc.errorExpected {
require.Error(t, err)
return
Expand Down Expand Up @@ -3657,7 +3657,7 @@ func TestTransformDriverRDMA(t *testing.T) {

err := TransformDriver(ds.DaemonSet, cpSpec,
ClusterPolicyController{client: mockClient, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"})
require.NoError(t, err)

require.EqualValues(t, expectedDs, ds)
Expand Down Expand Up @@ -3740,7 +3740,7 @@ func TestTransformDriverVGPUTopologyConfig(t *testing.T) {

err := TransformDriver(ds.DaemonSet, cpSpec,
ClusterPolicyController{client: mockClient, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"})
require.NoError(t, err)
require.EqualValues(t, expectedDs, ds)
}
Expand Down Expand Up @@ -4173,7 +4173,7 @@ func TestTransformDriverWithAdditionalConfig(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec,
ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd,
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu24.04"})
operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu24.04"})
if tc.errorExpected {
require.Error(t, err)
require.Equal(t, tc.errorMessage, err.Error())
Expand Down
35 changes: 29 additions & 6 deletions internal/state/nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,40 @@ func getNodePools(ctx context.Context, k8sClient client.Client, selector map[str

func getOSTag(osRelease, osVersion string) (string, error) {
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return "", fmt.Errorf("failed to parse os version: %w", err)
}

var osTagSuffix string
// If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag
if osRelease == "rocky" || (osRelease == "rhel" && osMajorNumber >= 10) {
switch osRelease {
case "rocky":
osTagSuffix = osMajorVersion
} else {
case "rhel":
osMajorNumber, err := parseOSMajorVersion(osVersion)
if err != nil {
return "", fmt.Errorf("failed to parse os version: %w", err)
}
if osMajorNumber >= 10 {
osTagSuffix = osMajorVersion
} else {
osTagSuffix = osVersion
}
default:
osTagSuffix = osVersion
}
return fmt.Sprintf("%s%s", osRelease, osTagSuffix), nil
}

func parseOSMajorVersion(osVersion string) (int, error) {
osMajorVersion := strings.Split(osVersion, ".")[0]
osMajorVersion = strings.TrimSpace(osMajorVersion)
osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V")
if osMajorVersion == "" {
return 0, fmt.Errorf("empty OS major version")
}

osMajorNumber, err := strconv.Atoi(osMajorVersion)
if err != nil {
return 0, err
}

return osMajorNumber, nil
}
Loading
Loading