diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index 31ae31b1b013..6551550bd436 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index f4e90bc17b63..961c59201704 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 39254b96c705..daa856224716 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index 5d677995de71..d3aac976d3c2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 3b87ef08c3a0..0f0318057eee 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -13,7 +13,7 @@ import com.azure.cosmos.models.{CosmosClientTelemetryConfig, CosmosMetricCategor import com.azure.cosmos.spark.CosmosPredicates.isOnSparkDriver import com.azure.cosmos.spark.catalog.{CosmosCatalogClient, CosmosCatalogCosmosSDKClient, CosmosCatalogManagementSDKClient} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} +import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, CosmosHeaderName, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder} import com.azure.monitor.opentelemetry.autoconfigure.{AzureMonitorAutoConfigure, AzureMonitorAutoConfigureOptions} import com.azure.resourcemanager.cosmos.CosmosManager @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import java.util.function.BiPredicate import scala.collection.concurrent.TrieMap - // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import @@ -713,6 +712,17 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + // Apply additional HTTP headers (e.g., workload-id) to the builder if configured. + // Header name validation already happened at config-parse time in CosmosConfig.AdditionalHeadersConfig, + // so CosmosHeaderName.fromString() here is just a type conversion for already-validated keys. + if (cosmosClientConfiguration.additionalHeaders.isDefined) { + val headerMap = new java.util.HashMap[CosmosHeaderName, String]() + for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) { + headerMap.put(CosmosHeaderName.fromString(key), value) + } + builder.additionalHeaders(headerMap) + } + var client = builder.buildAsyncClient() if (cosmosClientConfiguration.clientInterceptors.isDefined) { @@ -916,7 +926,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Additional HTTP headers are part of the cache key because different workload-ids + // should produce different CosmosAsyncClient instances + additionalHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -935,7 +948,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientBuilderInterceptors, clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, - clientConfig.azureMonitorConfig + clientConfig.azureMonitorConfig, + clientConfig.additionalHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 6f4e26e1f503..7e09c73698a7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration ( clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Optional additional HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientBuilderInterceptors, cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, - diagnosticsConfig.azureMonitorConfig + diagnosticsConfig.azureMonitorConfig, + cosmosAccountConfig.additionalHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..10b0d5af8623 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -4,10 +4,10 @@ package com.azure.cosmos.spark import com.azure.core.management.AzureEnvironment -import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} +import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, CosmosHeaderName, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper -import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} +import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition} import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter import java.time.{Duration, Instant} import java.util import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable @@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" + // Additional HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} + // Flows through to CosmosClientBuilder.additionalHeaders(). + val AdditionalHeaders = "spark.cosmos.additionalHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames { WriteOnRetryCommitInterceptor, WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, - WriteMaxRetryNoProgressIntervalInSeconds + WriteMaxRetryNoProgressIntervalInSeconds, + AdditionalHeaders ) def validateConfigName(name: String): Unit = { @@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String, resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], - clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + // Optional additional HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.additionalHeaders JSON config, passed to CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -727,6 +736,40 @@ private object CosmosAccountConfig extends BasicLoggingTrait { parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like + // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. + // These headers are converted to Map[CosmosHeaderName, String] and passed to + // CosmosClientBuilder.additionalHeaders() in CosmosClientCache. + // + // Validation: After JSON parsing, every header name is validated via CosmosHeaderName.fromString() + // to fail fast at config-parse time rather than at runtime during client creation. + // This prevents Spark jobs from starting, allocating cluster resources, and only failing + // later when CosmosClientCache tries to convert String keys to CosmosHeaderName instances. + private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.AdditionalHeaders, + mandatory = false, + parseFromStringFunction = headersJson => { + try { + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + val parsed = Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + + // Fail fast: validate every header name is a known CosmosHeaderName at parse time. + // Without this, unknown headers like {"x-bad-header": "value"} would parse successfully + // and only blow up at runtime in CosmosClientCache when CosmosHeaderName.fromString() is called. + for (key <- parsed.keys) { + CosmosHeaderName.fromString(key) // throws IllegalArgumentException for unknown headers + } + + parsed + } catch { + case e: IllegalArgumentException => throw e + case e: Exception => throw new IllegalArgumentException( + s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " + + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) + } + }, + helpMessage = "Optional additional headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -761,6 +804,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) + // Parse optional additional HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val additionalHeaders = CosmosConfigEntry.parse(cfg, AdditionalHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -880,7 +925,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { resourceGroupNameOpt, azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, - if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, + additionalHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index ccf36791dc96..8e44968a76df 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -64,7 +64,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -91,7 +92,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -118,7 +120,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -145,7 +148,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ) ) @@ -179,8 +183,9 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None - ) + azureMonitorConfig = None, + additionalHeaders = None + ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 7fcc601ba016..9752136066f4 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -408,4 +408,75 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ") configuration.azureMonitorConfig shouldEqual None } + + // Verifies that the spark.cosmos.additionalHeaders configuration option correctly parses + // a JSON string containing a single workload-id header into a Map[String, String] on + // CosmosClientConfiguration. This is the primary use case for the workload-id feature. + it should "parse additionalHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + } + + // Verifies that when spark.cosmos.additionalHeaders is not specified in the config map, + // CosmosClientConfiguration.additionalHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set additionalHeaders continue to work without changes. + it should "handle missing additionalHeaders" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe None + } + + // Verifies that unknown header names in spark.cosmos.additionalHeaders are rejected + // at config-parse time. This ensures Spark jobs fail fast during configuration rather than + // starting, allocating cluster resources, and only failing later at runtime when + // CosmosClientCache tries to create the client. + // Note: CosmosConfigEntry.parse() wraps all parsing exceptions in RuntimeException, + // so the IllegalArgumentException from CosmosHeaderName.fromString() surfaces as + // RuntimeException with the original IllegalArgumentException as the cause. + it should "reject unknown additional headers at parse time" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val thrown = the [RuntimeException] thrownBy { + CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + } + thrown.getCause shouldBe a [IllegalArgumentException] + thrown.getCause.getMessage should include ("x-custom-header") + } + + // Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully, + // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases + // and that no headers are injected when the JSON object is empty. + it should "handle empty additionalHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> "{}" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get shouldBe empty + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 6ef90b55989d..f745bae6b56d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index dfd14c36c80f..c17e7b02dad8 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala new file mode 100644 index 000000000000..0f0a1f200d3e --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +/** + * Integration tests (smoke tests) for the additional headers (workload-id) feature in the Spark connector. + * These are smoke tests — they verify that Spark DataFrame read and write operations succeed + * (no errors, correct data) when the `spark.cosmos.additionalHeaders` configuration is set. + * They do NOT assert that the workload-id header is actually present on the wire request. + * Wire-level header propagation is verified by: + * - Java SDK unit tests: RxGatewayStoreModelTest, GatewayAddressCacheTest + * - Java SDK integration tests: WorkloadIdE2ETests (interceptor-based wire assertions) + * + * Test cases: + * 1. Read with workload-id header — Spark read succeeds, correct item returned + * 2. Write with workload-id header — Spark write succeeds, item verified via SDK read-back + * 3. No additionalHeaders (regression) — operations succeed without the config set + */ +class SparkE2EWorkloadIdITest + extends IntegrationSpec + with Spark + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait { + + val objectMapper = new ObjectMapper() + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + // Integration smoke test #1: Spark read with workload-id header. + // Creates an item via SDK, then reads it back via Spark DataFrame with + // spark.cosmos.additionalHeaders set to {"x-ms-cosmos-workload-id": "15"}. + // Verifies the read succeeds and returns the correct item. + // This proves the header flows through the Spark config pipeline without causing errors. + "spark query with additionalHeaders" can "read items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } + + // Integration smoke test #2: Spark write with workload-id header. + // Writes an item via Spark DataFrame with spark.cosmos.additionalHeaders set to + // {"x-ms-cosmos-workload-id": "20"}, then reads it back via SDK to confirm + // write was persisted correctly. + // This proves the header flows through the Spark write pipeline without causing errors. + "spark write with additionalHeaders" can "write items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testWriteItem" + | } + |""".stripMargin + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.write.strategy" -> "ItemOverwrite", + "spark.cosmos.write.bulk.enabled" -> "false", + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + val spark_session = spark + import spark_session.implicits._ + val df = spark.read.json(Seq(rawItem).toDS()) + + df.write.format("cosmos.oltp").options(cfg).mode("Append").save() + + // Verify the item was written by reading it back via the SDK directly + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block() + readItem.getItem.get("id").textValue() shouldEqual id + readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" + } + + // Integration smoke test #3: Regression test — no additionalHeaders configured. + // Verifies that Spark read operations continue to work correctly when + // spark.cosmos.additionalHeaders is NOT specified in the config map. + // This ensures the feature addition does not break existing Spark jobs + // that don't use additional headers. + "spark operations without additionalHeaders" can "still succeed" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "noHeadersItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + rowsArray(0).getAs[String]("id") shouldEqual id + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals + //scalastyle:on null +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index 9a438a3f1c53..b36949f18c57 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java new file mode 100644 index 000000000000..41fc62f62ca3 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java @@ -0,0 +1,397 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.ImplementationBridgeHelpers; +import com.azure.cosmos.implementation.RequestOptions; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosContainerRequestOptions; +import com.azure.cosmos.models.CosmosDatabaseRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; +import com.azure.cosmos.models.CosmosStoredProcedureRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Comprehensive unit tests for the workload-id / additional headers feature. + *

