Skip to content
Open
140 changes: 131 additions & 9 deletions pkg/driver/aws-ebs/aws_ebs_tags_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,37 @@ type pvUpdateQueueItem struct {
pvNames []string
}

// ec2TagsAPI defines the EC2 API methods used by the tags controller.
// Using an interface allows mocking the EC2 client in unit tests.
type ec2TagsAPI interface {
CreateTags(ctx context.Context, params *ec2.CreateTagsInput, optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
}

type failedTagError struct {
awsError error
}

func (f *failedTagError) Error() string {
return f.awsError.Error()
}

func (f *failedTagError) Unwrap() error {
return f.awsError
}

type failWholeBatchError struct {
failedTagError
}

type failOneOrMoreTagError struct {
failedTagError
}

var _ error = &failedTagError{}
var _ error = &failWholeBatchError{}
var _ error = &failOneOrMoreTagError{}

func NewEBSVolumeTagsController(
name string,
commonClient *clients.Clients,
Expand Down Expand Up @@ -270,7 +301,7 @@ func (c *EBSVolumeTagsController) getEBSCloudCredSecret() (*v1.Secret, error) {
// processInfrastructure processes the Infrastructure resource and push pvNames for tags update in retry queue worker.
func (c *EBSVolumeTagsController) processInfrastructure(infra *configv1.Infrastructure) error {
if infra.Status.PlatformStatus != nil && infra.Status.PlatformStatus.AWS != nil &&
infra.Status.PlatformStatus.AWS.ResourceTags != nil {
infra.Status.PlatformStatus.AWS.ResourceTags != nil && len(infra.Status.PlatformStatus.AWS.ResourceTags) > 0 {
err := c.fetchAndPushPvsToQueue(infra)
if err != nil {
klog.Errorf("error processing PVs for infrastructure update: %v", err)
Expand Down Expand Up @@ -308,20 +339,38 @@ func (c *EBSVolumeTagsController) fetchAndPushPvsToQueue(infra *configv1.Infrast
return nil
}

// updateEBSTags updates the tags of an AWS EBS volume with rate limiting
func (c *EBSVolumeTagsController) updateEBSTags(ctx context.Context, ec2Client *ec2.Client, resourceTags []configv1.AWSResourceTag,
pvs ...*v1.PersistentVolume) error {
// updateEBSTags updates the tags of an AWS EBS volume with rate limiting.
// It first checks if the volumes already have the desired tags and skips the
// CreateTags call for volumes that are already up to date.
//
// Returns error and whether tags were actually updated on AWS.
// A nil error with false return value implies all tags were already updated on AWS.
// A nil error with true return value implies some tags were updated on AWS.
func (c *EBSVolumeTagsController) updateEBSTags(ctx context.Context, ec2Client ec2TagsAPI, resourceTags []configv1.AWSResourceTag,
pvs ...*v1.PersistentVolume) (bool, error) {
// Prepare tags
tags := newAndUpdatedTags(resourceTags)
// Create or update the tags
_, err := ec2Client.CreateTags(ctx, &ec2.CreateTagsInput{
Resources: pvsToResourceIDs(pvs),

// Filter out volumes that already have all desired tags
pvsNeedingUpdate, err := filterVolumesNeedingTagUpdate(ctx, ec2Client, tags, pvs)
if err != nil {
return false, &failWholeBatchError{failedTagError{err}}
}

if len(pvsNeedingUpdate) == 0 {
klog.V(4).Infof("All volumes already have the desired tags, skipping CreateTags call")
return false, nil
}

// Create or update the tags only for volumes that need it
_, err = ec2Client.CreateTags(ctx, &ec2.CreateTagsInput{
Resources: pvsToResourceIDs(pvsNeedingUpdate),
Tags: tags,
})
if err != nil {
return err
return false, &failOneOrMoreTagError{failedTagError{err}}
}
return nil
return true, nil
}

// listPersistentVolumes lists the volume
Expand Down Expand Up @@ -382,6 +431,79 @@ func newAndUpdatedTags(resourceTags []configv1.AWSResourceTag) []ec2types.Tag {
return tags
}

// volumeHasAllTags returns true if all desired tags already exist on the volume with matching values.
// Extra tags on the volume that are not in the desired set are ignored.
func volumeHasAllTags(existingTags map[string]string, desiredTags []ec2types.Tag) bool {
for _, tag := range desiredTags {
val, ok := existingTags[*tag.Key]
if !ok || val != *tag.Value {
return false
}
}
return true
}

// filterVolumesNeedingTagUpdate calls DescribeTags to fetch existing tags and returns
// only the PVs whose AWS volumes do not already have all desired tags applied.
// If DescribeTags fails, all PVs are returned unchanged (fail-open).
func filterVolumesNeedingTagUpdate(ctx context.Context, ec2Client ec2TagsAPI, desiredTags []ec2types.Tag, pvs []*v1.PersistentVolume) ([]*v1.PersistentVolume, error) {
volumeIDs := pvsToResourceIDs(pvs)
if len(volumeIDs) == 0 {
return pvs, nil
}

volumeTags, err := fetchTagsOnVolumes(ctx, ec2Client, volumeIDs)
if err != nil {
return pvs, err
}

var needUpdate []*v1.PersistentVolume
for _, pv := range pvs {
existingTags, found := volumeTags[pv.Spec.CSI.VolumeHandle]
if !found || !volumeHasAllTags(existingTags, desiredTags) {
needUpdate = append(needUpdate, pv)
} else {
klog.V(4).Infof("Skipping tag update for volume %s (%s): all tags already present", pv.Name, pv.Spec.CSI.VolumeHandle)
}
}
return needUpdate, nil
}

func fetchTagsOnVolumes(ctx context.Context, ec2Client ec2TagsAPI, volumeIDs []string) (map[string]map[string]string, error) {
volumeTags := make(map[string]map[string]string)
var nextToken *string

for {
input := &ec2.DescribeTagsInput{
Filters: []ec2types.Filter{
{Name: aws.String("resource-id"), Values: volumeIDs},
{Name: aws.String("resource-type"), Values: []string{"volume"}},
},
MaxResults: aws.Int32(1000),
NextToken: nextToken,
}
output, err := ec2Client.DescribeTags(ctx, input)
if err != nil {
return volumeTags, fmt.Errorf("fetching tags for one or more volumes with: %w", err)
}

for _, td := range output.Tags {
if td.ResourceId == nil || td.Key == nil || td.Value == nil {
continue
}
if _, ok := volumeTags[*td.ResourceId]; !ok {
volumeTags[*td.ResourceId] = make(map[string]string)
}
volumeTags[*td.ResourceId][*td.Key] = *td.Value
}
nextToken = output.NextToken
if nextToken == nil {
break
}
}
return volumeTags, nil
}

// filterUpdatableVolumes filters the list of volumes whose tags needs to be updated.
func (c *EBSVolumeTagsController) filterUpdatableVolumes(volumes []*v1.PersistentVolume, newTagsHash string) []*v1.PersistentVolume {
var updatablePVs []*v1.PersistentVolume
Expand Down
168 changes: 168 additions & 0 deletions pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,171 @@ func TestRemoveVolumesFromQueueSet(t *testing.T) {
t.Error("removeVolumesFromQueueSet() incorrectly removed PV pv2 from queue set")
}
}

func TestVolumeHasAllTags(t *testing.T) {
tests := []struct {
name string
existingTags map[string]string
desiredTags []ec2types.Tag
expected bool
}{
{
name: "all desired tags present with matching values",
existingTags: map[string]string{
"key1": "value1",
"key2": "value2",
"extra": "ignored",
},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
{Key: aws.String("key2"), Value: aws.String("value2")},
},
expected: true,
},
{
name: "missing a desired tag",
existingTags: map[string]string{
"key1": "value1",
},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
{Key: aws.String("key2"), Value: aws.String("value2")},
},
expected: false,
},
{
name: "desired tag exists but value differs",
existingTags: map[string]string{
"key1": "old-value",
},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("new-value")},
},
expected: false,
},
{
name: "empty desired tags always matches",
existingTags: map[string]string{},
desiredTags: []ec2types.Tag{},
expected: true,
},
{
name: "no existing tags with desired tags",
existingTags: map[string]string{},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
},
expected: false,
},
{
name: "nil existing tags map",
existingTags: nil,
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
},
expected: false,
},
{
name: "all keys present but one value wrong",
existingTags: map[string]string{
"key1": "value1",
"key2": "wrong-value",
"key3": "value3",
},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
{Key: aws.String("key2"), Value: aws.String("value2")},
{Key: aws.String("key3"), Value: aws.String("value3")},
},
expected: false,
},
{
name: "many extra tags do not affect match",
existingTags: map[string]string{
"aws:cloudformation:stack-name": "my-stack",
"kubernetes.io/cluster/test": "owned",
"Name": "my-volume",
"key1": "value1",
},
desiredTags: []ec2types.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
},
expected: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := volumeHasAllTags(tt.existingTags, tt.desiredTags)
if result != tt.expected {
t.Errorf("volumeHasAllTags() = %v, want %v", result, tt.expected)
}
})
}
}

// TestProcessInfrastructureSkipsWhenNoTags verifies that processInfrastructure
// returns nil without calling fetchAndPushPvsToQueue when ResourceTags is nil,
// empty, or the platform status is missing. The controller has no informers
// set up, so reaching fetchAndPushPvsToQueue would panic — a clean return
// confirms the guard condition works.
func TestProcessInfrastructureSkipsWhenNoTags(t *testing.T) {
c := &EBSVolumeTagsController{}

tests := []struct {
name string
infra *configv1.Infrastructure
}{
{
name: "nil PlatformStatus",
infra: &configv1.Infrastructure{
Status: configv1.InfrastructureStatus{
PlatformStatus: nil,
},
},
},
{
name: "nil AWS in PlatformStatus",
infra: &configv1.Infrastructure{
Status: configv1.InfrastructureStatus{
PlatformStatus: &configv1.PlatformStatus{
AWS: nil,
},
},
},
},
{
name: "nil ResourceTags",
infra: &configv1.Infrastructure{
Status: configv1.InfrastructureStatus{
PlatformStatus: &configv1.PlatformStatus{
AWS: &configv1.AWSPlatformStatus{
ResourceTags: nil,
},
},
},
},
},
{
name: "empty ResourceTags",
infra: &configv1.Infrastructure{
Status: configv1.InfrastructureStatus{
PlatformStatus: &configv1.PlatformStatus{
AWS: &configv1.AWSPlatformStatus{
ResourceTags: []configv1.AWSResourceTag{},
},
},
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := c.processInfrastructure(tt.infra)
if err != nil {
t.Errorf("processInfrastructure() returned error: %v, want nil", err)
}
})
}
}
Loading