Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.46.0-beta.1 (Unreleased)

#### Features Added
* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.46.0-beta.1 (Unreleased)

#### Features Added
* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.46.0-beta.1 (Unreleased)

#### Features Added
* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.46.0-beta.1 (Unreleased)

#### Features Added
* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.azure.cosmos.models.{CosmosClientTelemetryConfig, CosmosMetricCategor
import com.azure.cosmos.spark.CosmosPredicates.isOnSparkDriver
import com.azure.cosmos.spark.catalog.{CosmosCatalogClient, CosmosCatalogCosmosSDKClient, CosmosCatalogManagementSDKClient}
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions}
import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, CosmosHeaderName, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions}
import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder}
import com.azure.monitor.opentelemetry.autoconfigure.{AzureMonitorAutoConfigure, AzureMonitorAutoConfigureOptions}
import com.azure.resourcemanager.cosmos.CosmosManager
Expand Down Expand Up @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.function.BiPredicate
import scala.collection.concurrent.TrieMap

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
Expand Down Expand Up @@ -713,6 +712,17 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}

// Apply additional HTTP headers (e.g., workload-id) to the builder if configured.
// Header name validation already happened at config-parse time in CosmosConfig.AdditionalHeadersConfig,
// so CosmosHeaderName.fromString() here is just a type conversion for already-validated keys.
if (cosmosClientConfiguration.additionalHeaders.isDefined) {
val headerMap = new java.util.HashMap[CosmosHeaderName, String]()
for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) {
headerMap.put(CosmosHeaderName.fromString(key), value)
}
builder.additionalHeaders(headerMap)
}

var client = builder.buildAsyncClient()

if (cosmosClientConfiguration.clientInterceptors.isDefined) {
Expand Down Expand Up @@ -916,7 +926,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig]
azureMonitorConfig: Option[AzureMonitorConfig],
// Additional HTTP headers are part of the cache key because different workload-ids
// should produce different CosmosAsyncClient instances
additionalHeaders: Option[Map[String, String]]
)

private[this] object ClientConfigurationWrapper {
Expand All @@ -935,7 +948,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientBuilderInterceptors,
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
clientConfig.azureMonitorConfig
clientConfig.azureMonitorConfig,
clientConfig.additionalHeaders
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration (
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig]
azureMonitorConfig: Option[AzureMonitorConfig],
// Optional additional HTTP headers (e.g., workload-id) to attach to
// all Cosmos DB requests via CosmosClientBuilder.additionalHeaders()
additionalHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
Expand Down Expand Up @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientBuilderInterceptors,
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
diagnosticsConfig.azureMonitorConfig
diagnosticsConfig.azureMonitorConfig,
cosmosAccountConfig.additionalHeaders
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
package com.azure.cosmos.spark

import com.azure.core.management.AzureEnvironment
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark}
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, CosmosHeaderName, ReadConsistencyStrategy, spark}
import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants
import com.azure.cosmos.implementation.routing.LocationHelper
import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings}
import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils}
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
Expand All @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util
import java.util.{Locale, ServiceLoader}
import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.{HashSet, List, Map}
import scala.collection.mutable
Expand Down Expand Up @@ -151,6 +152,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
// Additional HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
// Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
// Flows through to CosmosClientBuilder.additionalHeaders().
val AdditionalHeaders = "spark.cosmos.additionalHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
Expand Down Expand Up @@ -297,7 +302,8 @@ private[spark] object CosmosConfigNames {
WriteOnRetryCommitInterceptor,
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
WriteMaxRetryNoProgressIntervalInSeconds
WriteMaxRetryNoProgressIntervalInSeconds,
AdditionalHeaders
)