+ * Covers three layers: + *

    + *
  1. Public API surface — {@link CosmosHeaderName} constants, {@code CosmosClientBuilder.additionalHeaders()}, + * and that {@code setAdditionalHeaders()} is callable on every request options class.
  2. + *
  3. Validation — non-numeric workload-id rejected at builder and request-options levels; + * out-of-range values accepted (range enforcement is the backend's responsibility).
  4. + *
  5. Internal wiring — headers set via {@code setAdditionalHeaders()} actually reach + * {@code RequestOptions.getHeaders()}, which is what {@code RxDocumentClientImpl.getRequestHeaders()} + * reads to populate outbound request headers. Covers both data-plane and control-plane classes.
  6. + *
+ * E2E tests (SDK → Gateway/Backend → Jarvis) are in {@code WorkloadIdJarvisValidationTests}. + */ +public class WorkloadIdHeaderTests { + + private static final String WORKLOAD_ID_HEADER = HttpConstants.HttpHeaders.WORKLOAD_ID; + private static final String TEST_WORKLOAD_ID = "42"; + + // ============================================================================================== + // 1. CosmosHeaderName constants + // ============================================================================================== + + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameWorkloadIdValue() { + assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName()).isEqualTo("x-ms-cosmos-workload-id"); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringResolvesKnownHeader() { + assertThat(CosmosHeaderName.fromString("x-ms-cosmos-workload-id")).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringIsCaseInsensitive() { + assertThat(CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID")).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringRejectsUnknownHeader() { + assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header"); + } + + // ============================================================================================== + // 2. CosmosClientBuilder — additionalHeaders() + // ============================================================================================== + + @Test(groups = { "unit" }) + public void builderStoresAdditionalHeaders() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(WORKLOAD_ID_HEADER, "25"); + } + + @Test(groups = { "unit" }) + public void builderHandlesNullAdditionalHeaders() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(null); + + assertThat(builder.getAdditionalHeaders()).isNull(); + } + + @Test(groups = { "unit" }) + public void builderHandlesEmptyAdditionalHeaders() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(new HashMap<>()); + + assertThat(builder.getAdditionalHeaders()).isEmpty(); + } + + @Test(groups = { "unit" }) + public void clonedBuilderPreservesAdditionalHeaders() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosClientBuilder original = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + CosmosClientBuilder cloned = CosmosBridgeInternal.cloneCosmosClientBuilder(original); + + assertThat(cloned.getAdditionalHeaders()) + .as("cloned builder should preserve additionalHeaders") + .containsEntry(WORKLOAD_ID_HEADER, "25"); + } + + @Test(groups = { "unit" }) + public void clonedBuilderHandlesNullAdditionalHeaders() { + CosmosClientBuilder original = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ=="); + + CosmosClientBuilder cloned = CosmosBridgeInternal.cloneCosmosClientBuilder(original); + + assertThat(cloned.getAdditionalHeaders()) + .as("cloned builder should handle null additionalHeaders") + .isNull(); + } + + // ============================================================================================== + // 3. Validation — non-numeric rejected, out-of-range accepted + // ============================================================================================== + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosItemRequestOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); + + assertThatThrownBy(() -> new CosmosBulkExecutionOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not-a-number"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtDatabaseRequestOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); + + assertThatThrownBy(() -> new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not-a-number"); + } + + /** Range validation [1, 50] is the backend's responsibility — SDK only validates integer format. */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(WORKLOAD_ID_HEADER, "51"); + } + + // ============================================================================================== + // 4. Internal wiring — workload-id reaches RequestOptions.getHeaders() + // This is the acceptance proof that the header will be on the wire. + // ============================================================================================== + + /** + * Data provider covering request options classes that have a no-arg {@code toRequestOptions()}. + * Includes both data-plane and control-plane classes. + */ + @DataProvider(name = "requestOptionsWithToRequestOptions") + public Object[][] requestOptionsWithToRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + return new Object[][] { + // Data-plane + { "CosmosItemRequestOptions", + reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosBatchRequestOptions", + reflectToRequestOptions(new CosmosBatchRequestOptions().setAdditionalHeaders(headers)) }, + // Control-plane + { "CosmosDatabaseRequestOptions", + reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosContainerRequestOptions", + reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosStoredProcedureRequestOptions", + reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(headers)) }, + }; + } + + @Test(groups = { "unit" }, dataProvider = "requestOptionsWithToRequestOptions") + public void workloadIdReachesOutboundHeaders(String optionsClassName, RequestOptions requestOptions) { + assertThat(requestOptions.getHeaders()) + .as("workload-id should reach RequestOptions.getHeaders() for " + optionsClassName) + .isNotNull() + .containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosQueryRequestOptions uses a bridge accessor pattern (no simple toRequestOptions()). */ + @Test(groups = { "unit" }) + public void workloadIdReachesQueryRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + RequestOptions opts = ImplementationBridgeHelpers + .CosmosQueryRequestOptionsHelper + .getCosmosQueryRequestOptionsAccessor() + .toRequestOptions(new CosmosQueryRequestOptions().setAdditionalHeaders(headers)); + + assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosChangeFeedRequestOptions uses a bridge accessor that exposes getHeaders(). */ + @Test(groups = { "unit" }) + public void workloadIdReachesChangeFeedRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + Map extracted = ImplementationBridgeHelpers + .CosmosChangeFeedRequestOptionsHelper + .getCosmosChangeFeedRequestOptionsAccessor() + .getHeaders( + CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setAdditionalHeaders(headers)); + + assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosReadManyRequestOptions uses getImpl().applyToRequestOptions(). */ + @Test(groups = { "unit" }) + public void workloadIdReachesReadManyRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + RequestOptions opts = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor() + .getImpl(new CosmosReadManyRequestOptions().setAdditionalHeaders(headers)) + .applyToRequestOptions(new RequestOptions()); + + assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosBulkExecutionOptions delegates to an internal CosmosBatchRequestOptions. */ + @Test(groups = { "unit" }) + public void workloadIdReachesBulkExecutionOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + Map extracted = reflectGetHeaders( + new CosmosBulkExecutionOptions().setAdditionalHeaders(headers)); + + assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + // ============================================================================================== + // 5. Null/empty additionalHeaders does not inject workload-id into outbound headers + // ============================================================================================== + + @Test(groups = { "unit" }) + public void nullAdditionalHeadersDoesNotInjectWorkloadId() { + assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(null))); + } + + @Test(groups = { "unit" }) + public void emptyAdditionalHeadersDoesNotInjectWorkloadId() { + Map empty = new HashMap<>(); + assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(empty))); + } + + // ============================================================================================== + // 6. Coverage matrix guard — setAdditionalHeaders() exists on all expected classes + // ============================================================================================== + + /** Guard against accidental removal of setAdditionalHeaders() from any request options class. */ + @Test(groups = { "unit" }) + public void allExpectedRequestOptionsClassesSupportAdditionalHeaders() { + Class[] expectedClasses = { + // Data-plane + CosmosItemRequestOptions.class, + CosmosBatchRequestOptions.class, + CosmosBulkExecutionOptions.class, + CosmosQueryRequestOptions.class, + CosmosReadManyRequestOptions.class, + CosmosChangeFeedRequestOptions.class, + // Control-plane + CosmosDatabaseRequestOptions.class, + CosmosContainerRequestOptions.class, + CosmosStoredProcedureRequestOptions.class, + }; + + for (Class clazz : expectedClasses) { + assertThat(hasSetAdditionalHeaders(clazz)) + .as(clazz.getSimpleName() + " should have setAdditionalHeaders()") + .isTrue(); + } + } + + // ============================================================================================== + // Helpers + // ============================================================================================== + + private static RequestOptions reflectToRequestOptions(Object cosmosRequestOptions) { + try { + java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("toRequestOptions"); + m.setAccessible(true); + return (RequestOptions) m.invoke(cosmosRequestOptions); + } catch (Exception e) { + throw new RuntimeException( + "Failed to call toRequestOptions() on " + cosmosRequestOptions.getClass().getSimpleName(), e); + } + } + + @SuppressWarnings("unchecked") + private static Map reflectGetHeaders(Object cosmosRequestOptions) { + try { + java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("getHeaders"); + m.setAccessible(true); + return (Map) m.invoke(cosmosRequestOptions); + } catch (Exception e) { + throw new RuntimeException( + "Failed to call getHeaders() on " + cosmosRequestOptions.getClass().getSimpleName(), e); + } + } + + private static void assertNoWorkloadId(RequestOptions options) { + Map headers = options.getHeaders(); + if (headers != null) { + assertThat(headers).doesNotContainKey(WORKLOAD_ID_HEADER); + } + } + + private static boolean hasSetAdditionalHeaders(Class clazz) { + try { + clazz.getMethod("setAdditionalHeaders", Map.class); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } +} + diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java index a9f5cb35549c..b4cb04e0ae3e 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.mockito.Mockito.doAnswer; @@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy( GlobalEndpointManager globalEndpointManager, GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker, HttpClient rxOrigClient, - ApiType apiType) { + ApiType apiType, + Map additionalHeaders) { origHttpClient = rxOrigClient; spyHttpClient = Mockito.spy(rxOrigClient); @@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy( userAgentContainer, globalEndpointManager, spyHttpClient, - apiType); + apiType, + additionalHeaders); } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java index 54440ecfabc5..edf16806489c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java @@ -27,6 +27,8 @@ import java.net.SocketException; import java.net.URI; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext; @@ -102,6 +104,7 @@ public void readTimeout() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -205,7 +209,8 @@ public void applySessionToken( new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( @@ -277,7 +282,8 @@ public void validateApiType() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception { return false; } + /** + * Verifies that client-level additionalHeaders (e.g., workload-id) are injected into + * outgoing HTTP requests by performRequest(). This covers metadata requests + * (collection cache, partition key range) that don't go through getRequestHeaders(). + */ + @Test(groups = "unit") + public void additionalHeadersInjectedInPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + additionalHeaders); + + // Simulate a metadata request (e.g., collection cache lookup) — no additionalHeaders on the request itself + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25"); + } + + /** + * Verifies that request-level headers take precedence over client-level additionalHeaders. + * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest() + * should NOT overwrite it. + */ + @Test(groups = "unit") + public void requestLevelHeadersTakePrecedenceOverAdditionalHeaders() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + additionalHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col/docs/doc1", + ResourceType.Document); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + // Simulate request-level header already set (e.g., by getRequestHeaders()) + dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // Request-level header "42" should win over client-level "10" + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42"); + } + + /** + * Verifies that when additionalHeaders is null, performRequest() still works normally + * without injecting any extra headers. + */ + @Test(groups = "unit") + public void nullAdditionalHeadersDoesNotAffectPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // No workload-id header should be present + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull(); + } + enum SessionTokenType { NONE, // no session token applied USER, // userControlled session token diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java index b06d6f89b8e9..aa9029576b24 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.stream.Collectors; @@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient rxClient, - ApiType apiType) { + ApiType apiType, + Map additionalHeaders) { this.origRxGatewayStoreModel = super.createRxGatewayProxy( sessionContainer, consistencyLevel, @@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, rxClient, - apiType); + apiType, + additionalHeaders); this.requests = Collections.synchronizedList(new ArrayList<>()); this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel); this.initRequestCapture(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java index 64cd7fe37115..1b5e3670ace0 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java @@ -2,14 +2,19 @@ import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.http.HttpClient; +import com.azure.cosmos.implementation.http.HttpRequest; import com.azure.cosmos.implementation.routing.RegionalRoutingContext; import io.netty.channel.ConnectTimeoutException; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.testng.annotations.Test; import reactor.core.publisher.Mono; import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; public class ThinClientStoreModelTest { @@ -44,7 +49,8 @@ public void testThinClientStoreModel() throws Exception { ConsistencyLevel.SESSION, new UserAgentContainer(), globalEndpointManager, - httpClient); + httpClient, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -58,4 +64,125 @@ public void testThinClientStoreModel() throws Exception { //no-op } } + + /** + * Verifies that additionalHeaders (e.g., workload-id) passed to ThinClientStoreModel's + * constructor are correctly propagated to the parent RxGatewayStoreModel and injected + * into outgoing requests via performRequest(). + *

+ * This is for the workload-id feature: ThinClientStoreModel extends + * RxGatewayStoreModel, and the additionalHeaders must flow through the constructor + * chain so that performRequest() injects them into every request — including + * metadata requests (collection cache, PKRange cache, etc.). + */ + @Test(groups = "unit") + public void testAdditionalHeadersFlowThroughThinClientStoreModel() throws Exception { + DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class); + Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig(); + Mockito + .doReturn(ImplementationBridgeHelpers + .CosmosDiagnosticsHelper + .getCosmosDiagnosticsAccessor() + .create(clientContext, 1d)) + .when(clientContext).createDiagnostics(); + + String sdkGlobalSessionToken = "1#100#1=20#2=5#3=30"; + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + Mockito.doReturn(sdkGlobalSessionToken).when(sessionContainer).resolveGlobalSessionToken(any()); + + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + // Capture the HttpRequest sent by performRequest() to verify headers + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(requestCaptor.capture(), any())) + .thenReturn(Mono.error(new ConnectTimeoutException())); + + // Set up additionalHeaders with workload-id + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + ThinClientStoreModel storeModel = new ThinClientStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + additionalHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/fakeResourceFullName", + ResourceType.Document); + + try { + storeModel.performRequest(dsr).block(); + } catch (Exception e) { + // Expected — mock HTTP client throws ConnectTimeoutException + } + + // Verify that the workload-id header was injected into the request + assertThat(dsr.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)) + .as("workload-id header should be injected into request by performRequest()") + .isEqualTo("15"); + } + + /** + * Verifies that ThinClientStoreModel works correctly when additionalHeaders is null + * (the default case when no workload-id is configured). This ensures backward + * compatibility — the null case should not throw or inject unexpected headers. + */ + @Test(groups = "unit") + public void testNullAdditionalHeadersThinClientStoreModel() throws Exception { + DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class); + Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig(); + Mockito + .doReturn(ImplementationBridgeHelpers + .CosmosDiagnosticsHelper + .getCosmosDiagnosticsAccessor() + .create(clientContext, 1d)) + .when(clientContext).createDiagnostics(); + + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + Mockito.doReturn("1#100#1=20#2=5#3=30").when(sessionContainer).resolveGlobalSessionToken(any()); + + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + Mockito.when(httpClient.send(any(), any())) + .thenReturn(Mono.error(new ConnectTimeoutException())); + + // Pass null for additionalHeaders — this is the default case + ThinClientStoreModel storeModel = new ThinClientStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/fakeResourceFullName", + ResourceType.Document); + + try { + storeModel.performRequest(dsr).block(); + } catch (Exception e) { + // Expected — mock HTTP client throws ConnectTimeoutException + } + + // Verify that no workload-id header was injected + assertThat(dsr.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)) + .as("workload-id header should NOT be present when additionalHeaders is null") + .isNull(); + } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java index 4001ac6a343b..ab0cc053026b 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java @@ -57,6 +57,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds, null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti null, null, null, + null, null); RxDocumentServiceRequest req = @@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh() null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask); @@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); // connected status @@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs return new HttpClientUnderTestWrapper(origHttpClient); } + /** + * Verifies that client-level additionalHeaders (e.g., workload-id) are included in + * GatewayAddressCache's defaultRequestHeaders, which are sent on every address + * resolution request. + */ + @Test(groups = { "unit" }) + public void additionalHeadersIncludedInDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + additionalHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that additionalHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) + * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers + * set before additionalHeaders are preserved. + */ + @Test(groups = { "unit" }) + public void additionalHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); + additionalHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + additionalHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // SDK headers should NOT be overwritten + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent"); + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION); + // Additional header should still be added + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that when additionalHeaders is null, GatewayAddressCache's defaultRequestHeaders + * contains only SDK system headers and no extra entries. + */ + @Test(groups = { "unit" }) + public void nullAdditionalHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + null); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // Should only contain SDK system headers, no workload-id + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT); + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION); + assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID); + } + public String getNameBasedCollectionLink() { return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId(); } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java index 331be53cc7af..5879e7d3e61c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java @@ -110,7 +110,7 @@ public void resolveAsync() throws Exception { GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider, userAgentContainer, - serviceConfigReader, connectionPolicy, null); + serviceConfigReader, connectionPolicy, null, null); RxDocumentServiceRequest request; request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(), OperationType.Read, @@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() { userAgentContainer, serviceConfigReader, connectionPolicy, + null, null); GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache(); GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java new file mode 100644 index 000000000000..9ca123e16160 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.implementation.directconnectivity.rntbd; + +import com.azure.cosmos.implementation.HttpConstants; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants. + *

+ * + * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC), + * correct token type (Byte), is not required, and is not in the thin-client ordered header list + * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()). + */ +public class RntbdWorkloadIdTests { + + /** + * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders + * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and + * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping. + */ + @Test(groups = { "unit" }) + public void workloadIdConstantExists() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader + * with the correct wire ID (0x00DC). This ID is used to identify the header in the + * binary RNTBD protocol when communicating in Direct mode. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderExists() { + // Verify WorkloadId enum value exists with correct ID + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader).isNotNull(); + assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC); + } + + /** + * Verifies that the WorkloadId RNTBD header is defined as Byte token type, + * consistent with the ThroughputBucket pattern. The workload ID value (1-50) + * is encoded as a single byte on the wire. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsByteType() { + // Verify WorkloadId is Byte type (same as ThroughputBucket pattern) + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte); + } + + /** + * Verifies that WorkloadId is not a required RNTBD header. The header is optional — + * requests without a workload ID are valid and should not be rejected by the SDK. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsNotRequired() { + // WorkloadId should not be a required header + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.isRequired()).isFalse(); + } + + /** + * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client + * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is + * excluded from this list and will be auto-encoded in the second pass of + * RntbdTokenStream.encode() along with other non-ordered headers. + */ + @Test(groups = { "unit" }) + public void workloadIdNotInThinClientOrderedList() { + // WorkloadId should NOT be in thinClientHeadersInOrderList + // It will be automatically encoded in the second pass of RntbdTokenStream.encode() + assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList) + .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId); + } + + /** + * Verifies that valid workload ID values (1-50) can be parsed from String to int + * and cast to byte without data loss. Note: the SDK itself does not validate the + * range — this test confirms the encoding path works for expected values. + */ + @Test(groups = { "unit" }) + public void workloadIdValidValues() { + // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly + String[] validValues = {"1", "25", "50"}; + for (String value : validValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + assertThat(byteVal).isBetween((byte) 1, (byte) 50); + } + } + + /** + * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause + * exceptions in the SDK's parsing path. The SDK intentionally does not validate + * the range — invalid values are accepted and sent to the service, which silently + * ignores them. + */ + @Test(groups = { "unit" }) + public void workloadIdInvalidValuesAcceptedBySdk() { + // SDK does NOT validate range — service silently ignores invalid values + // These should not throw exceptions in SDK + String[] invalidValues = {"0", "51", "-1", "100"}; + for (String value : invalidValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + // SDK accepts any integer value that fits in a byte + assertThat(byteVal).isNotNull(); + } + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java new file mode 100644 index 000000000000..6882cac1193e --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosHeaderName; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.ResourceType; +import com.azure.cosmos.implementation.RxDocumentServiceRequest; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.test.implementation.interceptor.CosmosInterceptorHelper; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Interceptor-based integration tests for the workload-id / additional headers feature + * running in Direct mode (RNTBD). + *

+ * These tests use {@link CosmosInterceptorHelper#registerTransportClientInterceptor} to + * capture {@link RxDocumentServiceRequest} objects at the TransportClient level and assert + * that the {@code x-ms-cosmos-workload-id} header is present with the correct value on + * wire requests. The interceptor only fires in Direct mode because Gateway mode data-plane + * requests go through {@code RxGatewayStoreModel → HttpClient}, bypassing the + * TransportClient entirely. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")} to ensure + * all tests run exclusively in Direct/TCP mode where the transport interceptor is active. + *

+ * Gateway-mode header injection is separately verified by: + *

    + *
  • Unit tests in {@code RxGatewayStoreModelTest} (data-plane)
  • + *
  • Unit tests in {@code GatewayAddressCacheTest} (metadata requests)
  • + *
  • Smoke tests in {@link WorkloadIdE2ETests} (Gateway mode end-to-end)
  • + *
+ */ +public class WorkloadIdDirectInterceptorTests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdDirectTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdDirectTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp") + public WorkloadIdDirectInterceptorTests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + + // Clone the shared builder before setting additionalHeaders. + // getClientBuilder() returns the same mutable instance from the data provider. + // Calling .additionalHeaders() directly on it would mutate the shared builder, + clientWithWorkloadId = copyCosmosClientBuilder(getClientBuilder()) + .additionalHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * Verifies that the client-level workload-id header (value "15") is present on the + * {@link RxDocumentServiceRequest} for create (data-plane) operation in Direct mode. + *

+ * Registers a transport client interceptor that captures every request before it hits + * the RNTBD wire. After performing a createItem, asserts that at least one captured + * request with {@code ResourceType.Document} carries the {@code x-ms-cosmos-workload-id} + * header with value {@code "15"}. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyWorkloadIdHeaderPresentOnDataPlaneRequest() { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithWorkloadId, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + // In Direct mode, the interceptor MUST capture requests — fail if it didn't + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that at least one Document-type request carries the workload-id header + boolean foundWorkloadIdOnDocument = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> "15".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID))); + + assertThat(foundWorkloadIdOnDocument) + .as("Expected workload-id header '15' on at least one Document request") + .isTrue(); + } + + /** + * Verifies that a per-request workload-id override (value "30") is present on the wire + * request instead of the client-level default (value "15"). + *

+ * This confirms that the request-level header set via + * {@link CosmosItemRequestOptions#setAdditionalHeaders(Map)} takes precedence over + * the client-level header set via {@link CosmosClientBuilder#additionalHeaders(Map)} + * in Direct mode (RNTBD). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyRequestLevelOverrideOnWire() { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithWorkloadId, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30"); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setAdditionalHeaders(requestHeaders); + + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), options).block(); + + // In Direct mode, the interceptor MUST capture requests — fail if it didn't + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that the Document request carries the overridden value "30", not the client default "15" + boolean foundOverriddenWorkloadId = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> "30".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID))); + + assertThat(foundOverriddenWorkloadId) + .as("Expected workload-id header '30' (request-level override) on Document request") + .isTrue(); + } + + /** + * Negative test: verifies that a client created WITHOUT additional headers does NOT + * have the workload-id header on wire requests in Direct mode. Ensures the header is + * only present when explicitly configured. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyNoWorkloadIdHeaderWhenNotConfigured() { + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) + .buildAsyncClient(); + + try { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithoutHeaders, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + c.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + // In Direct mode, the interceptor MUST capture requests + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that NO Document-type request carries the workload-id header + boolean anyDocRequestHasWorkloadId = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> r.getHeaders().containsKey(HttpConstants.HttpHeaders.WORKLOAD_ID)); + + assertThat(anyDocRequestHasWorkloadId) + .as("Expected NO workload-id header on Document requests when not configured") + .isFalse(); + } finally { + safeClose(clientWithoutHeaders); + } + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java new file mode 100644 index 000000000000..823acf3c01f2 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosHeaderName; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosItemResponse; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * End-to-end smoke tests for the additional headers / workload-id feature in Gateway mode. + *

+ * Test type: EMULATOR INTEGRATION TEST — requires a Cosmos DB account or emulator. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests + * in Gateway mode with Session consistency. + *

+ * These are smoke tests — they verify CRUD/query operations succeed (status code + * 200/201/204) when the workload-id header is set. They prove the header doesn't break + * anything but do NOT assert the header is actually present on the wire request. + *

+ * For wire-level assertion tests that verify the header is actually present on + * {@link com.azure.cosmos.implementation.RxDocumentServiceRequest}, see + * {@link WorkloadIdDirectInterceptorTests} which runs in Direct mode (RNTBD) using + * the transport client interceptor. + */ +public class WorkloadIdE2ETests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + @Factory(dataProvider = "simpleClientBuilderGatewaySession") + public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + + // Clone the shared builder before setting additionalHeaders. + // getClientBuilder() returns the same mutable instance from the data provider. + // Calling .additionalHeaders() directly on it would mutate the shared builder, + clientWithWorkloadId = copyCosmosClientBuilder(getClientBuilder()) + .additionalHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * verifies that a create (POST) operation succeeds when the client + * has a workload-id additional header set at the builder level. Confirms the header + * flows through the request pipeline without causing errors. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithClientLevelWorkloadId() { + TestObject doc = TestObject.create(); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a read (GET) operation succeeds with the client-level workload-id + * header and that the correct document is returned. Ensures the header does not + * interfere with normal read semantics. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void readItemWithClientLevelWorkloadId() { + // Verify read operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getItem().getId()).isEqualTo(doc.getId()); + } + + /** + * Verifies that a replace (PUT) operation succeeds with the client-level workload-id + * header. Confirms the header propagates correctly for update operations. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void replaceItemWithClientLevelWorkloadId() { + // Verify replace operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + doc.setStringProp("updated-" + UUID.randomUUID()); + CosmosItemResponse response = container + .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + } + + /** + * Verifies that a delete operation succeeds with the client-level workload-id header + * and returns the expected 204 No Content status code. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void deleteItemWithClientLevelWorkloadId() { + // Verify delete operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(204); + } + + /** + * Verifies that a per-request workload-id header override via + * {@code CosmosItemRequestOptions.setAdditionalHeaders()} works. The request-level header + * (value "30") should take precedence over the client-level default (value "15"). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override works — request-level should take precedence + TestObject doc = TestObject.create(); + + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30"); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setAdditionalHeaders(requestHeaders); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), options) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a cross-partition query operation succeeds when the client has a + * workload-id additional header. Confirms the header flows correctly through the + * query pipeline and does not affect result correctness. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithClientLevelWorkloadId() { + // Verify query operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions(); + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Verifies that a per-request workload-id header override on + * {@code CosmosQueryRequestOptions.setAdditionalHeaders()} works for query operations. + * The request-level header (value "42") should take precedence over the + * client-level default. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override on query options works + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "42"); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions() + .setAdditionalHeaders(requestHeaders); + + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Regression test: verifies that a client created without any additional headers + * continues to work normally. Ensures the additional headers feature does not + * introduce regressions for clients that do not use it. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithNoAdditionalHeadersStillWorks() { + // Verify that a client without additional headers works normally (no regression) + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithoutHeaders); + } + } + + /** + * Verifies that a client created with an empty additional headers map works normally. + * An empty map should behave identically to no additional headers — no errors, + * no unexpected behavior. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithEmptyAdditionalHeaders() { + // Verify that a client with empty additional headers map works normally + CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder()) + .additionalHeaders(new HashMap<>()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithEmptyHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithEmptyHeaders); + } + } + + + /** + * Verifies that the {@link CosmosHeaderName} allowlist rejects unknown + * header names at client build time. Attempting to set an unrecognized header via + * {@code additionalHeaders()} should throw {@link IllegalArgumentException} from + * {@link CosmosHeaderName#fromString(String)}, preventing arbitrary headers from + * being sent. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT, + expectedExceptions = IllegalArgumentException.class) + public void unknownAdditionalHeadersRejectedByAllowlist() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + // Use fromString with an unknown header — should throw IllegalArgumentException + CosmosHeaderName.fromString("x-unknown-header"); + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index cd85fb9aaafe..b0a010955989 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -5,6 +5,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) * Added support for Query Advisor feature - See [48160](https://github.com/Azure/azure-sdk-for-java/pull/48160) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java index ec0dd64af008..13f526552301 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java @@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable { .withDefaultSerializer(this.defaultCustomSerializer) .withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled()) .withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled()) + .withAdditionalHeaders(builder.getAdditionalHeaders()) .build(); this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java index 81beda97d5b4..35d8fd6aece4 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java @@ -115,6 +115,10 @@ public static CosmosClientBuilder cloneCosmosClientBuilder(CosmosClientBuilder b .multipleWriteRegionsEnabled(builder.isMultipleWriteRegionsEnabled()) .readRequestsFallbackEnabled(builder.isReadRequestsFallbackEnabled()); + if (builder.getAdditionalHeadersRaw() != null) { + copy.additionalHeaders(builder.getAdditionalHeadersRaw()); + } + return copy; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 12d022e69ee7..9f999d899cfa 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -33,10 +33,12 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; @@ -155,6 +157,7 @@ public class CosmosClientBuilder implements private boolean serverCertValidationDisabled = false; private Function containerFactory = null; + private Map additionalHeaders; /** * Instantiates a new Cosmos client builder. @@ -734,6 +737,58 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { return this; } + /** + * Sets additional HTTP headers that will be included with every request from this client. + *

+ * {@link CosmosHeaderName} defines exactly which headers are supported. Currently + * the only supported header is {@link CosmosHeaderName#WORKLOAD_ID} + * ({@code x-ms-cosmos-workload-id}). + *

+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit + * encoding support are sent on the wire. Using {@link CosmosHeaderName} ensures consistent + * behavior across both Gateway and Direct modes. + *

+ * If the same header is also set on request options (e.g., + * {@code CosmosItemRequestOptions.setAdditionalHeaders(Map)}), + * the request-level value takes precedence over the client-level value. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return current CosmosClientBuilder + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosClientBuilder additionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + this.additionalHeaders = additionalHeaders != null + ? new HashMap<>(additionalHeaders) + : null; + return this; + } + + /** + * Gets the additional headers configured on this builder, converted to a + * {@code Map} for internal use. + * @return the additional headers map with string keys, or null if not set + */ + Map getAdditionalHeaders() { + if (this.additionalHeaders == null) { + return null; + } + Map result = new HashMap<>(); + for (Map.Entry entry : this.additionalHeaders.entrySet()) { + result.put(entry.getKey().getHeaderName(), entry.getValue()); + } + return result; + } + + /** + * Gets the additional headers configured on this builder in their original + * {@code Map} form. + * @return the additional headers map, or null if not set + */ + Map getAdditionalHeadersRaw() { + return this.additionalHeaders; + } + /** * Sets the retry policy options associated with the DocumentClient instance. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java new file mode 100644 index 000000000000..f7ed5d6f6f24 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.StringJoiner; + +import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull; + +/** + * Defines the set of additional headers that can be set on a {@link CosmosClientBuilder} + * via {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * Only headers with RNTBD encoding support are included, ensuring consistent + * behavior across both Gateway mode (HTTP) and Direct mode (RNTBD binary protocol). + *

+ * This class uses the non-exhaustive final class pattern (rather than Java enum) for + * binary compatibility when new header names are added in future releases. See + * Azure SDK Java Guidelines — Enumerations. + */ +public final class CosmosHeaderName { + + private final String headerName; + + private CosmosHeaderName(String headerName) { + checkNotNull(headerName, "Argument 'headerName' must not be null."); + this.headerName = headerName; + } + + /** + * The workload ID header ({@code x-ms-cosmos-workload-id}). + *

+ * Valid values: a string representation of an integer (e.g., {@code "15"}). + * The service accepts values in the range 1–50 for Azure Monitor metrics attribution. + * The SDK validates that the value is a valid integer but does not enforce range limits — + * range validation is the backend's responsibility. + */ + public static final CosmosHeaderName WORKLOAD_ID = new CosmosHeaderName( + HttpConstants.HttpHeaders.WORKLOAD_ID); + + // IMPORTANT: ADDITIONAL_HEADERS must be declared AFTER all public static final fields above, + // because Java initializes static fields in declaration order. If this map were declared + // before WORKLOAD_ID, the map would contain null values. + private static final Map ADDITIONAL_HEADERS = createAdditionalHeadersMap(); + + /** + * Gets the canonical HTTP header name string (e.g., {@code "x-ms-cosmos-workload-id"}). + * + * @return the header name string + */ + public String getHeaderName() { + return this.headerName; + } + + /** + * Converts a header name string to the corresponding {@link CosmosHeaderName} instance. + *

+ * This is primarily used by the Spark connector, which parses header names from JSON + * configuration strings and needs to convert them to {@link CosmosHeaderName} instances + * before calling {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. + * + * @param headerName the header name string (e.g., {@code "x-ms-cosmos-workload-id"}) + * @return the matching {@link CosmosHeaderName} + * @throws IllegalArgumentException if the header name does not match any known value + */ + public static CosmosHeaderName fromString(String headerName) { + checkNotNull(headerName, "Argument 'headerName' must not be null."); + + String normalizedName = headerName.trim().toLowerCase(Locale.ROOT); + CosmosHeaderName result = ADDITIONAL_HEADERS.getOrDefault(normalizedName, null); + + if (result == null) { + throw new IllegalArgumentException( + "Unknown header: '" + headerName + "'. Allowed headers: " + getValidValues()); + } + + return result; + } + + /** + * Validates all entries in an additional-headers map. + *

+ * Each {@link CosmosHeaderName} instance carries its own validation rules. Currently: + *

    + *
  • {@link #WORKLOAD_ID}: value must be a valid integer string
  • + *
+ *

+ * This method is called by {@link CosmosClientBuilder#additionalHeaders(Map)} and + * by every request-options class's {@code setAdditionalHeaders} method, so the + * validation logic lives in one place. + * + * @param additionalHeaders the map to validate (may be null — no-op in that case) + * @throws IllegalArgumentException if any header value fails validation + */ + public static void validateAdditionalHeaders(Map additionalHeaders) { + if (additionalHeaders == null) { + return; + } + for (Map.Entry entry : additionalHeaders.entrySet()) { + CosmosHeaderName key = entry.getKey(); + String value = entry.getValue(); + + if (WORKLOAD_ID.equals(key) && value != null) { + try { + Integer.parseInt(value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + value + "' for header '" + key.getHeaderName() + + "'. The value must be a valid integer.", e); + } + } + } + } + + @Override + public String toString() { + return this.headerName; + } + + @Override + public int hashCode() { + return Objects.hash(CosmosHeaderName.class, this.headerName); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null) { + return false; + } + if (!(obj instanceof CosmosHeaderName)) { + return false; + } + + CosmosHeaderName other = (CosmosHeaderName) obj; + return Objects.equals(this.headerName, other.headerName); + } + + private static Map createAdditionalHeadersMap() { + Map map = new HashMap<>(); + map.put(HttpConstants.HttpHeaders.WORKLOAD_ID.toLowerCase(Locale.ROOT), WORKLOAD_ID); + return Collections.unmodifiableMap(map); + } + + private static String getValidValues() { + StringJoiner sj = new StringJoiner(", "); + for (CosmosHeaderName header : ADDITIONAL_HEADERS.values()) { + sj.add(header.headerName); + } + return sj.toString(); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 03590c1f8a5d..81a8f2826f04 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -116,6 +116,7 @@ class Builder { private boolean isRegionScopedSessionCapturingEnabled; private boolean isPerPartitionAutomaticFailoverEnabled; private List operationPolicies; + private Map additionalHeaders; public Builder withServiceEndpoint(String serviceEndpoint) { try { @@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu return this; } + public Builder withAdditionalHeaders(Map additionalHeaders) { + this.additionalHeaders = additionalHeaders; + return this; + } + private void ifThrowIllegalArgException(boolean value, String error) { if (value) { throw new IllegalArgumentException(error); @@ -328,7 +334,8 @@ public AsyncDocumentClient build() { defaultCustomSerializer, isRegionScopedSessionCapturingEnabled, operationPolicies, - isPerPartitionAutomaticFailoverEnabled); + isPerPartitionAutomaticFailoverEnabled, + additionalHeaders); client.init(state, null); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java index 1d2f96316fcf..e086ffba7be5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java @@ -300,6 +300,9 @@ public static class HttpHeaders { // Region affinity headers public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only"; + + // Workload ID header for Azure Monitor metrics attribution + public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id"; } public static class A_IMHeaderValues { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 191b5a969cd3..e984217353a8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -294,6 +294,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization private final AtomicReference cachedCosmosAsyncClientSnapshot; private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads; private Consumer perPartitionFailoverConfigModifier; + private Map additionalHeaders; public RxDocumentClientImpl(URI serviceEndpoint, String masterKeyOrResourceToken, @@ -367,6 +368,60 @@ public RxDocumentClientImpl(URI serviceEndpoint, boolean isRegionScopedSessionCapturingEnabled, List operationPolicies, boolean isPerPartitionAutomaticFailoverEnabled) { + this( + serviceEndpoint, + masterKeyOrResourceToken, + permissionFeed, + connectionPolicy, + consistencyLevel, + readConsistencyStrategy, + configs, + cosmosAuthorizationTokenResolver, + credential, + tokenCredential, + sessionCapturingOverride, + connectionSharingAcrossClientsEnabled, + contentResponseOnWriteEnabled, + metadataCachesSnapshot, + apiType, + clientTelemetryConfig, + clientCorrelationId, + cosmosEndToEndOperationLatencyPolicyConfig, + sessionRetryOptions, + containerProactiveInitConfig, + defaultCustomSerializer, + isRegionScopedSessionCapturingEnabled, + operationPolicies, + isPerPartitionAutomaticFailoverEnabled, + null + ); + } + + public RxDocumentClientImpl(URI serviceEndpoint, + String masterKeyOrResourceToken, + List permissionFeed, + ConnectionPolicy connectionPolicy, + ConsistencyLevel consistencyLevel, + ReadConsistencyStrategy readConsistencyStrategy, + Configs configs, + CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver, + AzureKeyCredential credential, + TokenCredential tokenCredential, + boolean sessionCapturingOverride, + boolean connectionSharingAcrossClientsEnabled, + boolean contentResponseOnWriteEnabled, + CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, + ApiType apiType, + CosmosClientTelemetryConfig clientTelemetryConfig, + String clientCorrelationId, + CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig, + SessionRetryOptions sessionRetryOptions, + CosmosContainerProactiveInitConfig containerProactiveInitConfig, + CosmosItemSerializer defaultCustomSerializer, + boolean isRegionScopedSessionCapturingEnabled, + List operationPolicies, + boolean isPerPartitionAutomaticFailoverEnabled, + Map additionalHeaders) { this( serviceEndpoint, masterKeyOrResourceToken, @@ -393,6 +448,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver; this.operationPolicies = operationPolicies; + this.additionalHeaders = additionalHeaders; } private RxDocumentClientImpl(URI serviceEndpoint, @@ -808,13 +864,15 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.userAgentContainer, this.globalEndpointManager, this.reactorHttpClient, - this.apiType); + this.apiType, + this.additionalHeaders); this.thinProxy = createThinProxy(this.sessionContainer, this.consistencyLevel, this.userAgentContainer, this.globalEndpointManager, - this.reactorHttpClient); + this.reactorHttpClient, + this.additionalHeaders); this.perPartitionFailoverConfigModifier = (databaseAccount -> { @@ -925,7 +983,8 @@ private void initializeDirectConnectivity() { // this.gatewayConfigurationReader, null, this.connectionPolicy, - this.apiType); + this.apiType, + this.additionalHeaders); this.storeClientFactory = new StoreClientFactory( this.addressResolver, @@ -969,7 +1028,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map additionalHeaders) { return new RxGatewayStoreModel( this, sessionContainer, @@ -978,21 +1038,24 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, httpClient, - apiType); + apiType, + additionalHeaders); } ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, ConsistencyLevel consistencyLevel, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, - HttpClient httpClient) { + HttpClient httpClient, + Map additionalHeaders) { return new ThinClientStoreModel( this, sessionContainer, consistencyLevel, userAgentContainer, globalEndpointManager, - httpClient); + httpClient, + additionalHeaders); } private HttpClient httpClient() { @@ -1896,6 +1959,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) { Map headers = new HashMap<>(); + // Apply client-level additional headers first (e.g., workload-id from CosmosClientBuilder.additionalHeaders()) + if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) { + headers.putAll(this.additionalHeaders); + } + if (this.useMultipleWriteLocations) { headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString()); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java index 42172026ad5b..689bccdca2a9 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java @@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize private GatewayServiceConfigurationReader gatewayServiceConfigurationReader; private RxClientCollectionCache collectionCache; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map additionalHeaders; public RxGatewayStoreModel( DiagnosticsClientContext clientContext, @@ -100,7 +101,8 @@ public RxGatewayStoreModel( UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map additionalHeaders) { this.clientContext = clientContext; @@ -116,6 +118,7 @@ public RxGatewayStoreModel( this.httpClient = httpClient; this.sessionContainer = sessionContainer; + this.additionalHeaders = additionalHeaders; } public RxGatewayStoreModel(RxGatewayStoreModel inner) { @@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) { this.httpClient = inner.httpClient; this.sessionContainer = inner.sessionContainer; + this.additionalHeaders = inner.additionalHeaders; } protected Map getDefaultHeaders( @@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics(); } + // Apply client-level additional headers (e.g., workload-id) to all requests + // including metadata requests (collection cache, partition key range, etc.) + if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) { + for (Map.Entry entry : this.additionalHeaders.entrySet()) { + // Only set if not already present — request-level headers take precedence + if (!request.getHeaders().containsKey(entry.getKey())) { + request.getHeaders().put(entry.getKey(), entry.getValue()); + } + } + } + URI uri = getUri(request); request.requestContext.resourcePhysicalAddress = uri.toString(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java index 92d1c197525e..975747faa453 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java @@ -49,7 +49,8 @@ public ThinClientStoreModel( ConsistencyLevel defaultConsistencyLevel, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, - HttpClient httpClient) { + HttpClient httpClient, + Map additionalHeaders) { super( clientContext, sessionContainer, @@ -58,7 +59,8 @@ public ThinClientStoreModel( userAgentContainer, globalEndpointManager, httpClient, - ApiType.SQL); + ApiType.SQL, + additionalHeaders); String userAgent = userAgentContainer != null ? userAgentContainer.getUserAgent() diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index e62d7b8c6ca4..1c09f37e6fa8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -123,7 +123,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map additionalHeaders) { this.clientContext = clientContext; try { @@ -165,6 +166,16 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); + // Apply client-level additional headers (e.g., workload-id) to address resolution requests. + // Address resolution is the one metadata path that bypasses RxGatewayStoreModel (which + // handles all other metadata requests) — GatewayAddressCache calls httpClient.send() directly. + // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten. + if (additionalHeaders != null && !additionalHeaders.isEmpty()) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + this.lastForcedRefreshMap = new ConcurrentHashMap<>(); this.globalEndpointManager = globalEndpointManager; this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor; @@ -188,7 +199,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map additionalHeaders) { this(clientContext, serviceEndpoint, protocol, @@ -200,7 +212,8 @@ public GatewayAddressCache( globalEndpointManager, connectionPolicy, proactiveOpenConnectionsProcessor, - gatewayServerErrorInjector); + gatewayServerErrorInjector, + additionalHeaders); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java index 00905682b4d1..6dded78b8a07 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java @@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver { private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor; private ConnectionPolicy connectionPolicy; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map additionalHeaders; public GlobalAddressResolver( DiagnosticsClientContext diagnosticsClientContext, @@ -74,7 +75,8 @@ public GlobalAddressResolver( UserAgentContainer userAgentContainer, GatewayServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, - ApiType apiType) { + ApiType apiType, + Map additionalHeaders) { this.diagnosticsClientContext = diagnosticsClientContext; this.httpClient = httpClient; this.endpointManager = endpointManager; @@ -86,6 +88,7 @@ public GlobalAddressResolver( this.serviceConfigReader = serviceConfigReader; this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled(); this.connectionPolicy = connectionPolicy; + this.additionalHeaders = additionalHeaders; int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0; this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover) @@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) { this.endpointManager, this.connectionPolicy, this.proactiveOpenConnectionsProcessor, - this.gatewayServerErrorInjector); + this.gatewayServerErrorInjector, + this.additionalHeaders); AddressResolver addressResolver = new AddressResolver(); addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache); EndpointCache cache = new EndpointCache(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index 40b9b82f6e3d..ae596cec0494 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -597,8 +597,9 @@ public enum RntbdRequestHeader implements RntbdHeader { ChangeFeedWireFormatVersion((short) 0x00B2, RntbdTokenType.String, false), PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), - ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), PopulateQueryAdvice((short) 0x00DA, RntbdTokenType.Byte, false), + ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false), HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); public static final List thinClientHeadersInOrderList = Arrays.asList( diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index b5abd5b19d88..1bb005e0ad8a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonFilter; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { // region Fields + private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class); private static final String URL_TRIM = "/"; // endregion @@ -135,6 +138,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { this.addThroughputBucket(headers); this.addPopulateQueryAdvice(headers); this.addHubRegionProcessingOnly(headers); + this.addWorkloadId(headers); // Normal headers (Strings, Ints, Longs, etc.) @@ -300,6 +304,8 @@ private RntbdToken getCorrelatedActivityId() { private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); } + private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); } + private RntbdToken getGlobalDatabaseAccountName() { return this.get(RntbdRequestHeader.GlobalDatabaseAccountName); } @@ -826,6 +832,19 @@ private void addHubRegionProcessingOnly(final Map headers) { } } + private void addWorkloadId(final Map headers) { + final String value = headers.get(HttpHeaders.WORKLOAD_ID); + + if (StringUtils.isNotEmpty(value)) { + try { + final int workloadId = Integer.parseInt(value); + this.getWorkloadId().setValue((byte) workloadId); + } catch (NumberFormatException e) { + logger.warn("Invalid value for workload id header: {}", value, e); + } + } + } + private void addGlobalDatabaseAccountName(final Map headers) { final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index 7d5a27324f95..978f0a43ec37 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosItemSerializer; @@ -154,11 +155,35 @@ RequestOptions toRequestOptions() { } /** - * Sets the custom batch request option value by key + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosBatchRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosBatchRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * + * @param name the header name + * @param value the header value * @return the CosmosBatchRequestOptions. */ CosmosBatchRequestOptions setHeader(String name, String value) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index f125c02d6725..f576d31cb86b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.implementation.CosmosBulkExecutionOptionsImpl; import com.azure.cosmos.implementation.ImplementationBridgeHelpers; @@ -257,10 +258,35 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat } /** - * Sets the custom bulk request option value by key + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosBulkExecutionOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosBulkExecutionOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. + * + * @param name the header name + * @param value the header value * @return the CosmosBulkExecutionOptions. */ CosmosBulkExecutionOptions setHeader(String name, String value) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index 3ac526de6d63..f08dee8b04fe 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.ReadConsistencyStrategy; @@ -564,11 +565,35 @@ public List getExcludedRegions() { } /** - * Sets the custom change feed request option value by key + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosChangeFeedRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosChangeFeedRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * + * @param name the header name + * @param value the header value * @return the CosmosChangeFeedRequestOptions. */ CosmosChangeFeedRequestOptions setHeader(String name, String value) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java index 1a2cad1b85ed..1747d4605a80 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java @@ -2,9 +2,13 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to Cosmos container. */ @@ -15,6 +19,7 @@ public final class CosmosContainerRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private ThroughputProperties throughputProperties; + private Map customOptions; /** * Gets the quotaInfoEnabled setting for cosmos container read requests in the Azure Cosmos DB database service. @@ -141,6 +146,39 @@ CosmosContainerRequestOptions setThroughputProperties(ThroughputProperties throu return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosContainerRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosContainerRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosContainerRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions options = new RequestOptions(); options.setIfMatchETag(getIfMatchETag()); @@ -149,6 +187,11 @@ RequestOptions toRequestOptions() { options.setSessionToken(sessionToken); options.setConsistencyLevel(consistencyLevel); options.setThroughputProperties(this.throughputProperties); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + options.setHeader(entry.getKey(), entry.getValue()); + } + } return options; } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java index a4817e14be7e..618bf03ffa99 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java @@ -2,8 +2,12 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to cosmos database. */ @@ -11,6 +15,7 @@ public final class CosmosDatabaseRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private ThroughputProperties throughputProperties; + private Map customOptions; /** * Gets the If-Match (ETag) associated with the request in the Azure Cosmos DB service. @@ -73,11 +78,49 @@ CosmosDatabaseRequestOptions setThroughputProperties(ThroughputProperties throug return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosDatabaseRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosDatabaseRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosDatabaseRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions options = new RequestOptions(); options.setIfMatchETag(getIfMatchETag()); options.setIfNoneMatchETag(getIfNoneMatchETag()); options.setThroughputProperties(this.throughputProperties); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + options.setHeader(entry.getKey(), entry.getValue()); + } + } return options; } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index 72eb108a6428..36c7a97c318a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -3,6 +3,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosClientBuilder; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; @@ -566,11 +567,35 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre } /** - * Sets the custom item request option value by key + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosItemRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosItemRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * + * @param name the header name + * @param value the header value * @return the CosmosItemRequestOptions. */ CosmosItemRequestOptions setHeader(String name, String value) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index 722f66bf65d0..75b84c7f9cc6 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnostics; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; @@ -260,6 +261,43 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosQueryRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosQueryRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. + * + * @param name the header name + * @param value the header value + * @return the CosmosQueryRequestOptions. + */ + CosmosQueryRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + /** * Gets the list of regions to exclude for the request/retries. These regions are excluded * from the preferred region list. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index f6e570258042..04ac31f30f06 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosItemSerializer; @@ -15,6 +16,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -366,6 +368,43 @@ public Set getKeywordIdentifiers() { return this.actualRequestOptions.getKeywordIdentifiers(); } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosReadManyRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosReadManyRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. + * + * @param name the header name + * @param value the header value + * @return the CosmosReadManyRequestOptions. + */ + CosmosReadManyRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + CosmosQueryRequestOptionsBase getImpl() { return this.actualRequestOptions; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java index 02f360299c42..42486a1ad6b7 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java @@ -2,9 +2,13 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to cosmos stored procedure. */ @@ -15,6 +19,7 @@ public final class CosmosStoredProcedureRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private boolean scriptLoggingEnabled; + private Map customOptions; /** * Gets the If-Match (ETag) associated with the request in the Azure Cosmos DB service. @@ -158,6 +163,39 @@ public CosmosStoredProcedureRequestOptions setScriptLoggingEnabled(boolean scrip return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosStoredProcedureRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosStoredProcedureRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosStoredProcedureRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions requestOptions = new RequestOptions(); requestOptions.setIfMatchETag(getIfMatchETag()); @@ -166,6 +204,11 @@ RequestOptions toRequestOptions() { requestOptions.setPartitionKey(partitionKey); requestOptions.setSessionToken(sessionToken); requestOptions.setScriptLoggingEnabled(scriptLoggingEnabled); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + requestOptions.setHeader(entry.getKey(), entry.getValue()); + } + } return requestOptions; } }