diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_controller.go b/pkg/driver/aws-ebs/aws_ebs_tags_controller.go index e42a1dd25..7e3805a15 100644 --- a/pkg/driver/aws-ebs/aws_ebs_tags_controller.go +++ b/pkg/driver/aws-ebs/aws_ebs_tags_controller.go @@ -73,6 +73,37 @@ type pvUpdateQueueItem struct { pvNames []string } +// ec2TagsAPI defines the EC2 API methods used by the tags controller. +// Using an interface allows mocking the EC2 client in unit tests. +type ec2TagsAPI interface { + CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) +} + +type failedTagError struct { + awsError error +} + +func (f *failedTagError) Error() string { + return f.awsError.Error() +} + +func (f *failedTagError) Unwrap() error { + return f.awsError +} + +type failWholeBatchError struct { + failedTagError +} + +type failOneOrMoreTagError struct { + failedTagError +} + +var _ error = &failedTagError{} +var _ error = &failWholeBatchError{} +var _ error = &failOneOrMoreTagError{} + func NewEBSVolumeTagsController( name string, commonClient *clients.Clients, @@ -270,7 +301,7 @@ func (c *EBSVolumeTagsController) getEBSCloudCredSecret() (*v1.Secret, error) { // processInfrastructure processes the Infrastructure resource and push pvNames for tags update in retry queue worker. func (c *EBSVolumeTagsController) processInfrastructure(infra *configv1.Infrastructure) error { if infra.Status.PlatformStatus != nil && infra.Status.PlatformStatus.AWS != nil && - infra.Status.PlatformStatus.AWS.ResourceTags != nil { + infra.Status.PlatformStatus.AWS.ResourceTags != nil && len(infra.Status.PlatformStatus.AWS.ResourceTags) > 0 { err := c.fetchAndPushPvsToQueue(infra) if err != nil { klog.Errorf("error processing PVs for infrastructure update: %v", err) @@ -308,20 +339,38 @@ func (c *EBSVolumeTagsController) fetchAndPushPvsToQueue(infra *configv1.Infrast return nil } -// updateEBSTags updates the tags of an AWS EBS volume with rate limiting -func (c *EBSVolumeTagsController) updateEBSTags(ctx context.Context, ec2Client *ec2.Client, resourceTags []configv1.AWSResourceTag, - pvs ...*v1.PersistentVolume) error { +// updateEBSTags updates the tags of an AWS EBS volume with rate limiting. +// It first checks if the volumes already have the desired tags and skips the +// CreateTags call for volumes that are already up to date. +// +// Returns error and whether tags were actually updated on AWS. +// A nil error with false return value implies all tags were already updated on AWS. +// A nil error with true return value implies some tags were updated on AWS. +func (c *EBSVolumeTagsController) updateEBSTags(ctx context.Context, ec2Client ec2TagsAPI, resourceTags []configv1.AWSResourceTag, + pvs ...*v1.PersistentVolume) (bool, error) { // Prepare tags tags := newAndUpdatedTags(resourceTags) - // Create or update the tags - _, err := ec2Client.CreateTags(ctx, &ec2.CreateTagsInput{ - Resources: pvsToResourceIDs(pvs), + + // Filter out volumes that already have all desired tags + pvsNeedingUpdate, err := filterVolumesNeedingTagUpdate(ctx, ec2Client, tags, pvs) + if err != nil { + return false, &failWholeBatchError{failedTagError{err}} + } + + if len(pvsNeedingUpdate) == 0 { + klog.V(4).Infof("All volumes already have the desired tags, skipping CreateTags call") + return false, nil + } + + // Create or update the tags only for volumes that need it + _, err = ec2Client.CreateTags(ctx, &ec2.CreateTagsInput{ + Resources: pvsToResourceIDs(pvsNeedingUpdate), Tags: tags, }) if err != nil { - return err + return false, &failOneOrMoreTagError{failedTagError{err}} } - return nil + return true, nil } // listPersistentVolumes lists the volume @@ -382,6 +431,79 @@ func newAndUpdatedTags(resourceTags []configv1.AWSResourceTag) []ec2types.Tag { return tags } +// volumeHasAllTags returns true if all desired tags already exist on the volume with matching values. +// Extra tags on the volume that are not in the desired set are ignored. +func volumeHasAllTags(existingTags map[string]string, desiredTags []ec2types.Tag) bool { + for _, tag := range desiredTags { + val, ok := existingTags[*tag.Key] + if !ok || val != *tag.Value { + return false + } + } + return true +} + +// filterVolumesNeedingTagUpdate calls DescribeTags to fetch existing tags and returns +// only the PVs whose AWS volumes do not already have all desired tags applied. +// If DescribeTags fails, all PVs are returned unchanged (fail-open). +func filterVolumesNeedingTagUpdate(ctx context.Context, ec2Client ec2TagsAPI, desiredTags []ec2types.Tag, pvs []*v1.PersistentVolume) ([]*v1.PersistentVolume, error) { + volumeIDs := pvsToResourceIDs(pvs) + if len(volumeIDs) == 0 { + return pvs, nil + } + + volumeTags, err := fetchTagsOnVolumes(ctx, ec2Client, volumeIDs) + if err != nil { + return pvs, err + } + + var needUpdate []*v1.PersistentVolume + for _, pv := range pvs { + existingTags, found := volumeTags[pv.Spec.CSI.VolumeHandle] + if !found || !volumeHasAllTags(existingTags, desiredTags) { + needUpdate = append(needUpdate, pv) + } else { + klog.V(4).Infof("Skipping tag update for volume %s (%s): all tags already present", pv.Name, pv.Spec.CSI.VolumeHandle) + } + } + return needUpdate, nil +} + +func fetchTagsOnVolumes(ctx context.Context, ec2Client ec2TagsAPI, volumeIDs []string) (map[string]map[string]string, error) { + volumeTags := make(map[string]map[string]string) + var nextToken *string + + for { + input := &ec2.DescribeTagsInput{ + Filters: []ec2types.Filter{ + {Name: aws.String("resource-id"), Values: volumeIDs}, + {Name: aws.String("resource-type"), Values: []string{"volume"}}, + }, + MaxResults: aws.Int32(1000), + NextToken: nextToken, + } + output, err := ec2Client.DescribeTags(ctx, input) + if err != nil { + return volumeTags, fmt.Errorf("fetching tags for one or more volumes with: %w", err) + } + + for _, td := range output.Tags { + if td.ResourceId == nil || td.Key == nil || td.Value == nil { + continue + } + if _, ok := volumeTags[*td.ResourceId]; !ok { + volumeTags[*td.ResourceId] = make(map[string]string) + } + volumeTags[*td.ResourceId][*td.Key] = *td.Value + } + nextToken = output.NextToken + if nextToken == nil { + break + } + } + return volumeTags, nil +} + // filterUpdatableVolumes filters the list of volumes whose tags needs to be updated. func (c *EBSVolumeTagsController) filterUpdatableVolumes(volumes []*v1.PersistentVolume, newTagsHash string) []*v1.PersistentVolume { var updatablePVs []*v1.PersistentVolume diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go b/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go index 202cb329c..0b124d043 100644 --- a/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go +++ b/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go @@ -449,3 +449,171 @@ func TestRemoveVolumesFromQueueSet(t *testing.T) { t.Error("removeVolumesFromQueueSet() incorrectly removed PV pv2 from queue set") } } + +func TestVolumeHasAllTags(t *testing.T) { + tests := []struct { + name string + existingTags map[string]string + desiredTags []ec2types.Tag + expected bool + }{ + { + name: "all desired tags present with matching values", + existingTags: map[string]string{ + "key1": "value1", + "key2": "value2", + "extra": "ignored", + }, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + {Key: aws.String("key2"), Value: aws.String("value2")}, + }, + expected: true, + }, + { + name: "missing a desired tag", + existingTags: map[string]string{ + "key1": "value1", + }, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + {Key: aws.String("key2"), Value: aws.String("value2")}, + }, + expected: false, + }, + { + name: "desired tag exists but value differs", + existingTags: map[string]string{ + "key1": "old-value", + }, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("new-value")}, + }, + expected: false, + }, + { + name: "empty desired tags always matches", + existingTags: map[string]string{}, + desiredTags: []ec2types.Tag{}, + expected: true, + }, + { + name: "no existing tags with desired tags", + existingTags: map[string]string{}, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + }, + expected: false, + }, + { + name: "nil existing tags map", + existingTags: nil, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + }, + expected: false, + }, + { + name: "all keys present but one value wrong", + existingTags: map[string]string{ + "key1": "value1", + "key2": "wrong-value", + "key3": "value3", + }, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + {Key: aws.String("key2"), Value: aws.String("value2")}, + {Key: aws.String("key3"), Value: aws.String("value3")}, + }, + expected: false, + }, + { + name: "many extra tags do not affect match", + existingTags: map[string]string{ + "aws:cloudformation:stack-name": "my-stack", + "kubernetes.io/cluster/test": "owned", + "Name": "my-volume", + "key1": "value1", + }, + desiredTags: []ec2types.Tag{ + {Key: aws.String("key1"), Value: aws.String("value1")}, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := volumeHasAllTags(tt.existingTags, tt.desiredTags) + if result != tt.expected { + t.Errorf("volumeHasAllTags() = %v, want %v", result, tt.expected) + } + }) + } +} + +// TestProcessInfrastructureSkipsWhenNoTags verifies that processInfrastructure +// returns nil without calling fetchAndPushPvsToQueue when ResourceTags is nil, +// empty, or the platform status is missing. The controller has no informers +// set up, so reaching fetchAndPushPvsToQueue would panic — a clean return +// confirms the guard condition works. +func TestProcessInfrastructureSkipsWhenNoTags(t *testing.T) { + c := &EBSVolumeTagsController{} + + tests := []struct { + name string + infra *configv1.Infrastructure + }{ + { + name: "nil PlatformStatus", + infra: &configv1.Infrastructure{ + Status: configv1.InfrastructureStatus{ + PlatformStatus: nil, + }, + }, + }, + { + name: "nil AWS in PlatformStatus", + infra: &configv1.Infrastructure{ + Status: configv1.InfrastructureStatus{ + PlatformStatus: &configv1.PlatformStatus{ + AWS: nil, + }, + }, + }, + }, + { + name: "nil ResourceTags", + infra: &configv1.Infrastructure{ + Status: configv1.InfrastructureStatus{ + PlatformStatus: &configv1.PlatformStatus{ + AWS: &configv1.AWSPlatformStatus{ + ResourceTags: nil, + }, + }, + }, + }, + }, + { + name: "empty ResourceTags", + infra: &configv1.Infrastructure{ + Status: configv1.InfrastructureStatus{ + PlatformStatus: &configv1.PlatformStatus{ + AWS: &configv1.AWSPlatformStatus{ + ResourceTags: []configv1.AWSResourceTag{}, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := c.processInfrastructure(tt.infra) + if err != nil { + t.Errorf("processInfrastructure() returned error: %v, want nil", err) + } + }) + } +} diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker.go b/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker.go index ad4717c0b..78ecfd33a 100644 --- a/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker.go +++ b/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker.go @@ -4,7 +4,6 @@ import ( "context" "errors" - "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/smithy-go" v1 "k8s.io/api/core/v1" @@ -89,7 +88,7 @@ func (c *EBSVolumeTagsController) needsTagUpdate(infra *configv1.Infrastructure, // their AWS EBS tags in bulk. If the tag update succeeds, it updates the PV annotations // with the new tag hash. In case of errors, failed PVs are re-queued individually // for retry with a backoff mechanism. -func (c *EBSVolumeTagsController) processBatchVolumes(ctx context.Context, item *pvUpdateQueueItem, infra *configv1.Infrastructure, ec2Client *ec2.Client) { +func (c *EBSVolumeTagsController) processBatchVolumes(ctx context.Context, item *pvUpdateQueueItem, infra *configv1.Infrastructure, ec2Client ec2TagsAPI) { pvList := make([]*v1.PersistentVolume, 0) for _, pvName := range item.pvNames { pv, err := c.getPersistentVolumeByName(pvName) @@ -113,11 +112,24 @@ func (c *EBSVolumeTagsController) processBatchVolumes(ctx context.Context, item return } // update the tags for the volume list. - err := c.updateEBSTags(ctx, ec2Client, infra.Status.PlatformStatus.AWS.ResourceTags, pvList...) + tagsUpdatedOnAWS, err := c.updateEBSTags(ctx, ec2Client, infra.Status.PlatformStatus.AWS.ResourceTags, pvList...) if err != nil { klog.Errorf("failed to update EBS tags: %v", err) - c.handleBatchTagUpdateFailure(pvList, err) - c.queue.Forget(item) + var batchErr *failWholeBatchError + var oneOrMoreErr *failOneOrMoreTagError + switch { + case errors.As(err, &batchErr): + // DescribeTags failed — re-queue the whole batch to retry later. + c.queue.AddRateLimited(item) + case errors.As(err, &oneOrMoreErr): + // CreateTags failed — one or more volumes may be bad (e.g. deleted). + // Re-queue each volume individually so the bad one can be identified + // and removed by processIndividualVolume's NotFound handling. + c.handleBatchTagUpdateFailure(pvList, err) + c.queue.Forget(item) + default: + c.queue.AddRateLimited(item) + } return } newTagsHash := computeTagsHash(infra.Status.PlatformStatus.AWS.ResourceTags) @@ -137,7 +149,7 @@ func (c *EBSVolumeTagsController) processBatchVolumes(ctx context.Context, item continue } c.removeVolumesFromQueueSet(volume.Name) - klog.Infof("Successfully updated PV annotations and tags for volume %s", volume.Name) + logTagCompletionMessage(volume.Name, tagsUpdatedOnAWS) } c.queue.Forget(item) } @@ -147,12 +159,13 @@ func (c *EBSVolumeTagsController) processBatchVolumes(ctx context.Context, item // If the tag update succeeds, it updates the PV annotations with the new tag hash. // If the PV is missing or the AWS volume does not exist, it removes it from the queue. // In case of errors, it re-queues the PV for retry with a backoff mechanism. -func (c *EBSVolumeTagsController) processIndividualVolume(ctx context.Context, item *pvUpdateQueueItem, infra *configv1.Infrastructure, ec2Client *ec2.Client) { - pv, err := c.getPersistentVolumeByName(item.pvNames[0]) +func (c *EBSVolumeTagsController) processIndividualVolume(ctx context.Context, item *pvUpdateQueueItem, infra *configv1.Infrastructure, ec2Client ec2TagsAPI) { + pvName := item.pvNames[0] + pv, err := c.getPersistentVolumeByName(pvName) if err != nil { if apierrors.IsNotFound(err) { - klog.Infof("skipping volume tags update because PV %v does not exist", item.pvNames[0]) - c.removeVolumesFromQueueSet(pv.Name) + klog.Infof("skipping volume tags update because PV %v does not exist", pvName) + c.removeVolumesFromQueueSet(pvName) c.queue.Forget(item) return } @@ -166,7 +179,7 @@ func (c *EBSVolumeTagsController) processIndividualVolume(ctx context.Context, i c.queue.Forget(item) return } - err = c.updateEBSTags(ctx, ec2Client, infra.Status.PlatformStatus.AWS.ResourceTags, pv) + tagsUpdatedOnAWS, err := c.updateEBSTags(ctx, ec2Client, infra.Status.PlatformStatus.AWS.ResourceTags, pv) if err != nil { var apiErr smithy.APIError if errors.As(err, &apiErr) { @@ -192,6 +205,14 @@ func (c *EBSVolumeTagsController) processIndividualVolume(ctx context.Context, i return } c.removeVolumesFromQueueSet(pv.Name) - klog.Infof("Successfully updated PV annotations and tags for volume %s", pv.Name) + logTagCompletionMessage(pv.Name, tagsUpdatedOnAWS) c.queue.Forget(item) } + +func logTagCompletionMessage(volumeName string, tagsUpdatedOnAWS bool) { + if tagsUpdatedOnAWS { + klog.Infof("Successfully updated PV annotations and tags for volume %s", volumeName) + } else { + klog.Infof("Successfully updated PV annotations for volume %s", volumeName) + } +} diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker_test.go b/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker_test.go new file mode 100644 index 000000000..f52d26150 --- /dev/null +++ b/pkg/driver/aws-ebs/aws_ebs_tags_queue_worker_test.go @@ -0,0 +1,561 @@ +package aws_ebs + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + configv1 "github.com/openshift/api/config/v1" + "github.com/openshift/csi-operator/pkg/clients" + "github.com/openshift/library-go/pkg/operator/events" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + fakecore "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/util/workqueue" + "k8s.io/utils/clock" +) + +// mockEC2Client implements ec2TagsAPI for unit testing. +type mockEC2Client struct { + createTagsFunc func(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) + describeTagsFunc func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) +} + +func (m *mockEC2Client) CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { + if m.createTagsFunc != nil { + return m.createTagsFunc(ctx, params, optFns...) + } + return &ec2.CreateTagsOutput{}, nil +} + +func (m *mockEC2Client) DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + if m.describeTagsFunc != nil { + return m.describeTagsFunc(ctx, params, optFns...) + } + return &ec2.DescribeTagsOutput{}, nil +} + +func newTestController() *EBSVolumeTagsController { + return &EBSVolumeTagsController{ + queue: workqueue.NewTypedRateLimitingQueue[*pvUpdateQueueItem](workqueue.NewTypedItemExponentialFailureRateLimiter[*pvUpdateQueueItem](1*time.Millisecond, 1*time.Millisecond)), + queueSet: make(map[string]struct{}), + eventRecorder: events.NewInMemoryRecorder("test", &clock.RealClock{}), + } +} + +func newTestControllerWithFakeKubeClient(t *testing.T, pvs ...*v1.PersistentVolume) *EBSVolumeTagsController { + t.Helper() + cr := clients.GetFakeOperatorCR() + c := clients.NewFakeClients("openshift-cluster-csi-drivers", cr) + + // Access the PV informer before starting informers so it gets registered. + // Then add PVs directly to the informer store, which is the most reliable + // way to populate fake informer caches (same pattern as aws_ebs_test.go). + pvInformer := c.KubeInformers.InformersFor("").Core().V1().PersistentVolumes().Informer() + clients.SyncFakeInformers(t, c) + + for _, pv := range pvs { + if err := pvInformer.GetStore().Add(pv); err != nil { + t.Fatalf("failed to add PV %s to informer store: %v", pv.Name, err) + } + // Also add to the fake client so that Updates (e.g. annotation writes) work. + if err := c.KubeClient.(*fakecore.Clientset).Tracker().Add(pv); err != nil { + t.Fatalf("failed to add PV %s to tracker: %v", pv.Name, err) + } + } + + return &EBSVolumeTagsController{ + commonClient: c, + queue: workqueue.NewTypedRateLimitingQueue[*pvUpdateQueueItem](workqueue.NewTypedItemExponentialFailureRateLimiter[*pvUpdateQueueItem](1*time.Millisecond, 1*time.Millisecond)), + queueSet: make(map[string]struct{}), + eventRecorder: events.NewInMemoryRecorder("test", &clock.RealClock{}), + } +} + +func newTestPV(name, volumeHandle string) *v1.PersistentVolume { + return &v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{Name: name}, + Spec: v1.PersistentVolumeSpec{ + PersistentVolumeSource: v1.PersistentVolumeSource{ + CSI: &v1.CSIPersistentVolumeSource{ + Driver: driverName, + VolumeHandle: volumeHandle, + }, + }, + }, + } +} + +func newTestInfra(tags []configv1.AWSResourceTag) *configv1.Infrastructure { + return &configv1.Infrastructure{ + Status: configv1.InfrastructureStatus{ + PlatformStatus: &configv1.PlatformStatus{ + AWS: &configv1.AWSPlatformStatus{ + Region: "us-east-1", + ResourceTags: tags, + }, + }, + }, + } +} + +func TestHandleBatchTagUpdateFailure(t *testing.T) { + tests := []struct { + name string + pvs []*v1.PersistentVolume + wantQueueSize int + }{ + { + name: "small batch re-queued individually", + pvs: []*v1.PersistentVolume{ + newTestPV("pv1", "vol-111"), + newTestPV("pv2", "vol-222"), + newTestPV("pv3", "vol-333"), + }, + wantQueueSize: 3, + }, + { + name: "large batch re-queued individually", + pvs: []*v1.PersistentVolume{ + newTestPV("pv1", "vol-111"), + newTestPV("pv2", "vol-222"), + newTestPV("pv3", "vol-333"), + newTestPV("pv4", "vol-444"), + newTestPV("pv5", "vol-555"), + }, + wantQueueSize: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := newTestController() + c.handleBatchTagUpdateFailure(tt.pvs, fmt.Errorf("batch failed")) + + for i := 0; i < tt.wantQueueSize; i++ { + item, quit := c.queue.Get() + if quit { + t.Fatalf("queue shut down unexpectedly after %d items", i) + } + if item.updateType != updateTypeIndividual { + t.Errorf("item %d: updateType = %v, want %v", i, item.updateType, updateTypeIndividual) + } + if len(item.pvNames) != 1 { + t.Errorf("item %d: got %d pvNames, want 1", i, len(item.pvNames)) + } + if item.pvNames[0] != tt.pvs[i].Name { + t.Errorf("item %d: pvName = %s, want %s", i, item.pvNames[0], tt.pvs[i].Name) + } + c.queue.Done(item) + } + + recorder := c.eventRecorder.(events.InMemoryRecorder) + foundWarning := false + for _, event := range recorder.Events() { + if event.Reason == "EBSVolumeTagsUpdateFailed" { + foundWarning = true + break + } + } + if !foundWarning { + t.Error("expected EBSVolumeTagsUpdateFailed event, not found") + } + }) + } +} + +func TestNeedsTagUpdate(t *testing.T) { + resourceTags := []configv1.AWSResourceTag{{Key: "key1", Value: "value1"}} + infra := newTestInfra(resourceTags) + expectedHash := computeTagsHash(resourceTags) + + tests := []struct { + name string + pv *v1.PersistentVolume + expected bool + }{ + { + name: "no hash annotation - needs update", + pv: newTestPV("pv1", "vol-111"), + expected: true, + }, + { + name: "matching hash - no update needed", + pv: &v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pv2", + Annotations: map[string]string{tagHashAnnotationKey: expectedHash}, + }, + }, + expected: false, + }, + { + name: "stale hash - needs update", + pv: &v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pv3", + Annotations: map[string]string{tagHashAnnotationKey: "old-hash"}, + }, + }, + expected: true, + }, + } + + c := newTestController() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.needsTagUpdate(infra, tt.pv) + if result != tt.expected { + t.Errorf("needsTagUpdate() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestProcessIndividualVolume(t *testing.T) { + resourceTags := []configv1.AWSResourceTag{{Key: "env", Value: "prod"}} + infra := newTestInfra(resourceTags) + + tests := []struct { + name string + pvs []*v1.PersistentVolume // PVs to seed in the informer + pvName string // the name put into the queue item + setup func(t *testing.T) *mockEC2Client + verify func(t *testing.T, c *EBSVolumeTagsController, item *pvUpdateQueueItem) + }{ + { + // Regression test: the old code called pv.Name after a not-found error, + // which would panic because pv is nil when the API returns not-found. + // The fix extracts pvName before the call and uses that string instead. + name: "PV not found: removes from queue without panicking", + pvs: nil, // no PV in the informer — will produce a not-found error + pvName: "missing-pv", + setup: func(t *testing.T) *mockEC2Client { + return &mockEC2Client{} // EC2 should never be reached + }, + verify: func(t *testing.T, c *EBSVolumeTagsController, item *pvUpdateQueueItem) { + if c.isVolumeInQueue("missing-pv") { + t.Error("missing-pv should have been removed from queueSet after not-found") + } + // The item should have been Forgotten, not rate-limited — so the + // queue should be empty (no requeue). + if c.queue.Len() != 0 { + t.Errorf("queue should be empty after not-found, got len=%d", c.queue.Len()) + } + }, + }, + { + name: "success: tags applied and annotations updated", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111")}, + pvName: "pv1", + setup: func(t *testing.T) *mockEC2Client { + return &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{}, nil + }, + } + }, + verify: func(t *testing.T, c *EBSVolumeTagsController, item *pvUpdateQueueItem) { + if c.isVolumeInQueue("pv1") { + t.Error("pv1 should have been removed from queueSet after successful update") + } + updated, err := c.commonClient.KubeClient.CoreV1().PersistentVolumes().Get(context.Background(), "pv1", metav1.GetOptions{}) + if err != nil { + t.Fatalf("failed to get pv1: %v", err) + } + expectedHash := computeTagsHash(resourceTags) + if updated.Annotations[tagHashAnnotationKey] != expectedHash { + t.Errorf("hash = %q, want %q", updated.Annotations[tagHashAnnotationKey], expectedHash) + } + }, + }, + { + name: "transient error: PV re-queued for retry", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111")}, + pvName: "pv1", + setup: func(t *testing.T) *mockEC2Client { + return &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return nil, fmt.Errorf("transient network error") + }, + } + }, + verify: func(t *testing.T, c *EBSVolumeTagsController, item *pvUpdateQueueItem) { + c.queue.Forget(item) // avoid rate-limit delay + itemCh := make(chan *pvUpdateQueueItem, 1) + go func() { + got, _ := c.queue.Get() + itemCh <- got + }() + select { + case got := <-itemCh: + if got.updateType != updateTypeIndividual { + t.Errorf("requeued item updateType = %v, want %v", got.updateType, updateTypeIndividual) + } + if len(got.pvNames) != 1 || got.pvNames[0] != "pv1" { + t.Errorf("requeued pvNames = %v, want [pv1]", got.pvNames) + } + c.queue.Done(got) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for requeued item") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := tt.setup(t) + c := newTestControllerWithFakeKubeClient(t, tt.pvs...) + + item := &pvUpdateQueueItem{updateType: updateTypeIndividual, pvNames: []string{tt.pvName}} + c.queue.Add(item) + c.addVolumesToQueueSet(tt.pvs...) + // Also add the target name to the queueSet so removeVolumesFromQueueSet + // has something to remove (mirrors real usage where the item was enqueued). + c.queueSet[tt.pvName] = struct{}{} + c.queue.Get() // mark as processing + + c.processIndividualVolume(t.Context(), item, infra, mock) + c.queue.Done(item) + + tt.verify(t, c, item) + }) + } +} + +func TestProcessBatchVolumes(t *testing.T) { + resourceTags := []configv1.AWSResourceTag{{Key: "env", Value: "prod"}} + expectedHash := computeTagsHash(resourceTags) + + // Each test case provides a setup func that builds the mock and returns a + // verify func. This lets the mock and verify share state (e.g. captured + // call arguments) via a closure without awkward struct field tricks. + tests := []struct { + name string + pvs []*v1.PersistentVolume + pvNames []string + setup func(t *testing.T) (mock *mockEC2Client, verify func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem)) + }{ + { + name: "success: tags applied and annotations updated", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111"), newTestPV("pv2", "vol-222")}, + pvNames: []string{"pv1", "pv2"}, + setup: func(t *testing.T) (*mockEC2Client, func(*testing.T, *EBSVolumeTagsController, *pvUpdateQueueItem)) { + mock := &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{}, nil + }, + } + verify := func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem) { + for _, name := range []string{"pv1", "pv2"} { + if c.isVolumeInQueue(name) { + t.Errorf("%s should have been removed from queueSet after successful update", name) + } + updated, err := c.commonClient.KubeClient.CoreV1().PersistentVolumes().Get(context.Background(), name, metav1.GetOptions{}) + if err != nil { + t.Fatalf("failed to get %s: %v", name, err) + } + if updated.Annotations[tagHashAnnotationKey] != expectedHash { + t.Errorf("%s hash = %q, want %q", name, updated.Annotations[tagHashAnnotationKey], expectedHash) + } + } + } + return mock, verify + }, + }, + { + name: "success: all volumes already tagged, CreateTags not called", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111"), newTestPV("pv2", "vol-222")}, + pvNames: []string{"pv1", "pv2"}, + setup: func(t *testing.T) (*mockEC2Client, func(*testing.T, *EBSVolumeTagsController, *pvUpdateQueueItem)) { + mock := &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{ + Tags: []ec2types.TagDescription{ + {ResourceId: aws.String("vol-111"), Key: aws.String("env"), Value: aws.String("prod"), ResourceType: ec2types.ResourceTypeVolume}, + {ResourceId: aws.String("vol-222"), Key: aws.String("env"), Value: aws.String("prod"), ResourceType: ec2types.ResourceTypeVolume}, + }, + }, nil + }, + createTagsFunc: func(_ context.Context, _ *ec2.CreateTagsInput, _ ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { + t.Error("CreateTags should NOT have been called when all volumes already have tags") + return &ec2.CreateTagsOutput{}, nil + }, + } + verify := func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem) { + for _, name := range []string{"pv1", "pv2"} { + if c.isVolumeInQueue(name) { + t.Errorf("%s should have been removed from queueSet", name) + } + } + } + return mock, verify + }, + }, + { + name: "success: only untagged volumes sent to CreateTags", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111"), newTestPV("pv2", "vol-222"), newTestPV("pv3", "vol-333")}, + pvNames: []string{"pv1", "pv2", "pv3"}, + setup: func(t *testing.T) (*mockEC2Client, func(*testing.T, *EBSVolumeTagsController, *pvUpdateQueueItem)) { + var taggedVolumeIDs []string + mock := &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + // vol-222 is already tagged; vol-111 and vol-333 need tagging. + return &ec2.DescribeTagsOutput{ + Tags: []ec2types.TagDescription{ + {ResourceId: aws.String("vol-222"), Key: aws.String("env"), Value: aws.String("prod"), ResourceType: ec2types.ResourceTypeVolume}, + }, + }, nil + }, + createTagsFunc: func(_ context.Context, params *ec2.CreateTagsInput, _ ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { + taggedVolumeIDs = params.Resources + return &ec2.CreateTagsOutput{}, nil + }, + } + verify := func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem) { + if len(taggedVolumeIDs) != 2 { + t.Fatalf("expected CreateTags for 2 volumes, got %d: %v", len(taggedVolumeIDs), taggedVolumeIDs) + } + for _, id := range taggedVolumeIDs { + if id == "vol-222" { + t.Error("vol-222 should have been skipped (already tagged)") + } + } + for _, name := range []string{"pv1", "pv2", "pv3"} { + if c.isVolumeInQueue(name) { + t.Errorf("%s should have been removed from queueSet", name) + } + } + } + return mock, verify + }, + }, + { + name: "error: DescribeTags failure re-queues whole batch", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111"), newTestPV("pv2", "vol-222"), newTestPV("pv3", "vol-333")}, + pvNames: []string{"pv1", "pv2", "pv3"}, + setup: func(t *testing.T) (*mockEC2Client, func(*testing.T, *EBSVolumeTagsController, *pvUpdateQueueItem)) { + mock := &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return nil, fmt.Errorf("DescribeTags throttled") + }, + } + verify := func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem) { + // forget the item here to avoid rate limiting + c.queue.Forget(workItem) + var updateItem *pvUpdateQueueItem + itemCh := make(chan *pvUpdateQueueItem, 1) + go func() { + for { + item, _ := c.queue.Get() + itemCh <- item + break + } + }() + select { + case updateItem = <-itemCh: + case <-time.After(2 * time.Second): + t.Fatalf("Failed waiting for updateItem") + } + + expectedPVNames := []string{"pv1", "pv2", "pv3"} + if !reflect.DeepEqual(updateItem.pvNames, expectedPVNames) { + t.Errorf("expected %+v, got %+v", expectedPVNames, updateItem.pvNames) + } + if updateItem.updateType != updateTypeBatch { + t.Errorf("expected batched item, got %+v", updateItem.updateType) + } + for _, pvName := range expectedPVNames { + if !c.isVolumeInQueue(pvName) { + t.Errorf("%s should have been added to queueSet", pvName) + } + } + } + return mock, verify + }, + }, + { + name: "error: CreateTags failure re-queues each volume individually", + pvs: []*v1.PersistentVolume{newTestPV("pv1", "vol-111"), newTestPV("pv2", "vol-222")}, + pvNames: []string{"pv1", "pv2"}, + setup: func(t *testing.T) (*mockEC2Client, func(*testing.T, *EBSVolumeTagsController, *pvUpdateQueueItem)) { + mock := &mockEC2Client{ + describeTagsFunc: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{}, nil + }, + createTagsFunc: func(_ context.Context, _ *ec2.CreateTagsInput, _ ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) { + return nil, fmt.Errorf("InvalidVolume.NotFound: vol-222 does not exist") + }, + } + verify := func(t *testing.T, c *EBSVolumeTagsController, workItem *pvUpdateQueueItem) { + c.queue.Forget(workItem) + workItems := []*pvUpdateQueueItem{} + itemCh := make(chan *pvUpdateQueueItem, 3) + go func() { + for { + item, quit := c.queue.Get() + if quit { + break + } + itemCh <- item + } + }() + for { + select { + case item := <-itemCh: + workItems = append(workItems, item) + case <-time.After(2 * time.Second): + t.Fatalf("failed waiting for workitems") + } + + if len(workItems) == 2 { + break + } + } + if len(workItems) != 2 { + t.Errorf("Expected 2 work items, got %d", len(workItems)) + } + for _, item := range workItems { + if item.updateType != updateTypeIndividual { + t.Errorf("Expected updateTypeIndividual, got %v", item.updateType) + } + c.queue.Done(item) + } + expectedPVNames := []string{"pv1", "pv2"} + for _, pvName := range expectedPVNames { + if !c.isVolumeInQueue(pvName) { + t.Errorf("%s should have been added to queueSet", pvName) + } + } + } + return mock, verify + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock, verify := tt.setup(t) + + c := newTestControllerWithFakeKubeClient(t, tt.pvs...) + c.addVolumesToQueueSet(tt.pvs...) + + infra := newTestInfra(resourceTags) + item := &pvUpdateQueueItem{updateType: updateTypeBatch, pvNames: tt.pvNames} + c.queue.Add(item) + c.queue.Get() // mark as processing + + c.processBatchVolumes(t.Context(), item, infra, mock) + // Mark the original item done so any rate-limited re-adds become dequeue-able. + c.queue.Done(item) + + verify(t, c, item) + }) + } +}