def validateConfigName(name: String): Unit = {
Expand Down Expand Up @@ -540,7 +546,10 @@ private case class CosmosAccountConfig(endpoint: String,
resourceGroupName: Option[String],
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
// Optional additional HTTP headers (e.g., workload-id) parsed from
// spark.cosmos.additionalHeaders JSON config, passed to CosmosClientBuilder.additionalHeaders()
additionalHeaders: Option[Map[String, String]]
)

private object CosmosAccountConfig extends BasicLoggingTrait {
Expand Down Expand Up @@ -727,6 +736,40 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN,
helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.")

// Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
// {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
// These headers are converted to Map[CosmosHeaderName, String] and passed to
// CosmosClientBuilder.additionalHeaders() in CosmosClientCache.
//
// Validation: After JSON parsing, every header name is validated via CosmosHeaderName.fromString()
// to fail fast at config-parse time rather than at runtime during client creation.
// This prevents Spark jobs from starting, allocating cluster resources, and only failing
// later when CosmosClientCache tries to convert String keys to CosmosHeaderName instances.
private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]](
key = CosmosConfigNames.AdditionalHeaders,
mandatory = false,
parseFromStringFunction = headersJson => {
try {
val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
val parsed = Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap

// Fail fast: validate every header name is a known CosmosHeaderName at parse time.
// Without this, unknown headers like {"x-bad-header": "value"} would parse successfully
// and only blow up at runtime in CosmosClientCache when CosmosHeaderName.fromString() is called.
for (key <- parsed.keys) {
CosmosHeaderName.fromString(key) // throws IllegalArgumentException for unknown headers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Suggestion · Fail-Fast Gap: Header values not validated at Spark config parse time

The AdditionalHeadersConfig parser validates header names at parse time (via CosmosHeaderName.fromString(key)) but does not validate header values. A Spark config like {"x-ms-cosmos-workload-id": "abc"} will pass parsing successfully, start the job, allocate executors, and only fail when CosmosClientBuilder.additionalHeaders() calls CosmosHeaderName.validateAdditionalHeaders() on the executor — at which point cluster resources are already allocated.

The code comment above explicitly states the fail-fast goal: "validate every header name is a known CosmosHeaderName at parse time … rather than at runtime during client creation." The same reasoning applies to values.

Suggested fix — after the name validation loop, convert and validate values too:

for (key <- parsed.keys) {
  CosmosHeaderName.fromString(key)
}
// Also validate values at parse time to complete fail-fast coverage
val headerMap = new java.util.HashMap[CosmosHeaderName, String]()
for ((key, value) <- parsed) {
  headerMap.put(CosmosHeaderName.fromString(key), value)
}
CosmosHeaderName.validateAdditionalHeaders(headerMap)

⚠️ AI-generated review — may be incorrect. Agree? → resolve the conversation. Disagree? → reply with your reasoning.

}

parsed
} catch {
case e: IllegalArgumentException => throw e
case e: Exception => throw new IllegalArgumentException(
s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " +
"Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
}
},
helpMessage = "Optional additional headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")

private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
try {
Expand Down Expand Up @@ -761,6 +804,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
// Parse optional additional HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
val additionalHeaders = CosmosConfigEntry.parse(cfg, AdditionalHeadersConfig)

val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
Expand Down Expand Up @@ -880,7 +925,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
resourceGroupNameOpt,
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None })
if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
additionalHeaders)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
additionalHeaders = None
)
),
(
Expand All @@ -91,7 +92,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
additionalHeaders = None
)
),
(
Expand All @@ -118,7 +120,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
additionalHeaders = None
)
),
(
Expand All @@ -145,7 +148,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
azureMonitorConfig = None,
additionalHeaders = None
)
)
)
Expand Down Expand Up @@ -179,8 +183,9 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None
)
azureMonitorConfig = None,
additionalHeaders = None
)

logInfo(s"TestCase: {$testCaseName}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,75 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ")
configuration.azureMonitorConfig shouldEqual None
}

// Verifies that the spark.cosmos.additionalHeaders configuration option correctly parses
// a JSON string containing a single workload-id header into a Map[String, String] on
// CosmosClientConfiguration. This is the primary use case for the workload-id feature.
it should "parse additionalHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.additionalHeaders shouldBe defined
configuration.additionalHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
}

// Verifies that when spark.cosmos.additionalHeaders is not specified in the config map,
// CosmosClientConfiguration.additionalHeaders is None. This ensures backward compatibility —
// existing Spark jobs that don't set additionalHeaders continue to work without changes.
it should "handle missing additionalHeaders" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz"
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.additionalHeaders shouldBe None
}

// Verifies that unknown header names in spark.cosmos.additionalHeaders are rejected
// at config-parse time. This ensures Spark jobs fail fast during configuration rather than
// starting, allocating cluster resources, and only failing later at runtime when
// CosmosClientCache tries to create the client.
// Note: CosmosConfigEntry.parse() wraps all parsing exceptions in RuntimeException,
// so the IllegalArgumentException from CosmosHeaderName.fromString() surfaces as
// RuntimeException with the original IllegalArgumentException as the cause.
it should "reject unknown additional headers at parse time" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val thrown = the [RuntimeException] thrownBy {
CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
}
thrown.getCause shouldBe a [IllegalArgumentException]
thrown.getCause.getMessage should include ("x-custom-header")
}

// Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully,
// resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
// and that no headers are injected when the JSON object is empty.
it should "handle empty additionalHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
"spark.cosmos.additionalHeaders" -> "{}"
)

val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")

configuration.additionalHeaders shouldBe defined
configuration.additionalHeaders.get shouldBe empty
}
}
Loading
Loading