diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index 31ae31b1b013..6551550bd436 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index f4e90bc17b63..961c59201704 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 39254b96c705..daa856224716 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index 5d677995de71..d3aac976d3c2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 3b87ef08c3a0..0f0318057eee 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -13,7 +13,7 @@ import com.azure.cosmos.models.{CosmosClientTelemetryConfig, CosmosMetricCategor import com.azure.cosmos.spark.CosmosPredicates.isOnSparkDriver import com.azure.cosmos.spark.catalog.{CosmosCatalogClient, CosmosCatalogCosmosSDKClient, CosmosCatalogManagementSDKClient} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} +import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, CosmosHeaderName, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder} import com.azure.monitor.opentelemetry.autoconfigure.{AzureMonitorAutoConfigure, AzureMonitorAutoConfigureOptions} import com.azure.resourcemanager.cosmos.CosmosManager @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import java.util.function.BiPredicate import scala.collection.concurrent.TrieMap - // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import @@ -713,6 +712,17 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + // Apply additional HTTP headers (e.g., workload-id) to the builder if configured. + // Header name validation already happened at config-parse time in CosmosConfig.AdditionalHeadersConfig, + // so CosmosHeaderName.fromString() here is just a type conversion for already-validated keys. + if (cosmosClientConfiguration.additionalHeaders.isDefined) { + val headerMap = new java.util.HashMap[CosmosHeaderName, String]() + for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) { + headerMap.put(CosmosHeaderName.fromString(key), value) + } + builder.additionalHeaders(headerMap) + } + var client = builder.buildAsyncClient() if (cosmosClientConfiguration.clientInterceptors.isDefined) { @@ -916,7 +926,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Additional HTTP headers are part of the cache key because different workload-ids + // should produce different CosmosAsyncClient instances + additionalHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -935,7 +948,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientBuilderInterceptors, clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, - clientConfig.azureMonitorConfig + clientConfig.azureMonitorConfig, + clientConfig.additionalHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 6f4e26e1f503..7e09c73698a7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration ( clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Optional additional HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientBuilderInterceptors, cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, - diagnosticsConfig.azureMonitorConfig + diagnosticsConfig.azureMonitorConfig, + cosmosAccountConfig.additionalHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..10b0d5af8623 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -4,10 +4,10 @@ package com.azure.cosmos.spark import com.azure.core.management.AzureEnvironment -import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} +import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, CosmosHeaderName, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper -import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} +import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition} import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter import java.time.{Duration, Instant} import java.util import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable @@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" + // Additional HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} + // Flows through to CosmosClientBuilder.additionalHeaders(). + val AdditionalHeaders = "spark.cosmos.additionalHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames { WriteOnRetryCommitInterceptor, WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, - WriteMaxRetryNoProgressIntervalInSeconds + WriteMaxRetryNoProgressIntervalInSeconds, + AdditionalHeaders ) def validateConfigName(name: String): Unit = { @@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String, resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], - clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + // Optional additional HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.additionalHeaders JSON config, passed to CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -727,6 +736,40 @@ private object CosmosAccountConfig extends BasicLoggingTrait { parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like + // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. + // These headers are converted to Map[CosmosHeaderName, String] and passed to + // CosmosClientBuilder.additionalHeaders() in CosmosClientCache. + // + // Validation: After JSON parsing, every header name is validated via CosmosHeaderName.fromString() + // to fail fast at config-parse time rather than at runtime during client creation. + // This prevents Spark jobs from starting, allocating cluster resources, and only failing + // later when CosmosClientCache tries to convert String keys to CosmosHeaderName instances. + private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.AdditionalHeaders, + mandatory = false, + parseFromStringFunction = headersJson => { + try { + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + val parsed = Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + + // Fail fast: validate every header name is a known CosmosHeaderName at parse time. + // Without this, unknown headers like {"x-bad-header": "value"} would parse successfully + // and only blow up at runtime in CosmosClientCache when CosmosHeaderName.fromString() is called. + for (key <- parsed.keys) { + CosmosHeaderName.fromString(key) // throws IllegalArgumentException for unknown headers + } + + parsed + } catch { + case e: IllegalArgumentException => throw e + case e: Exception => throw new IllegalArgumentException( + s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " + + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) + } + }, + helpMessage = "Optional additional headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -761,6 +804,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) + // Parse optional additional HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val additionalHeaders = CosmosConfigEntry.parse(cfg, AdditionalHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -880,7 +925,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { resourceGroupNameOpt, azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, - if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, + additionalHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index ccf36791dc96..8e44968a76df 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -64,7 +64,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -91,7 +92,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -118,7 +120,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ), ( @@ -145,7 +148,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) ) ) @@ -179,8 +183,9 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None - ) + azureMonitorConfig = None, + additionalHeaders = None + ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 7fcc601ba016..9752136066f4 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -408,4 +408,75 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ") configuration.azureMonitorConfig shouldEqual None } + + // Verifies that the spark.cosmos.additionalHeaders configuration option correctly parses + // a JSON string containing a single workload-id header into a Map[String, String] on + // CosmosClientConfiguration. This is the primary use case for the workload-id feature. + it should "parse additionalHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + } + + // Verifies that when spark.cosmos.additionalHeaders is not specified in the config map, + // CosmosClientConfiguration.additionalHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set additionalHeaders continue to work without changes. + it should "handle missing additionalHeaders" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe None + } + + // Verifies that unknown header names in spark.cosmos.additionalHeaders are rejected + // at config-parse time. This ensures Spark jobs fail fast during configuration rather than + // starting, allocating cluster resources, and only failing later at runtime when + // CosmosClientCache tries to create the client. + // Note: CosmosConfigEntry.parse() wraps all parsing exceptions in RuntimeException, + // so the IllegalArgumentException from CosmosHeaderName.fromString() surfaces as + // RuntimeException with the original IllegalArgumentException as the cause. + it should "reject unknown additional headers at parse time" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val thrown = the [RuntimeException] thrownBy { + CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + } + thrown.getCause shouldBe a [IllegalArgumentException] + thrown.getCause.getMessage should include ("x-custom-header") + } + + // Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully, + // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases + // and that no headers are injected when the JSON object is empty. + it should "handle empty additionalHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.additionalHeaders" -> "{}" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get shouldBe empty + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 6ef90b55989d..f745bae6b56d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index dfd14c36c80f..c17e7b02dad8 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala new file mode 100644 index 000000000000..0f0a1f200d3e --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +/** + * Integration tests (smoke tests) for the additional headers (workload-id) feature in the Spark connector. + * These are smoke tests — they verify that Spark DataFrame read and write operations succeed + * (no errors, correct data) when the `spark.cosmos.additionalHeaders` configuration is set. + * They do NOT assert that the workload-id header is actually present on the wire request. + * Wire-level header propagation is verified by: + * - Java SDK unit tests: RxGatewayStoreModelTest, GatewayAddressCacheTest + * - Java SDK integration tests: WorkloadIdE2ETests (interceptor-based wire assertions) + * + * Test cases: + * 1. Read with workload-id header — Spark read succeeds, correct item returned + * 2. Write with workload-id header — Spark write succeeds, item verified via SDK read-back + * 3. No additionalHeaders (regression) — operations succeed without the config set + */ +class SparkE2EWorkloadIdITest + extends IntegrationSpec + with Spark + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait { + + val objectMapper = new ObjectMapper() + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + // Integration smoke test #1: Spark read with workload-id header. + // Creates an item via SDK, then reads it back via Spark DataFrame with + // spark.cosmos.additionalHeaders set to {"x-ms-cosmos-workload-id": "15"}. + // Verifies the read succeeds and returns the correct item. + // This proves the header flows through the Spark config pipeline without causing errors. + "spark query with additionalHeaders" can "read items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } + + // Integration smoke test #2: Spark write with workload-id header. + // Writes an item via Spark DataFrame with spark.cosmos.additionalHeaders set to + // {"x-ms-cosmos-workload-id": "20"}, then reads it back via SDK to confirm + // write was persisted correctly. + // This proves the header flows through the Spark write pipeline without causing errors. + "spark write with additionalHeaders" can "write items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testWriteItem" + | } + |""".stripMargin + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.write.strategy" -> "ItemOverwrite", + "spark.cosmos.write.bulk.enabled" -> "false", + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + val spark_session = spark + import spark_session.implicits._ + val df = spark.read.json(Seq(rawItem).toDS()) + + df.write.format("cosmos.oltp").options(cfg).mode("Append").save() + + // Verify the item was written by reading it back via the SDK directly + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block() + readItem.getItem.get("id").textValue() shouldEqual id + readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" + } + + // Integration smoke test #3: Regression test — no additionalHeaders configured. + // Verifies that Spark read operations continue to work correctly when + // spark.cosmos.additionalHeaders is NOT specified in the config map. + // This ensures the feature addition does not break existing Spark jobs + // that don't use additional headers. + "spark operations without additionalHeaders" can "still succeed" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "noHeadersItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + rowsArray(0).getAs[String]("id") shouldEqual id + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals + //scalastyle:on null +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index 9a438a3f1c53..b36949f18c57 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.46.0-beta.1 (Unreleased) #### Features Added +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java new file mode 100644 index 000000000000..41fc62f62ca3 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java @@ -0,0 +1,397 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.ImplementationBridgeHelpers; +import com.azure.cosmos.implementation.RequestOptions; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosContainerRequestOptions; +import com.azure.cosmos.models.CosmosDatabaseRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; +import com.azure.cosmos.models.CosmosStoredProcedureRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Comprehensive unit tests for the workload-id / additional headers feature. + *
+ * Covers three layers: + *
+ * This is for the workload-id feature: ThinClientStoreModel extends
+ * RxGatewayStoreModel, and the additionalHeaders must flow through the constructor
+ * chain so that performRequest() injects them into every request — including
+ * metadata requests (collection cache, PKRange cache, etc.).
+ */
+ @Test(groups = "unit")
+ public void testAdditionalHeadersFlowThroughThinClientStoreModel() throws Exception {
+ DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class);
+ Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig();
+ Mockito
+ .doReturn(ImplementationBridgeHelpers
+ .CosmosDiagnosticsHelper
+ .getCosmosDiagnosticsAccessor()
+ .create(clientContext, 1d))
+ .when(clientContext).createDiagnostics();
+
+ String sdkGlobalSessionToken = "1#100#1=20#2=5#3=30";
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ Mockito.doReturn(sdkGlobalSessionToken).when(sessionContainer).resolveGlobalSessionToken(any());
+
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ // Capture the HttpRequest sent by performRequest() to verify headers
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor
+ *
+ * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC),
+ * correct token type (Byte), is not required, and is not in the thin-client ordered header list
+ * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()).
+ */
+public class RntbdWorkloadIdTests {
+
+ /**
+ * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders
+ * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and
+ * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdConstantExists() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader
+ * with the correct wire ID (0x00DC). This ID is used to identify the header in the
+ * binary RNTBD protocol when communicating in Direct mode.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderExists() {
+ // Verify WorkloadId enum value exists with correct ID
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader).isNotNull();
+ assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC);
+ }
+
+ /**
+ * Verifies that the WorkloadId RNTBD header is defined as Byte token type,
+ * consistent with the ThroughputBucket pattern. The workload ID value (1-50)
+ * is encoded as a single byte on the wire.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsByteType() {
+ // Verify WorkloadId is Byte type (same as ThroughputBucket pattern)
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte);
+ }
+
+ /**
+ * Verifies that WorkloadId is not a required RNTBD header. The header is optional —
+ * requests without a workload ID are valid and should not be rejected by the SDK.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsNotRequired() {
+ // WorkloadId should not be a required header
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.isRequired()).isFalse();
+ }
+
+ /**
+ * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client
+ * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is
+ * excluded from this list and will be auto-encoded in the second pass of
+ * RntbdTokenStream.encode() along with other non-ordered headers.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdNotInThinClientOrderedList() {
+ // WorkloadId should NOT be in thinClientHeadersInOrderList
+ // It will be automatically encoded in the second pass of RntbdTokenStream.encode()
+ assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList)
+ .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId);
+ }
+
+ /**
+ * Verifies that valid workload ID values (1-50) can be parsed from String to int
+ * and cast to byte without data loss. Note: the SDK itself does not validate the
+ * range — this test confirms the encoding path works for expected values.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdValidValues() {
+ // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly
+ String[] validValues = {"1", "25", "50"};
+ for (String value : validValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ assertThat(byteVal).isBetween((byte) 1, (byte) 50);
+ }
+ }
+
+ /**
+ * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause
+ * exceptions in the SDK's parsing path. The SDK intentionally does not validate
+ * the range — invalid values are accepted and sent to the service, which silently
+ * ignores them.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdInvalidValuesAcceptedBySdk() {
+ // SDK does NOT validate range — service silently ignores invalid values
+ // These should not throw exceptions in SDK
+ String[] invalidValues = {"0", "51", "-1", "100"};
+ for (String value : invalidValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ // SDK accepts any integer value that fits in a byte
+ assertThat(byteVal).isNotNull();
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java
new file mode 100644
index 000000000000..6882cac1193e
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java
@@ -0,0 +1,226 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.rx;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosAsyncDatabase;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.CosmosHeaderName;
+import com.azure.cosmos.TestObject;
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.implementation.ResourceType;
+import com.azure.cosmos.implementation.RxDocumentServiceRequest;
+import com.azure.cosmos.models.CosmosContainerProperties;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.PartitionKey;
+import com.azure.cosmos.models.PartitionKeyDefinition;
+import com.azure.cosmos.test.implementation.interceptor.CosmosInterceptorHelper;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Factory;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Interceptor-based integration tests for the workload-id / additional headers feature
+ * running in Direct mode (RNTBD).
+ *
+ * These tests use {@link CosmosInterceptorHelper#registerTransportClientInterceptor} to
+ * capture {@link RxDocumentServiceRequest} objects at the TransportClient level and assert
+ * that the {@code x-ms-cosmos-workload-id} header is present with the correct value on
+ * wire requests. The interceptor only fires in Direct mode because Gateway mode data-plane
+ * requests go through {@code RxGatewayStoreModel → HttpClient}, bypassing the
+ * TransportClient entirely.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")} to ensure
+ * all tests run exclusively in Direct/TCP mode where the transport interceptor is active.
+ *
+ * Gateway-mode header injection is separately verified by:
+ *
+ * Registers a transport client interceptor that captures every request before it hits
+ * the RNTBD wire. After performing a createItem, asserts that at least one captured
+ * request with {@code ResourceType.Document} carries the {@code x-ms-cosmos-workload-id}
+ * header with value {@code "15"}.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void verifyWorkloadIdHeaderPresentOnDataPlaneRequest() {
+ ConcurrentLinkedQueue
+ * This confirms that the request-level header set via
+ * {@link CosmosItemRequestOptions#setAdditionalHeaders(Map)} takes precedence over
+ * the client-level header set via {@link CosmosClientBuilder#additionalHeaders(Map)}
+ * in Direct mode (RNTBD).
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void verifyRequestLevelOverrideOnWire() {
+ ConcurrentLinkedQueue
+ * Test type: EMULATOR INTEGRATION TEST — requires a Cosmos DB account or emulator.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests
+ * in Gateway mode with Session consistency.
+ *
+ * These are smoke tests — they verify CRUD/query operations succeed (status code
+ * 200/201/204) when the workload-id header is set. They prove the header doesn't break
+ * anything but do NOT assert the header is actually present on the wire request.
+ *
+ * For wire-level assertion tests that verify the header is actually present on
+ * {@link com.azure.cosmos.implementation.RxDocumentServiceRequest}, see
+ * {@link WorkloadIdDirectInterceptorTests} which runs in Direct mode (RNTBD) using
+ * the transport client interceptor.
+ */
+public class WorkloadIdE2ETests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ @Factory(dataProvider = "simpleClientBuilderGatewaySession")
+ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map
+ *
+ */
+public class WorkloadIdDirectInterceptorTests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdDirectTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdDirectTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")
+ public WorkloadIdDirectInterceptorTests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map