diff --git a/controllers/object_controls.go b/controllers/object_controls.go index b436bcab1..07b621274 100644 --- a/controllers/object_controls.go +++ b/controllers/object_controls.go @@ -3298,16 +3298,12 @@ func resolveDriverTag(n ClusterPolicyController, driverSpec interface{}) (string return image, nil } -func getOSName(osTag string) string { - // Extract base OS ID by stripping version suffix from osTag - // Examples: "rhel10" -> "rhel", "ubuntu22.04" -> "ubuntu", "rocky9" -> "rocky" - osID := strings.TrimRight(osTag, "0123456789.") - return osID -} - // getRepoConfigPath returns the standard OS specific path for repository configuration files. func (n ClusterPolicyController) getRepoConfigPath() (string, error) { - osID := getOSName(n.gpuNodeOSTag) + osID := n.gpuNodeOSRelease + if osID == "" { + return "", fmt.Errorf("GPU node OS name is empty") + } if path, ok := RepoConfigPathMap[osID]; ok { return path, nil } @@ -3316,7 +3312,10 @@ func (n ClusterPolicyController) getRepoConfigPath() (string, error) { // getCertConfigPath returns the standard OS specific path for ssl keys/certificates. func (n ClusterPolicyController) getCertConfigPath() (string, error) { - osID := getOSName(n.gpuNodeOSTag) + osID := n.gpuNodeOSRelease + if osID == "" { + return "", fmt.Errorf("GPU node OS name is empty") + } if path, ok := CertConfigPathMap[osID]; ok { return path, nil } @@ -3326,7 +3325,10 @@ func (n ClusterPolicyController) getCertConfigPath() (string, error) { // getSubscriptionPathsToVolumeSources returns the MountPathToVolumeSource map containing all // OS-specific subscription/entitlement paths that need to be mounted in the container. func (n ClusterPolicyController) getSubscriptionPathsToVolumeSources() (MountPathToVolumeSource, error) { - osID := getOSName(n.gpuNodeOSTag) + osID := n.gpuNodeOSRelease + if osID == "" { + return nil, fmt.Errorf("GPU node OS name is empty") + } if pathToVolumeSource, ok := SubscriptionPathMap[osID]; ok { return pathToVolumeSource, nil } @@ -3594,7 +3596,10 @@ func transformDriverContainer(obj *appsv1.DaemonSet, config *gpuv1.ClusterPolicy } } - osID := getOSName(n.gpuNodeOSTag) + osID := n.gpuNodeOSRelease + if osID == "" { + return fmt.Errorf("ERROR: failed to determine GPU node OS name") + } // set up subscription entitlements for RHEL(using K8s with a non-CRIO runtime) and SLES if (osID == "rhel" && n.openshift == "" && n.runtime != gpuv1.CRIO) || osID == "sles" || osID == "sl-micro" { n.logger.Info("Mounting subscriptions into the driver container", "OS", osID) diff --git a/controllers/object_controls_test.go b/controllers/object_controls_test.go index f6df7340d..39ad27679 100644 --- a/controllers/object_controls_test.go +++ b/controllers/object_controls_test.go @@ -225,13 +225,14 @@ func setup() error { if gpuNodeCount == 0 { return fmt.Errorf("no gpu nodes in mock cluster") } - gpuNodeOSTag, err := clusterPolicyController.getGPUNodeOSTag() + gpuNodeOSRelease, gpuNodeOSTag, err := clusterPolicyController.getGPUNodeOSInfo() if err != nil { return fmt.Errorf("unable to get GPU node tag: %w", err) } clusterPolicyController.hasGPUNodes = gpuNodeCount != 0 clusterPolicyController.hasNFDLabels = hasNFDLabels + clusterPolicyController.gpuNodeOSRelease = gpuNodeOSRelease clusterPolicyController.gpuNodeOSTag = gpuNodeOSTag // setup kernelVersionMap for pre-compiled driver tests diff --git a/controllers/state_manager.go b/controllers/state_manager.go index e7b11ca69..d78b5403f 100644 --- a/controllers/state_manager.go +++ b/controllers/state_manager.go @@ -164,11 +164,12 @@ type ClusterPolicyController struct { openshift string ocpDriverToolkit OpenShiftDriverToolkit - runtime gpuv1.Runtime - gpuNodeOSTag string - hasGPUNodes bool - hasNFDLabels bool - sandboxEnabled bool + runtime gpuv1.Runtime + gpuNodeOSTag string + gpuNodeOSRelease string + hasGPUNodes bool + hasNFDLabels bool + sandboxEnabled bool } func addState(n *ClusterPolicyController, path string) { @@ -637,7 +638,7 @@ func getRuntimeString(node corev1.Node) (gpuv1.Runtime, error) { return runtime, nil } -func (n *ClusterPolicyController) getGPUNodeOSTag() (string, error) { +func (n *ClusterPolicyController) getGPUNodeOSInfo() (string, string, error) { ctx := n.ctx opts := []client.ListOption{ client.MatchingLabels(map[string]string{commonGPULabelKey: commonGPULabelValue}), @@ -646,34 +647,55 @@ func (n *ClusterPolicyController) getGPUNodeOSTag() (string, error) { nodeList := &corev1.NodeList{} err := n.client.List(ctx, nodeList, opts...) if err != nil { - return "", fmt.Errorf("unable to list nodes with GPU present: %w", err) + return "", "", fmt.Errorf("unable to list nodes with GPU present: %w", err) } if len(nodeList.Items) == 0 { - return "", fmt.Errorf("no nodes found with GPU present") + return "", "", fmt.Errorf("no nodes found with GPU present") } labels := nodeList.Items[0].Labels osName, ok := labels[nfdOSReleaseIDLabelKey] if !ok { - return "", fmt.Errorf("unable to retrieve OS name from label %s", nfdOSReleaseIDLabelKey) + return "", "", fmt.Errorf("unable to retrieve OS name from label %s", nfdOSReleaseIDLabelKey) } osVersion, ok := labels[nfdOSVersionIDLabelKey] if !ok { - return "", fmt.Errorf("unable to retrieve OS version from label %s", nfdOSVersionIDLabelKey) + return "", "", fmt.Errorf("unable to retrieve OS version from label %s", nfdOSVersionIDLabelKey) } osMajorVersion := strings.Split(osVersion, ".")[0] - osMajorNumber, err := strconv.Atoi(osMajorVersion) - if err != nil { - return "", fmt.Errorf("error processing OS major version %s: %w", osMajorVersion, err) - } // If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag - if osName == "rocky" || (osName == "rhel" && osMajorNumber >= 10) { + switch osName { + case "rocky": osVersion = osMajorVersion + case "rhel": + osMajorNumber, err := parseOSMajorVersion(osVersion) + if err != nil { + return "", "", err + } + if osMajorNumber >= 10 { + osVersion = osMajorVersion + } } osTag := fmt.Sprintf("%s%s", osName, osVersion) - return osTag, nil + return osName, osTag, nil +} + +func parseOSMajorVersion(osVersion string) (int, error) { + osMajorVersion := strings.Split(osVersion, ".")[0] + osMajorVersion = strings.TrimSpace(osMajorVersion) + osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V") + if osMajorVersion == "" { + return 0, fmt.Errorf("empty OS major version") + } + + osMajorNumber, err := strconv.Atoi(osMajorVersion) + if err != nil { + return 0, fmt.Errorf("error processing OS major version %s: %w", osMajorVersion, err) + } + + return osMajorNumber, nil } func (n *ClusterPolicyController) setPodSecurityLabelsForNamespace() error { @@ -939,10 +961,11 @@ func (n *ClusterPolicyController) init(ctx context.Context, reconciler *ClusterP n.hasNFDLabels = hasNFDLabels if n.hasGPUNodes { - gpuNodeOSTag, err := n.getGPUNodeOSTag() + gpuNodeOSRelease, gpuNodeOSTag, err := n.getGPUNodeOSInfo() if err != nil { return fmt.Errorf("failed to retrieve GPU node OS tag: %w", err) } + n.gpuNodeOSRelease = gpuNodeOSRelease n.gpuNodeOSTag = gpuNodeOSTag } // fetch all nodes and annotate gpu nodes diff --git a/controllers/state_manager_test.go b/controllers/state_manager_test.go index 35c245be0..35150421e 100644 --- a/controllers/state_manager_test.go +++ b/controllers/state_manager_test.go @@ -17,16 +17,119 @@ package controllers import ( + "context" "errors" "testing" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" gpuv1 "github.com/NVIDIA/gpu-operator/api/nvidia/v1" ) +func TestGetGPUNodeOSInfo(t *testing.T) { + testCases := []struct { + name string + osName string + osVersion string + expected string + expectError bool + errorContainsText string + }{ + { + name: "talos version with v prefix", + osName: "talos", + osVersion: "v1.12.6", + expected: "talosv1.12.6", + }, + { + name: "rhel 10 omits minor version", + osName: "rhel", + osVersion: "10.2", + expected: "rhel10", + }, + { + name: "rocky omits minor version", + osName: "rocky", + osVersion: "9.5", + expected: "rocky9", + }, + { + name: "ubuntu preserves full version", + osName: "ubuntu", + osVersion: "24.04", + expected: "ubuntu24.04", + }, + { + name: "sles preserves dotted version", + osName: "sles", + osVersion: "15.6", + expected: "sles15.6", + }, + { + name: "sles preserves service-pack version", + osName: "sles", + osVersion: "15-SP6", + expected: "sles15-SP6", + }, + { + name: "sl-micro preserves dotted version", + osName: "sl-micro", + osVersion: "6.0", + expected: "sl-micro6.0", + }, + { + name: "archlinux preserves rolling version", + osName: "archlinux", + osVersion: "rolling", + expected: "archlinuxrolling", + }, + { + name: "rhel invalid major version errors", + osName: "rhel", + osVersion: "A.10", + expectError: true, + errorContainsText: "error processing OS major version", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + scheme := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(scheme)) + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gpu-node-1", + Labels: map[string]string{ + commonGPULabelKey: commonGPULabelValue, + nfdOSReleaseIDLabelKey: tc.osName, + nfdOSVersionIDLabelKey: tc.osVersion, + }, + }, + } + + client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(node).Build() + controller := ClusterPolicyController{ctx: context.Background(), client: client} + + osName, osTag, err := controller.getGPUNodeOSInfo() + if tc.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errorContainsText) + return + } + + require.NoError(t, err) + require.Equal(t, tc.osName, osName) + require.Equal(t, tc.expected, osTag) + }) + } +} + func TestGetRuntimeString(t *testing.T) { testCases := []struct { description string diff --git a/controllers/transforms_test.go b/controllers/transforms_test.go index 86ade02a2..4c17fba11 100644 --- a/controllers/transforms_test.go +++ b/controllers/transforms_test.go @@ -3016,7 +3016,7 @@ func TestTransformDriver(t *testing.T) { t.Run(tc.description, func(t *testing.T) { err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec, ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"}) if tc.errorExpected { require.Error(t, err) return @@ -3430,7 +3430,7 @@ func TestTransformDriverWithLicensingConfig(t *testing.T) { t.Run(tc.description, func(t *testing.T) { err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec, ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"}) if tc.errorExpected { require.Error(t, err) return @@ -3562,7 +3562,7 @@ func TestTransformDriverWithResources(t *testing.T) { t.Run(tc.description, func(t *testing.T) { err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec, ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"}) if tc.errorExpected { require.Error(t, err) return @@ -3657,7 +3657,7 @@ func TestTransformDriverRDMA(t *testing.T) { err := TransformDriver(ds.DaemonSet, cpSpec, ClusterPolicyController{client: mockClient, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"}) require.NoError(t, err) require.EqualValues(t, expectedDs, ds) @@ -3740,7 +3740,7 @@ func TestTransformDriverVGPUTopologyConfig(t *testing.T) { err := TransformDriver(ds.DaemonSet, cpSpec, ClusterPolicyController{client: mockClient, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu20.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu20.04"}) require.NoError(t, err) require.EqualValues(t, expectedDs, ds) } @@ -4173,7 +4173,7 @@ func TestTransformDriverWithAdditionalConfig(t *testing.T) { t.Run(tc.description, func(t *testing.T) { err := TransformDriver(tc.ds.DaemonSet, tc.cpSpec, ClusterPolicyController{client: tc.client, runtime: gpuv1.Containerd, - operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSTag: "ubuntu24.04"}) + operatorNamespace: "test-ns", logger: ctrl.Log.WithName("test"), gpuNodeOSRelease: "ubuntu", gpuNodeOSTag: "ubuntu24.04"}) if tc.errorExpected { require.Error(t, err) require.Equal(t, tc.errorMessage, err.Error()) diff --git a/internal/state/nodepool.go b/internal/state/nodepool.go index 9cbd1b447..c9781cce0 100644 --- a/internal/state/nodepool.go +++ b/internal/state/nodepool.go @@ -142,17 +142,40 @@ func getNodePools(ctx context.Context, k8sClient client.Client, selector map[str func getOSTag(osRelease, osVersion string) (string, error) { osMajorVersion := strings.Split(osVersion, ".")[0] - osMajorNumber, err := strconv.Atoi(osMajorVersion) - if err != nil { - return "", fmt.Errorf("failed to parse os version: %w", err) - } var osTagSuffix string // If the OS is RockyLinux or RHEL 10 & above, we will omit the minor version when constructing the os image tag - if osRelease == "rocky" || (osRelease == "rhel" && osMajorNumber >= 10) { + switch osRelease { + case "rocky": osTagSuffix = osMajorVersion - } else { + case "rhel": + osMajorNumber, err := parseOSMajorVersion(osVersion) + if err != nil { + return "", fmt.Errorf("failed to parse os version: %w", err) + } + if osMajorNumber >= 10 { + osTagSuffix = osMajorVersion + } else { + osTagSuffix = osVersion + } + default: osTagSuffix = osVersion } return fmt.Sprintf("%s%s", osRelease, osTagSuffix), nil } + +func parseOSMajorVersion(osVersion string) (int, error) { + osMajorVersion := strings.Split(osVersion, ".")[0] + osMajorVersion = strings.TrimSpace(osMajorVersion) + osMajorVersion = strings.TrimPrefix(strings.TrimPrefix(osMajorVersion, "v"), "V") + if osMajorVersion == "" { + return 0, fmt.Errorf("empty OS major version") + } + + osMajorNumber, err := strconv.Atoi(osMajorVersion) + if err != nil { + return 0, err + } + + return osMajorNumber, nil +} diff --git a/internal/state/nodepool_test.go b/internal/state/nodepool_test.go index 665f5fcda..6d175d7d2 100644 --- a/internal/state/nodepool_test.go +++ b/internal/state/nodepool_test.go @@ -59,6 +59,20 @@ func TestGetOSTag(t *testing.T) { expected: "rhel10", expectError: false, }, + { + description: "talos version with v prefix", + osRelease: "talos", + osVersion: "v1.12.6", + expected: "talosv1.12.6", + expectError: false, + }, + { + description: "archlinux rolling version", + osRelease: "archlinux", + osVersion: "rolling", + expected: "archlinuxrolling", + expectError: false, + }, { description: "invalid os version", osRelease: "rhel",