From f894bf425c9b5be3800a32ffcad4d977a108620d Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 17:33:05 -0600
Subject: [PATCH 01/13] workload-id feature - initial commit
---
.../com/azure/cosmos/CustomHeadersTests.java | 170 +++++++++
.../rntbd/RntbdWorkloadIdTests.java | 115 ++++++
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 327 ++++++++++++++++++
.../com/azure/cosmos/CosmosAsyncClient.java | 1 +
.../com/azure/cosmos/CosmosClientBuilder.java | 29 ++
.../implementation/AsyncDocumentClient.java | 9 +-
.../cosmos/implementation/HttpConstants.java | 3 +
.../implementation/RxDocumentClientImpl.java | 61 ++++
.../rntbd/RntbdConstants.java | 3 +-
.../rntbd/RntbdRequestHeaders.java | 19 +
.../models/CosmosBatchRequestOptions.java | 13 +-
.../models/CosmosBulkExecutionOptions.java | 12 +-
.../CosmosChangeFeedRequestOptions.java | 13 +-
.../models/CosmosItemRequestOptions.java | 15 +-
.../models/CosmosQueryRequestOptions.java | 16 +
15 files changed, 784 insertions(+), 22 deletions(-)
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
new file mode 100644
index 000000000000..3c95c8bd2687
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
@@ -0,0 +1,170 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.models.CosmosBatchRequestOptions;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.FeedRange;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
+ *
+ * These tests verify the public API surface: builder fluent methods, getter behavior,
+ * null/empty handling, and that setHeader() is publicly accessible on all request options classes.
+ */
+public class CustomHeadersTests {
+
+ /**
+ * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders()
+ * are stored correctly and retrievable via getCustomHeaders().
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersSetOnBuilder() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "25");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25");
+ }
+
+ /**
+ * Verifies that passing null to customHeaders() does not throw and that
+ * getCustomHeaders() returns null, ensuring graceful null handling.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersNullHandledGracefully() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(null);
+
+ assertThat(builder.getCustomHeaders()).isNull();
+ }
+
+ /**
+ * Verifies that passing an empty map to customHeaders() is accepted and
+ * getCustomHeaders() returns an empty (not null) map.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersEmptyMapHandled() {
+ Map emptyHeaders = new HashMap<>();
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(emptyHeaders);
+
+ assertThat(builder.getCustomHeaders()).isEmpty();
+ }
+
+ /**
+ * Verifies that multiple custom headers can be set at once on the builder and
+ * all entries are preserved and retrievable with correct keys and values.
+ */
+ @Test(groups = { "unit" })
+ public void multipleCustomHeadersSupported() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "15");
+ headers.put("x-ms-custom-header", "value");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).hasSize(2);
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15");
+ assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value");
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on CRUD operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnItemRequestOptionsIsPublic() {
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "15");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on batch operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBatchRequestOptionsIsPublic() {
+ CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "20");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on change feed operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnChangeFeedRequestOptionsIsPublic() {
+ CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
+ .createForProcessingFromBeginning(FeedRange.forFullRange())
+ .setHeader("x-ms-cosmos-workload-id", "25");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions
+ * (previously package-private) and supports fluent chaining for per-request
+ * header overrides on bulk ingestion operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnBulkExecutionOptionsIsPublic() {
+ CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
+ .setHeader("x-ms-cosmos-workload-id", "30");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on query operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnQueryRequestOptionsIsPublic() {
+ CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "35");
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
+ * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
+ * by the Cosmos DB service.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdHttpHeaderConstant() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
new file mode 100644
index 000000000000..9ca123e16160
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.implementation.directconnectivity.rntbd;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import org.testng.annotations.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants.
+ *
+ *
+ * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC),
+ * correct token type (Byte), is not required, and is not in the thin-client ordered header list
+ * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()).
+ */
+public class RntbdWorkloadIdTests {
+
+ /**
+ * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders
+ * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and
+ * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdConstantExists() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader
+ * with the correct wire ID (0x00DC). This ID is used to identify the header in the
+ * binary RNTBD protocol when communicating in Direct mode.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderExists() {
+ // Verify WorkloadId enum value exists with correct ID
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader).isNotNull();
+ assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC);
+ }
+
+ /**
+ * Verifies that the WorkloadId RNTBD header is defined as Byte token type,
+ * consistent with the ThroughputBucket pattern. The workload ID value (1-50)
+ * is encoded as a single byte on the wire.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsByteType() {
+ // Verify WorkloadId is Byte type (same as ThroughputBucket pattern)
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte);
+ }
+
+ /**
+ * Verifies that WorkloadId is not a required RNTBD header. The header is optional —
+ * requests without a workload ID are valid and should not be rejected by the SDK.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdRntbdHeaderIsNotRequired() {
+ // WorkloadId should not be a required header
+ RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId;
+ assertThat(workloadIdHeader.isRequired()).isFalse();
+ }
+
+ /**
+ * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client
+ * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is
+ * excluded from this list and will be auto-encoded in the second pass of
+ * RntbdTokenStream.encode() along with other non-ordered headers.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdNotInThinClientOrderedList() {
+ // WorkloadId should NOT be in thinClientHeadersInOrderList
+ // It will be automatically encoded in the second pass of RntbdTokenStream.encode()
+ assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList)
+ .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId);
+ }
+
+ /**
+ * Verifies that valid workload ID values (1-50) can be parsed from String to int
+ * and cast to byte without data loss. Note: the SDK itself does not validate the
+ * range — this test confirms the encoding path works for expected values.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdValidValues() {
+ // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly
+ String[] validValues = {"1", "25", "50"};
+ for (String value : validValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ assertThat(byteVal).isBetween((byte) 1, (byte) 50);
+ }
+ }
+
+ /**
+ * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause
+ * exceptions in the SDK's parsing path. The SDK intentionally does not validate
+ * the range — invalid values are accepted and sent to the service, which silently
+ * ignores them.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdInvalidValuesAcceptedBySdk() {
+ // SDK does NOT validate range — service silently ignores invalid values
+ // These should not throw exceptions in SDK
+ String[] invalidValues = {"0", "51", "-1", "100"};
+ for (String value : invalidValues) {
+ int parsed = Integer.parseInt(value);
+ byte byteVal = (byte) parsed;
+ // SDK accepts any integer value that fits in a byte
+ assertThat(byteVal).isNotNull();
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
new file mode 100644
index 000000000000..85b87090bb36
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -0,0 +1,327 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.rx;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosAsyncDatabase;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.TestObject;
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.implementation.TestConfigurations;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosBulkOperations;
+import com.azure.cosmos.models.CosmosContainerProperties;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosItemResponse;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.PartitionKey;
+import com.azure.cosmos.models.PartitionKeyDefinition;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * End-to-end integration tests for the custom headers / workload-id feature.
+ *
+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ * These tests create a real database and container, then execute CRUD and query operations
+ * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level.
+ *
+ * What is verified:
+ * 1. CRUD operations succeed with client-level custom headers (workload-id)
+ * 2. Per-request header overrides work via setHeader()
+ * 3. Client with no custom headers continues to work (no regression)
+ * 4. Query operations succeed with workload-id
+ * 5. Empty headers and multiple headers are handled correctly
+ *
+
+ */
+public class WorkloadIdE2ETests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ public WorkloadIdE2ETests() {
+ super(new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY));
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ clientWithWorkloadId = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(headers)
+ .buildAsyncClient();
+
+ database = createDatabase(clientWithWorkloadId, DATABASE_ID);
+
+ PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition();
+ ArrayList paths = new ArrayList<>();
+ paths.add("/mypk");
+ partitionKeyDef.setPaths(paths);
+ CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef);
+ database.createContainer(containerProperties).block();
+ container = database.getContainer(CONTAINER_ID);
+ }
+
+ /**
+ * Smoke test: verifies that a create (POST) operation succeeds when the client
+ * has a workload-id custom header set at the builder level. Confirms the header
+ * flows through the request pipeline without causing errors.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void createItemWithClientLevelWorkloadId() {
+ // Smoke test: verify create operation succeeds with client-level workload-id header
+ TestObject doc = TestObject.create();
+
+ CosmosItemResponse response = container
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ }
+
+ /**
+ * Verifies that a read (GET) operation succeeds with the client-level workload-id
+ * header and that the correct document is returned. Ensures the header does not
+ * interfere with normal read semantics.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void readItemWithClientLevelWorkloadId() {
+ // Verify read operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse response = container
+ .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class)
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ assertThat(response.getItem().getId()).isEqualTo(doc.getId());
+ }
+
+ /**
+ * Verifies that a replace (PUT) operation succeeds with the client-level workload-id
+ * header. Confirms the header propagates correctly for update operations.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void replaceItemWithClientLevelWorkloadId() {
+ // Verify replace operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ doc.setStringProp("updated-" + UUID.randomUUID());
+ CosmosItemResponse response = container
+ .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(200);
+ }
+
+ /**
+ * Verifies that a delete operation succeeds with the client-level workload-id header
+ * and returns the expected 204 No Content status code.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void deleteItemWithClientLevelWorkloadId() {
+ // Verify delete operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosItemResponse response = container
+ .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(204);
+ }
+
+ /**
+ * Verifies that a per-request workload-id header override via
+ * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header
+ * (value "30") should take precedence over the client-level default (value "15").
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void createItemWithRequestLevelWorkloadIdOverride() {
+ // Verify per-request header override works — request-level should take precedence
+ TestObject doc = TestObject.create();
+
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30");
+
+ CosmosItemResponse response = container
+ .createItem(doc, new PartitionKey(doc.getMypk()), options)
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ }
+
+ /**
+ * Verifies that a cross-partition query operation succeeds when the client has a
+ * workload-id custom header. Confirms the header flows correctly through the
+ * query pipeline and does not affect result correctness.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void queryItemsWithClientLevelWorkloadId() {
+ // Verify query operation succeeds with workload-id header
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions();
+ long count = container
+ .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class)
+ .collectList()
+ .block()
+ .size();
+
+ assertThat(count).isGreaterThanOrEqualTo(1);
+ }
+
+ /**
+ * Verifies that a per-request workload-id header override on
+ * {@code CosmosQueryRequestOptions.setHeader()} works for query operations.
+ * The request-level header (value "42") should take precedence over the
+ * client-level default.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void queryItemsWithRequestLevelWorkloadIdOverride() {
+ // Verify per-request header override on query options works
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions()
+ .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+
+ long count = container
+ .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class)
+ .collectList()
+ .block()
+ .size();
+
+ assertThat(count).isGreaterThanOrEqualTo(1);
+ }
+
+ /**
+ * Regression test: verifies that a client created without any custom headers
+ * continues to work normally. Ensures the custom headers feature does not
+ * introduce regressions for clients that do not use it.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithNoCustomHeadersStillWorks() {
+ // Verify that a client without custom headers works normally (no regression)
+ CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithoutHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithoutHeaders);
+ }
+ }
+
+ /**
+ * Verifies that a client created with an empty custom headers map works normally.
+ * An empty map should behave identically to no custom headers — no errors,
+ * no unexpected behavior.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithEmptyCustomHeaders() {
+ // Verify that a client with empty custom headers map works normally
+ CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(new HashMap<>())
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithEmptyHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithEmptyHeaders);
+ }
+ }
+
+ /**
+ * Verifies that a client can be configured with multiple custom headers simultaneously
+ * (workload-id plus an additional custom header). Confirms that all headers flow
+ * through the pipeline without interfering with each other.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void clientWithMultipleCustomHeaders() {
+ // Verify that multiple custom headers can be set simultaneously
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20");
+ headers.put("x-ms-custom-test-header", "test-value");
+
+ CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder()
+ .endpoint(TestConfigurations.HOST)
+ .key(TestConfigurations.MASTER_KEY)
+ .customHeaders(headers)
+ .buildAsyncClient();
+
+ try {
+ CosmosAsyncContainer c = clientWithMultipleHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ CosmosItemResponse response = c
+ .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
+ .block();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getStatusCode()).isEqualTo(201);
+ } finally {
+ safeClose(clientWithMultipleHeaders);
+ }
+ }
+
+ @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
+ public void afterClass() {
+ safeDeleteDatabase(database);
+ safeClose(clientWithWorkloadId);
+ }
+}
+
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
index ec0dd64af008..f54f44482db5 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
@@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable {
.withDefaultSerializer(this.defaultCustomSerializer)
.withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled())
.withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled())
+ .withCustomHeaders(builder.getCustomHeaders())
.build();
this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index 12d022e69ee7..aea282be566c 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -37,6 +37,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
@@ -155,6 +156,7 @@ public class CosmosClientBuilder implements
private boolean serverCertValidationDisabled = false;
private Function containerFactory = null;
+ private Map customHeaders;
/**
* Instantiates a new Cosmos client builder.
@@ -734,6 +736,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
return this;
}
+ /**
+ * Sets custom HTTP headers that will be included with every request from this client.
+ *
+ * These headers are sent with all requests. For Direct/RNTBD mode, only known headers
+ * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers
+ * work only in Gateway mode.
+ *
+ * If the same header is also set on request options (e.g.,
+ * {@code CosmosItemRequestOptions.setHeader(String, String)}),
+ * the request-level value takes precedence over the client-level value.
+ *
+ * @param customHeaders map of header name to value
+ * @return current CosmosClientBuilder
+ */
+ public CosmosClientBuilder customHeaders(Map customHeaders) {
+ this.customHeaders = customHeaders;
+ return this;
+ }
+
+ /**
+ * Gets the custom headers configured on this builder.
+ * @return the custom headers map, or null if not set
+ */
+ Map getCustomHeaders() {
+ return this.customHeaders;
+ }
+
/**
* Sets the retry policy options associated with the DocumentClient instance.
*
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
index 03590c1f8a5d..7953721019c5 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
@@ -116,6 +116,7 @@ class Builder {
private boolean isRegionScopedSessionCapturingEnabled;
private boolean isPerPartitionAutomaticFailoverEnabled;
private List operationPolicies;
+ private Map customHeaders;
public Builder withServiceEndpoint(String serviceEndpoint) {
try {
@@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu
return this;
}
+ public Builder withCustomHeaders(Map customHeaders) {
+ this.customHeaders = customHeaders;
+ return this;
+ }
+
private void ifThrowIllegalArgException(boolean value, String error) {
if (value) {
throw new IllegalArgumentException(error);
@@ -328,7 +334,8 @@ public AsyncDocumentClient build() {
defaultCustomSerializer,
isRegionScopedSessionCapturingEnabled,
operationPolicies,
- isPerPartitionAutomaticFailoverEnabled);
+ isPerPartitionAutomaticFailoverEnabled,
+ customHeaders);
client.init(state, null);
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
index 4e283defbc1d..32378ef0cc8d 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java
@@ -298,6 +298,9 @@ public static class HttpHeaders {
// Region affinity headers
public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only";
+
+ // Workload ID header for Azure Monitor metrics attribution
+ public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id";
}
public static class A_IMHeaderValues {
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
index 192f8175978f..2f0bd4271d86 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
@@ -293,6 +293,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization
private final AtomicReference cachedCosmosAsyncClientSnapshot;
private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads;
private Consumer perPartitionFailoverConfigModifier;
+ private Map customHeaders;
public RxDocumentClientImpl(URI serviceEndpoint,
String masterKeyOrResourceToken,
@@ -366,6 +367,60 @@ public RxDocumentClientImpl(URI serviceEndpoint,
boolean isRegionScopedSessionCapturingEnabled,
List operationPolicies,
boolean isPerPartitionAutomaticFailoverEnabled) {
+ this(
+ serviceEndpoint,
+ masterKeyOrResourceToken,
+ permissionFeed,
+ connectionPolicy,
+ consistencyLevel,
+ readConsistencyStrategy,
+ configs,
+ cosmosAuthorizationTokenResolver,
+ credential,
+ tokenCredential,
+ sessionCapturingOverride,
+ connectionSharingAcrossClientsEnabled,
+ contentResponseOnWriteEnabled,
+ metadataCachesSnapshot,
+ apiType,
+ clientTelemetryConfig,
+ clientCorrelationId,
+ cosmosEndToEndOperationLatencyPolicyConfig,
+ sessionRetryOptions,
+ containerProactiveInitConfig,
+ defaultCustomSerializer,
+ isRegionScopedSessionCapturingEnabled,
+ operationPolicies,
+ isPerPartitionAutomaticFailoverEnabled,
+ null
+ );
+ }
+
+ public RxDocumentClientImpl(URI serviceEndpoint,
+ String masterKeyOrResourceToken,
+ List permissionFeed,
+ ConnectionPolicy connectionPolicy,
+ ConsistencyLevel consistencyLevel,
+ ReadConsistencyStrategy readConsistencyStrategy,
+ Configs configs,
+ CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver,
+ AzureKeyCredential credential,
+ TokenCredential tokenCredential,
+ boolean sessionCapturingOverride,
+ boolean connectionSharingAcrossClientsEnabled,
+ boolean contentResponseOnWriteEnabled,
+ CosmosClientMetadataCachesSnapshot metadataCachesSnapshot,
+ ApiType apiType,
+ CosmosClientTelemetryConfig clientTelemetryConfig,
+ String clientCorrelationId,
+ CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig,
+ SessionRetryOptions sessionRetryOptions,
+ CosmosContainerProactiveInitConfig containerProactiveInitConfig,
+ CosmosItemSerializer defaultCustomSerializer,
+ boolean isRegionScopedSessionCapturingEnabled,
+ List operationPolicies,
+ boolean isPerPartitionAutomaticFailoverEnabled,
+ Map customHeaders) {
this(
serviceEndpoint,
masterKeyOrResourceToken,
@@ -392,6 +447,7 @@ public RxDocumentClientImpl(URI serviceEndpoint,
this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver;
this.operationPolicies = operationPolicies;
+ this.customHeaders = customHeaders;
}
private RxDocumentClientImpl(URI serviceEndpoint,
@@ -1884,6 +1940,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten
private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) {
Map headers = new HashMap<>();
+ // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders())
+ if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
+ headers.putAll(this.customHeaders);
+ }
+
if (this.useMultipleWriteLocations) {
headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString());
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
index ba3ec8d2017d..d79231793679 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
@@ -598,7 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader {
PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false),
GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false),
ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false),
- HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false);
+ HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false),
+ WorkloadId((short)0x00DC, RntbdTokenType.Byte, false);
public static final List thinClientHeadersInOrderList = Arrays.asList(
EffectivePartitionKey,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
index 6f6e46ee695d..387e8cf3ed5a 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
@@ -23,6 +23,8 @@
import com.fasterxml.jackson.annotation.JsonFilter;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
@@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream {
// region Fields
+ private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class);
private static final String URL_TRIM = "/";
// endregion
@@ -134,6 +137,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream {
this.addGlobalDatabaseAccountName(headers);
this.addThroughputBucket(headers);
this.addHubRegionProcessingOnly(headers);
+ this.addWorkloadId(headers);
// Normal headers (Strings, Ints, Longs, etc.)
@@ -297,6 +301,8 @@ private RntbdToken getCorrelatedActivityId() {
private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); }
+ private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); }
+
private RntbdToken getGlobalDatabaseAccountName() {
return this.get(RntbdRequestHeader.GlobalDatabaseAccountName);
}
@@ -816,6 +822,19 @@ private void addHubRegionProcessingOnly(final Map headers) {
}
}
+ private void addWorkloadId(final Map headers) {
+ final String value = headers.get(HttpHeaders.WORKLOAD_ID);
+
+ if (StringUtils.isNotEmpty(value)) {
+ try {
+ final int workloadId = Integer.valueOf(value);
+ this.getWorkloadId().setValue((byte) workloadId);
+ } catch (NumberFormatException e) {
+ logger.warn("Invalid value for workload id header: {}", value, e);
+ }
+ }
+ }
+
private void addGlobalDatabaseAccountName(final Map headers)
{
final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME);
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
index 7d5a27324f95..3183fe59bdea 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
@@ -154,14 +154,17 @@ RequestOptions toRequestOptions() {
}
/**
- * Sets the custom batch request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosBatchRequestOptions.
*/
- CosmosBatchRequestOptions setHeader(String name, String value) {
+ public CosmosBatchRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
index f125c02d6725..cd688f8a0da6 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
@@ -257,13 +257,17 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat
}
/**
- * Sets the custom bulk request option value by key
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosBulkExecutionOptions.
*/
- CosmosBulkExecutionOptions setHeader(String name, String value) {
+ public CosmosBulkExecutionOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
index 3ac526de6d63..a1b675f2ffd8 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
@@ -564,14 +564,17 @@ public List getExcludedRegions() {
}
/**
- * Sets the custom change feed request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
*
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosChangeFeedRequestOptions.
*/
- CosmosChangeFeedRequestOptions setHeader(String name, String value) {
+ public CosmosChangeFeedRequestOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
index 72eb108a6428..fbc540e5baeb 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
@@ -566,14 +566,17 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre
}
/**
- * Sets the custom item request option value by key
- *
- * @param name a string representing the custom option's name
- * @param value a string representing the custom option's value
- *
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
* @return the CosmosItemRequestOptions.
*/
- CosmosItemRequestOptions setHeader(String name, String value) {
+ public CosmosItemRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
index 7ead6e208781..f0de81bbf823 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
@@ -260,6 +260,22 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions)
return this;
}
+ /**
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
+ * @return the CosmosQueryRequestOptions.
+ */
+ public CosmosQueryRequestOptions setHeader(String name, String value) {
+ this.actualRequestOptions.setHeader(name, value);
+ return this;
+ }
+
/**
* Gets the list of regions to exclude for the request/retries. These regions are excluded
* from the preferred region list.
From a75ab7b1bd494dfec12ad501b23923a101b1cba4 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 17:57:37 -0600
Subject: [PATCH 02/13] workload-id feature - initial commit(Spark)
---
.../cosmos/spark/CosmosClientCache.scala | 15 +-
.../spark/CosmosClientConfiguration.scala | 8 +-
.../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++-
.../cosmos/spark/CosmosClientCacheITest.scala | 17 +-
.../spark/CosmosClientConfigurationSpec.scala | 68 ++++++++
.../spark/CosmosPartitionPlannerSpec.scala | 24 ++-
.../cosmos/spark/PartitionMetadataSpec.scala | 48 ++++--
.../spark/SparkE2EWorkloadIdITest.scala | 150 ++++++++++++++++++
8 files changed, 323 insertions(+), 38 deletions(-)
create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
index e61a271aeb8b..9ad739d8f3f6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
@@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.function.BiPredicate
import scala.collection.concurrent.TrieMap
-
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
@@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}
+ // Apply custom HTTP headers (e.g., workload-id) to the builder if configured.
+ // These headers are attached to every Cosmos DB request made by this client instance.
+ if (cosmosClientConfiguration.customHeaders.isDefined) {
+ builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava)
+ }
+
var client = builder.buildAsyncClient()
if (cosmosClientConfiguration.clientInterceptors.isDefined) {
@@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Custom HTTP headers are part of the cache key because different workload-ids
+ // should produce different CosmosAsyncClient instances
+ customHeaders: Option[Map[String, String]]
)
private[this] object ClientConfigurationWrapper {
@@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientBuilderInterceptors,
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
- clientConfig.azureMonitorConfig
+ clientConfig.azureMonitorConfig,
+ clientConfig.customHeaders
)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
index 6f4e26e1f503..61fa0957af83 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
@@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration (
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
- azureMonitorConfig: Option[AzureMonitorConfig]
+ azureMonitorConfig: Option[AzureMonitorConfig],
+ // Optional custom HTTP headers (e.g., workload-id) to attach to
+ // all Cosmos DB requests via CosmosClientBuilder.customHeaders()
+ customHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
@@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientBuilderInterceptors,
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
- diagnosticsConfig.azureMonitorConfig
+ diagnosticsConfig.azureMonitorConfig,
+ cosmosAccountConfig.customHeaders
)
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index eef3f6ae1f8d..6646b2e69ae5 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter
import java.time.{Duration, Instant}
import java.util
import java.util.{Locale, ServiceLoader}
+import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import
import scala.collection.concurrent.TrieMap
import scala.collection.immutable.{HashSet, List, Map}
import scala.collection.mutable
@@ -150,6 +151,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
+ // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
+ // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
+ // Flows through to CosmosClientBuilder.customHeaders().
+ val CustomHeaders = "spark.cosmos.customHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
@@ -295,7 +300,8 @@ private[spark] object CosmosConfigNames {
WriteOnRetryCommitInterceptor,
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
- WriteMaxRetryNoProgressIntervalInSeconds
+ WriteMaxRetryNoProgressIntervalInSeconds,
+ CustomHeaders
)
def validateConfigName(name: String): Unit = {
@@ -538,7 +544,10 @@ private case class CosmosAccountConfig(endpoint: String,
resourceGroupName: Option[String],
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
- clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
+ // Optional custom HTTP headers (e.g., workload-id) parsed from
+ // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder
+ customHeaders: Option[Map[String, String]]
)
private object CosmosAccountConfig extends BasicLoggingTrait {
@@ -719,6 +728,19 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN,
helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.")
+ // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
+ // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
+ // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache.
+ private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]](
+ key = CosmosConfigNames.CustomHeaders,
+ mandatory = false,
+ parseFromStringFunction = headersJson => {
+ val mapper = new com.fasterxml.jackson.databind.ObjectMapper()
+ val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
+ mapper.readValue(headersJson, typeRef).asScala.toMap
+ },
+ helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
+
private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
try {
@@ -753,6 +775,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
+ // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
+ val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig)
val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
@@ -864,7 +888,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
resourceGroupNameOpt,
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
- if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None })
+ if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
+ customHeaders)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
index ccf36791dc96..4d542c44612e 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
@@ -64,7 +64,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -91,7 +92,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -118,7 +120,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
),
(
@@ -145,7 +148,8 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
)
)
@@ -179,8 +183,9 @@ class CosmosClientCacheITest
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
- )
+ azureMonitorConfig = None,
+ customHeaders = None
+ )
logInfo(s"TestCase: {$testCaseName}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index 7fcc601ba016..a0627c0cf3dd 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -408,4 +408,72 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ")
configuration.azureMonitorConfig shouldEqual None
}
+
+ // Verifies that the spark.cosmos.customHeaders configuration option correctly parses
+ // a JSON string containing a single workload-id header into a Map[String, String] on
+ // CosmosClientConfiguration. This is the primary use case for the workload-id feature.
+ it should "parse customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
+ }
+
+ // Verifies that when spark.cosmos.customHeaders is not specified in the config map,
+ // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility —
+ // existing Spark jobs that don't set customHeaders continue to work without changes.
+ it should "handle missing customHeaders" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe None
+ }
+
+ // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing
+ // multiple headers into a Map with all entries preserved. This supports use cases where
+ // multiple custom headers need to be sent alongside workload-id.
+ it should "parse multiple custom headers" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get should have size 2
+ configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20"
+ configuration.customHeaders.get("x-custom-header") shouldEqual "value"
+ }
+
+ // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
+ // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
+ // and that no headers are injected when the JSON object is empty.
+ it should "handle empty customHeaders JSON" in {
+ val userConfig = Map(
+ "spark.cosmos.accountEndpoint" -> "https://localhost:8081",
+ "spark.cosmos.accountKey" -> "xyz",
+ "spark.cosmos.customHeaders" -> "{}"
+ )
+
+ val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
+ val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+
+ configuration.customHeaders shouldBe defined
+ configuration.customHeaders.get shouldBe empty
+ }
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
index 6ef90b55989d..ab73dc4e54d3 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
@@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
index dfd14c36c80f..65274bee2b19 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
@@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec {
clientBuilderInterceptors = None,
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
- azureMonitorConfig = None
+ azureMonitorConfig = None,
+ customHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
new file mode 100644
index 000000000000..d9706d0709e5
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
@@ -0,0 +1,150 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.TestConfigurations
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.databind.node.ObjectNode
+
+import java.util.UUID
+
+/**
+ * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector.
+ *
+ * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows
+ * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that
+ * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations
+ * initiated via Spark DataFrames (reads and writes).
+ *
+ * Requires the Cosmos DB Emulator running
+ */
+class SparkE2EWorkloadIdITest
+ extends IntegrationSpec
+ with Spark
+ with CosmosClient
+ with AutoCleanableCosmosContainer
+ with BasicLoggingTrait {
+
+ val objectMapper = new ObjectMapper()
+
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off magic.number
+ //scalastyle:off null
+
+ // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The header should be passed through to the
+ // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior.
+ "spark query with customHeaders" can "read items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""",
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+
+ val item = rowsArray(0)
+ item.getAs[String]("id") shouldEqual id
+ }
+
+ // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders
+ // is configured with a workload-id header. The item is written via Spark and then verified
+ // via a direct SDK read to confirm the write was persisted correctly.
+ "spark write with customHeaders" can "write items with workload-id header" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "testWriteItem"
+ | }
+ |""".stripMargin
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""",
+ "spark.cosmos.write.strategy" -> "ItemOverwrite",
+ "spark.cosmos.write.bulk.enabled" -> "false",
+ "spark.cosmos.serialization.inclusionMode" -> "NonDefault"
+ )
+
+ val spark_session = spark
+ import spark_session.implicits._
+ val df = spark.read.json(Seq(rawItem).toDS())
+
+ df.write.format("cosmos.oltp").options(cfg).mode("Append").save()
+
+ // Verify the item was written by reading it back via the SDK directly
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block()
+ readItem.getItem.get("id").textValue() shouldEqual id
+ readItem.getItem.get("name").textValue() shouldEqual "testWriteItem"
+ }
+
+ // Regression test: verifies that Spark read operations continue to work correctly when
+ // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not
+ // break existing behavior for clients that do not use custom headers.
+ "spark operations without customHeaders" can "still succeed" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "name" : "noHeadersItem"
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map(
+ "spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where(s"id = '$id'").collect()
+ rowsArray should have size 1
+ rowsArray(0).getAs[String]("id") shouldEqual id
+ }
+
+ //scalastyle:on magic.number
+ //scalastyle:on multiple.string.literals
+ //scalastyle:on null
+}
From 1bbf6b9bb962fec1f576fbf3a4b0863a0a132099 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 19:08:04 -0600
Subject: [PATCH 03/13] workload-id feature - cleaning up comments
---
.../com/azure/cosmos/rx/WorkloadIdE2ETests.java | 14 +-------------
1 file changed, 1 insertion(+), 13 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
index 85b87090bb36..a57b4d9d9a0b 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -32,17 +32,6 @@
* End-to-end integration tests for the custom headers / workload-id feature.
*
* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
- * These tests create a real database and container, then execute CRUD and query operations
- * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level.
- *
- * What is verified:
- * 1. CRUD operations succeed with client-level custom headers (workload-id)
- * 2. Per-request header overrides work via setHeader()
- * 3. Client with no custom headers continues to work (no regression)
- * 4. Query operations succeed with workload-id
- * 5. Empty headers and multiple headers are handled correctly
- *
-
*/
public class WorkloadIdE2ETests extends TestSuiteBase {
@@ -82,13 +71,12 @@ public void beforeClass() {
}
/**
- * Smoke test: verifies that a create (POST) operation succeeds when the client
+ * verifies that a create (POST) operation succeeds when the client
* has a workload-id custom header set at the builder level. Confirms the header
* flows through the request pipeline without causing errors.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void createItemWithClientLevelWorkloadId() {
- // Smoke test: verify create operation succeeds with client-level workload-id header
TestObject doc = TestObject.create();
CosmosItemResponse response = container
From 7aeb8b4e030958f77e14983fc5f0d3475333effb Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 25 Feb 2026 19:20:36 -0600
Subject: [PATCH 04/13] workload-id feature - addressing copilot comments
---
sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 1 +
sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 +
.../directconnectivity/rntbd/RntbdConstants.java | 4 ++--
.../directconnectivity/rntbd/RntbdRequestHeaders.java | 2 +-
8 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
index de77b69485a8..2eeccc69dada 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
index 80072357c58f..6e72286dfab6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
index b905a025a1e6..d0d80af466d6 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
index e17cecdfdac2..ae444a9c399f 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
index e9be63ef89bd..ae910280495b 100644
--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
@@ -3,6 +3,7 @@
### 4.44.0-beta.1 (Unreleased)
#### Features Added
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index 8475b83dc5ee..ea4fdb82e1dc 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -4,6 +4,7 @@
#### Features Added
* Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757)
+* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
index d79231793679..d75bf5dc88e1 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java
@@ -598,8 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader {
PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false),
GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false),
ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false),
- HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false),
- WorkloadId((short)0x00DC, RntbdTokenType.Byte, false);
+ WorkloadId((short)0x00DC, RntbdTokenType.Byte, false),
+ HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false);
public static final List thinClientHeadersInOrderList = Arrays.asList(
EffectivePartitionKey,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
index 387e8cf3ed5a..46f8060387fc 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java
@@ -827,7 +827,7 @@ private void addWorkloadId(final Map headers) {
if (StringUtils.isNotEmpty(value)) {
try {
- final int workloadId = Integer.valueOf(value);
+ final int workloadId = Integer.parseInt(value);
this.getWorkloadId().setValue((byte) workloadId);
} catch (NumberFormatException e) {
logger.warn("Invalid value for workload id header: {}", value, e);
From 6d00d49320caa42432aa26725befe529e5934ad4 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Mon, 2 Mar 2026 16:38:14 -0600
Subject: [PATCH 05/13] workload-id feature - addressing comments
---
.../com/azure/cosmos/spark/CosmosConfig.scala | 13 +-
.../spark/CosmosClientConfigurationSpec.scala | 13 +-
.../com/azure/cosmos/CustomHeadersTests.java | 104 ++++++++++-
.../RxDocumentClientUnderTest.java | 7 +-
.../RxGatewayStoreModelTest.java | 172 +++++++++++++++++-
.../SpyClientUnderTestFactory.java | 7 +-
.../GatewayAddressCacheTest.java | 125 +++++++++++++
.../GlobalAddressResolverTest.java | 3 +-
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 62 ++-----
.../com/azure/cosmos/CosmosClientBuilder.java | 50 ++++-
.../implementation/RxDocumentClientImpl.java | 12 +-
.../implementation/RxGatewayStoreModel.java | 17 +-
.../implementation/ThinClientStoreModel.java | 3 +-
.../GatewayAddressCache.java | 17 +-
.../GlobalAddressResolver.java | 8 +-
.../models/CosmosReadManyRequestOptions.java | 16 ++
16 files changed, 546 insertions(+), 83 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index 6646b2e69ae5..62802d23b14c 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark}
import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants
import com.azure.cosmos.implementation.routing.LocationHelper
-import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings}
+import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils}
import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition}
import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode
import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime}
@@ -735,9 +735,14 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
key = CosmosConfigNames.CustomHeaders,
mandatory = false,
parseFromStringFunction = headersJson => {
- val mapper = new com.fasterxml.jackson.databind.ObjectMapper()
- val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
- mapper.readValue(headersJson, typeRef).asScala.toMap
+ try {
+ val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {}
+ Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap
+ } catch {
+ case e: Exception => throw new IllegalArgumentException(
+ s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " +
+ "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
+ }
},
helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index a0627c0cf3dd..377425189f07 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -441,10 +441,11 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.customHeaders shouldBe None
}
- // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing
- // multiple headers into a Map with all entries preserved. This supports use cases where
- // multiple custom headers need to be sent alongside workload-id.
- it should "parse multiple custom headers" in {
+ // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level.
+ // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD),
+ // unknown headers are silently dropped, so the allowlist ensures consistent behavior
+ // across Gateway and Direct modes.
+ it should "reject unknown custom headers" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
@@ -454,10 +455,10 @@ class CosmosClientConfigurationSpec extends UnitSpec {
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
+ // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is.
+ // The allowlist validation happens later in CosmosClientBuilder.customHeaders()
configuration.customHeaders shouldBe defined
configuration.customHeaders.get should have size 2
- configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20"
- configuration.customHeaders.get("x-custom-header") shouldEqual "value"
}
// Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
index 3c95c8bd2687..19eb03744d1a 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
@@ -9,6 +9,7 @@
import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.CosmosReadManyRequestOptions;
import com.azure.cosmos.models.FeedRange;
import org.testng.annotations.Test;
@@ -16,6 +17,7 @@
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
@@ -73,23 +75,40 @@ public void customHeadersEmptyMapHandled() {
}
/**
- * Verifies that multiple custom headers can be set at once on the builder and
- * all entries are preserved and retrievable with correct keys and values.
+ * Verifies that headers not in the allowlist are rejected with IllegalArgumentException.
+ * This ensures consistent behavior across Gateway and Direct modes — only headers with
+ * RNTBD encoding support are allowed.
*/
@Test(groups = { "unit" })
- public void multipleCustomHeadersSupported() {
+ public void unknownHeaderRejectedByAllowlist() {
Map headers = new HashMap<>();
- headers.put("x-ms-cosmos-workload-id", "15");
headers.put("x-ms-custom-header", "value");
- CosmosClientBuilder builder = new CosmosClientBuilder()
+ assertThatThrownBy(() -> new CosmosClientBuilder()
.endpoint("https://test.documents.azure.com:443/")
.key("dGVzdEtleQ==")
- .customHeaders(headers);
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header")
+ .hasMessageContaining("not allowed");
+ }
+
+ /**
+ * Verifies that a map containing both an allowed header and a disallowed header
+ * is rejected — the entire map must pass the allowlist check.
+ */
+ @Test(groups = { "unit" })
+ public void mixedAllowedAndDisallowedHeadersRejected() {
+ Map headers = new HashMap<>();
+ headers.put("x-ms-cosmos-workload-id", "15");
+ headers.put("x-ms-custom-header", "value");
- assertThat(builder.getCustomHeaders()).hasSize(2);
- assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15");
- assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value");
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header");
}
/**
@@ -158,6 +177,19 @@ public void setHeaderOnQueryRequestOptionsIsPublic() {
assertThat(options).isNotNull();
}
+ /**
+ * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions
+ * is publicly accessible and supports fluent chaining for per-request header
+ * overrides on read-many operations.
+ */
+ @Test(groups = { "unit" })
+ public void setHeaderOnReadManyRequestOptionsIsPublic() {
+ CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
+ .setHeader("x-ms-cosmos-workload-id", "40");
+
+ assertThat(options).isNotNull();
+ }
+
/**
* Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
* with the correct canonical header name "x-ms-cosmos-workload-id" as expected
@@ -167,4 +199,58 @@ public void setHeaderOnQueryRequestOptionsIsPublic() {
public void workloadIdHttpHeaderConstant() {
assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
}
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at builder level with
+ * IllegalArgumentException. This covers both Gateway and Direct modes consistently
+ * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
+ * Range validation [1, 50] is the backend's responsibility — the SDK only validates
+ * that the value is a valid integer. This avoids hardcoding a range the backend team
+ * might change in the future.
+ */
+ @Test(groups = { "unit" })
+ public void outOfRangeWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+ }
+
+ /**
+ * Verifies that a valid workload-id value passes builder validation.
+ */
+ @Test(groups = { "unit" })
+ public void validWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .customHeaders(headers);
+
+ assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+ }
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
index a9f5cb35549c..d5f8b92ac7a6 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
@@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import static org.mockito.Mockito.doAnswer;
@@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy(
GlobalEndpointManager globalEndpointManager,
GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker,
HttpClient rxOrigClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
origHttpClient = rxOrigClient;
spyHttpClient = Mockito.spy(rxOrigClient);
@@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy(
userAgentContainer,
globalEndpointManager,
spyHttpClient,
- apiType);
+ apiType,
+ customHeaders);
}
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
index 54440ecfabc5..587844f4043a 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
@@ -27,6 +27,8 @@
import java.net.SocketException;
import java.net.URI;
import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
@@ -102,6 +104,7 @@ public void readTimeout() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception {
userAgentContainer,
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -205,7 +209,8 @@ public void applySessionToken(
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
@@ -277,7 +282,8 @@ public void validateApiType() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ null);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
@@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception {
new UserAgentContainer(),
globalEndpointManager,
httpClient,
+ null,
null);
storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader);
@@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception {
return false;
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are injected into
+ * outgoing HTTP requests by performRequest(). This covers metadata requests
+ * (collection cache, partition key range) that don't go through getRequestHeaders().
+ */
+ @Test(groups = "unit")
+ public void customHeadersInjectedInPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25");
+ }
+
+ /**
+ * Verifies that request-level headers take precedence over client-level customHeaders.
+ * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest()
+ * should NOT overwrite it.
+ */
+ @Test(groups = "unit")
+ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10");
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ customHeaders);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col/docs/doc1",
+ ResourceType.Document);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ // Simulate request-level header already set (e.g., by getRequestHeaders())
+ dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // Request-level header "42" should win over client-level "10"
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, performRequest() still works normally
+ * without injecting any extra headers.
+ */
+ @Test(groups = "unit")
+ public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception {
+ DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
+
+ RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
+ clientContext,
+ sessionContainer,
+ ConsistencyLevel.SESSION,
+ QueryCompatibilityMode.Default,
+ new UserAgentContainer(),
+ globalEndpointManager,
+ httpClient,
+ null,
+ null);
+
+ RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
+ clientContext,
+ OperationType.Read,
+ "/dbs/db/colls/col",
+ ResourceType.DocumentCollection);
+ dsr.requestContext = new DocumentServiceRequestContext();
+ dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost"));
+
+ try {
+ storeModel.performRequest(dsr).block();
+ fail("Request should fail");
+ } catch (Exception e) {
+ // expected
+ }
+
+ Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any());
+ HttpRequest httpRequest = httpClientRequestCaptor.getValue();
+ HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest);
+ // No workload-id header should be present
+ assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull();
+ }
+
enum SessionTokenType {
NONE, // no session token applied
USER, // userControlled session token
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
index b06d6f89b8e9..775b74785630 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
@@ -25,6 +25,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient rxClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.origRxGatewayStoreModel = super.createRxGatewayProxy(
sessionContainer,
consistencyLevel,
@@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
userAgentContainer,
globalEndpointManager,
rxClient,
- apiType);
+ apiType,
+ customHeaders);
this.requests = Collections.synchronizedList(new ArrayList<>());
this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel);
this.initRequestCapture();
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
index 172c00f799bc..9b938d0a1520 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
@@ -57,6 +57,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
@@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds,
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
for (int i = 0; i < 2; i++) {
@@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh(
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
String collectionRid = createdCollection.getResourceId();
@@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti
null,
null,
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
RxDocumentServiceRequest req =
@@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh()
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
GatewayAddressCache spyCache = Mockito.spy(origCache);
@@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception {
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
RxDocumentServiceRequest req =
@@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t
null,
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock,
+ null,
null);
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask);
@@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti
null,
ConnectionPolicy.getDefaultPolicy(),
null,
+ null,
null);
// connected status
@@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs
return new HttpClientUnderTestWrapper(origHttpClient);
}
+ /**
+ * Verifies that client-level customHeaders (e.g., workload-id) are included in
+ * GatewayAddressCache's defaultRequestHeaders, which are sent on every address
+ * resolution request.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.)
+ * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers
+ * set before customHeaders are preserved.
+ */
+ @Test(groups = { "unit" })
+ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ Map customHeaders = new HashMap<>();
+ customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent");
+ customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version");
+ customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ customHeaders);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // SDK headers should NOT be overwritten
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent");
+ assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION);
+ // Custom header should still be added
+ assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ }
+
+ /**
+ * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders
+ * contains only SDK system headers and no extra entries.
+ */
+ @Test(groups = { "unit" })
+ public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception {
+ URI serviceEndpoint = new URI("https://localhost");
+
+ GatewayAddressCache cache = new GatewayAddressCache(
+ mockDiagnosticsClientContext(),
+ serviceEndpoint,
+ Protocol.HTTPS,
+ Mockito.mock(IAuthorizationTokenProvider.class),
+ null,
+ Mockito.mock(HttpClient.class),
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
+ defaultRequestHeadersField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache);
+
+ // Should only contain SDK system headers, no workload-id
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT);
+ assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION);
+ assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID);
+ }
+
public String getNameBasedCollectionLink() {
return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId();
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
index 331be53cc7af..5879e7d3e61c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java
@@ -110,7 +110,7 @@ public void resolveAsync() throws Exception {
GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider,
userAgentContainer,
- serviceConfigReader, connectionPolicy, null);
+ serviceConfigReader, connectionPolicy, null, null);
RxDocumentServiceRequest request;
request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(),
OperationType.Read,
@@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() {
userAgentContainer,
serviceConfigReader,
connectionPolicy,
+ null,
null);
GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache();
GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class);
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
index a57b4d9d9a0b..3bf2fdafce7c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -8,9 +8,6 @@
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.TestObject;
import com.azure.cosmos.implementation.HttpConstants;
-import com.azure.cosmos.implementation.TestConfigurations;
-import com.azure.cosmos.models.CosmosBulkExecutionOptions;
-import com.azure.cosmos.models.CosmosBulkOperations;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosItemResponse;
@@ -19,6 +16,7 @@
import com.azure.cosmos.models.PartitionKeyDefinition;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Factory;
import org.testng.annotations.Test;
import java.util.ArrayList;
@@ -32,6 +30,10 @@
* End-to-end integration tests for the custom headers / workload-id feature.
*
* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests
+ * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC),
+ * ensuring the workload-id header is correctly encoded and sent in both transport paths.
*/
public class WorkloadIdE2ETests extends TestSuiteBase {
@@ -42,10 +44,9 @@ public class WorkloadIdE2ETests extends TestSuiteBase {
private CosmosAsyncDatabase database;
private CosmosAsyncContainer container;
- public WorkloadIdE2ETests() {
- super(new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY));
+ @Factory(dataProvider = "simpleClientBuilderGatewaySession")
+ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
}
@BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
@@ -53,9 +54,7 @@ public void beforeClass() {
Map headers = new HashMap<>();
headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
- clientWithWorkloadId = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ clientWithWorkloadId = getClientBuilder()
.customHeaders(headers)
.buildAsyncClient();
@@ -218,9 +217,7 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() {
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void clientWithNoCustomHeadersStillWorks() {
// Verify that a client without custom headers works normally (no regression)
- CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder())
.buildAsyncClient();
try {
@@ -248,9 +245,7 @@ public void clientWithNoCustomHeadersStillWorks() {
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
public void clientWithEmptyCustomHeaders() {
// Verify that a client with empty custom headers map works normally
- CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
+ CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder())
.customHeaders(new HashMap<>())
.buildAsyncClient();
@@ -272,38 +267,19 @@ public void clientWithEmptyCustomHeaders() {
}
/**
- * Verifies that a client can be configured with multiple custom headers simultaneously
- * (workload-id plus an additional custom header). Confirms that all headers flow
- * through the pipeline without interfering with each other.
+ * Verifies that unknown headers in customHeaders are rejected by the allowlist.
+ * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist
+ * ensures consistent behavior across Gateway and Direct modes.
*/
- @Test(groups = { "emulator" }, timeOut = TIMEOUT)
- public void clientWithMultipleCustomHeaders() {
- // Verify that multiple custom headers can be set simultaneously
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class)
+ public void unknownCustomHeadersRejectedByAllowlist() {
Map headers = new HashMap<>();
headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20");
headers.put("x-ms-custom-test-header", "test-value");
- CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder()
- .endpoint(TestConfigurations.HOST)
- .key(TestConfigurations.MASTER_KEY)
- .customHeaders(headers)
- .buildAsyncClient();
-
- try {
- CosmosAsyncContainer c = clientWithMultipleHeaders
- .getDatabase(DATABASE_ID)
- .getContainer(CONTAINER_ID);
-
- TestObject doc = TestObject.create();
- CosmosItemResponse response = c
- .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions())
- .block();
-
- assertThat(response).isNotNull();
- assertThat(response.getStatusCode()).isEqualTo(201);
- } finally {
- safeClose(clientWithMultipleHeaders);
- }
+ // Should throw IllegalArgumentException due to unknown header
+ copyCosmosClientBuilder(getClientBuilder())
+ .customHeaders(headers);
}
@AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index aea282be566c..e4cdf6ca1e30 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -13,6 +13,7 @@
import com.azure.cosmos.implementation.ConnectionPolicy;
import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot;
import com.azure.cosmos.implementation.DiagnosticsProvider;
+import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.Strings;
import com.azure.cosmos.implementation.WriteRetryPolicy;
import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList;
@@ -158,6 +159,20 @@ public class CosmosClientBuilder implements
private Function containerFactory = null;
private Map customHeaders;
+ /**
+ * Allowlist of headers permitted in {@link #customHeaders(Map)}.
+ *
+ * In Direct mode (RNTBD), only headers with explicit encoding support in
+ * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped.
+ * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header
+ * is allowed here, it works in both modes. To add a new allowed header, you must also add
+ * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry +
+ * {@code RntbdRequestHeaders.addXxx()} method).
+ */
+ private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet(
+ new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID))
+ );
+
/**
* Instantiates a new Cosmos client builder.
*/
@@ -739,9 +754,13 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
/**
* Sets custom HTTP headers that will be included with every request from this client.
*
- * These headers are sent with all requests. For Direct/RNTBD mode, only known headers
- * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers
- * work only in Gateway mode.
+ * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is
+ * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw
+ * {@link IllegalArgumentException}.
+ *
+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit
+ * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist
+ * ensures consistent behavior across both Gateway and Direct modes.
*
* If the same header is also set on request options (e.g.,
* {@code CosmosItemRequestOptions.setHeader(String, String)}),
@@ -749,8 +768,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
*
* @param customHeaders map of header name to value
* @return current CosmosClientBuilder
+ * @throws IllegalArgumentException if any header key is not in the allowlist, or if the
+ * workload-id value is not a valid integer
*/
public CosmosClientBuilder customHeaders(Map customHeaders) {
+ if (customHeaders != null) {
+ for (Map.Entry entry : customHeaders.entrySet()) {
+ String key = entry.getKey();
+ String value = entry.getValue();
+
+ if (!ALLOWED_CUSTOM_HEADERS.contains(key)) {
+ throw new IllegalArgumentException(
+ "Header '" + key + "' is not allowed in customHeaders. "
+ + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS);
+ }
+
+ // Validate workload-id value is a valid integer (range validation is left to the backend)
+ if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) {
+ try {
+ Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException(
+ "Invalid value '" + value + "' for header '" + key
+ + "'. The value must be a valid integer.", e);
+ }
+ }
+ }
+ }
this.customHeaders = customHeaders;
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
index 2f0bd4271d86..122542a8810a 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
@@ -863,7 +863,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func
this.userAgentContainer,
this.globalEndpointManager,
this.reactorHttpClient,
- this.apiType);
+ this.apiType,
+ this.customHeaders);
this.thinProxy = createThinProxy(this.sessionContainer,
this.consistencyLevel,
@@ -969,7 +970,8 @@ private void initializeDirectConnectivity() {
// this.gatewayConfigurationReader,
null,
this.connectionPolicy,
- this.apiType);
+ this.apiType,
+ this.customHeaders);
this.storeClientFactory = new StoreClientFactory(
this.addressResolver,
@@ -1013,7 +1015,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
return new RxGatewayStoreModel(
this,
sessionContainer,
@@ -1022,7 +1025,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
userAgentContainer,
globalEndpointManager,
httpClient,
- apiType);
+ apiType,
+ customHeaders);
}
ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer,
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
index 979c528b32bb..a197723f0ae3 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
@@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize
private GatewayServiceConfigurationReader gatewayServiceConfigurationReader;
private RxClientCollectionCache collectionCache;
private GatewayServerErrorInjector gatewayServerErrorInjector;
+ private final Map customHeaders;
public RxGatewayStoreModel(
DiagnosticsClientContext clientContext,
@@ -100,7 +101,8 @@ public RxGatewayStoreModel(
UserAgentContainer userAgentContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.clientContext = clientContext;
@@ -116,6 +118,7 @@ public RxGatewayStoreModel(
this.httpClient = httpClient;
this.sessionContainer = sessionContainer;
+ this.customHeaders = customHeaders;
}
public RxGatewayStoreModel(RxGatewayStoreModel inner) {
@@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) {
this.httpClient = inner.httpClient;
this.sessionContainer = inner.sessionContainer;
+ this.customHeaders = inner.customHeaders;
}
protected Map getDefaultHeaders(
@@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r
request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics();
}
+ // Apply client-level custom headers (e.g., workload-id) to all requests
+ // including metadata requests (collection cache, partition key range, etc.)
+ if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
+ for (Map.Entry entry : this.customHeaders.entrySet()) {
+ // Only set if not already present — request-level headers take precedence
+ if (!request.getHeaders().containsKey(entry.getKey())) {
+ request.getHeaders().put(entry.getKey(), entry.getValue());
+ }
+ }
+ }
+
URI uri = getUri(request);
request.requestContext.resourcePhysicalAddress = uri.toString();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
index d32e5d901f18..ff139e203d2e 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java
@@ -56,7 +56,8 @@ public ThinClientStoreModel(
userAgentContainer,
globalEndpointManager,
httpClient,
- ApiType.SQL);
+ ApiType.SQL,
+ null);
String userAgent = userAgentContainer != null
? userAgentContainer.getUserAgent()
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
index e62d7b8c6ca4..7c761335b782 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
@@ -123,7 +123,8 @@ public GatewayAddressCache(
GlobalEndpointManager globalEndpointManager,
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
- GatewayServerErrorInjector gatewayServerErrorInjector) {
+ GatewayServerErrorInjector gatewayServerErrorInjector,
+ Map customHeaders) {
this.clientContext = clientContext;
try {
@@ -165,6 +166,14 @@ public GatewayAddressCache(
HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES,
HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES);
+ // Apply client-level custom headers (e.g., workload-id) to metadata requests
+ // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten
+ if (customHeaders != null && !customHeaders.isEmpty()) {
+ for (Map.Entry entry : customHeaders.entrySet()) {
+ this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue());
+ }
+ }
+
this.lastForcedRefreshMap = new ConcurrentHashMap<>();
this.globalEndpointManager = globalEndpointManager;
this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor;
@@ -188,7 +197,8 @@ public GatewayAddressCache(
GlobalEndpointManager globalEndpointManager,
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
- GatewayServerErrorInjector gatewayServerErrorInjector) {
+ GatewayServerErrorInjector gatewayServerErrorInjector,
+ Map customHeaders) {
this(clientContext,
serviceEndpoint,
protocol,
@@ -200,7 +210,8 @@ public GatewayAddressCache(
globalEndpointManager,
connectionPolicy,
proactiveOpenConnectionsProcessor,
- gatewayServerErrorInjector);
+ gatewayServerErrorInjector,
+ customHeaders);
}
@Override
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
index 00905682b4d1..2fd5287da028 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
@@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver {
private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor;
private ConnectionPolicy connectionPolicy;
private GatewayServerErrorInjector gatewayServerErrorInjector;
+ private final Map customHeaders;
public GlobalAddressResolver(
DiagnosticsClientContext diagnosticsClientContext,
@@ -74,7 +75,8 @@ public GlobalAddressResolver(
UserAgentContainer userAgentContainer,
GatewayServiceConfigurationReader serviceConfigReader,
ConnectionPolicy connectionPolicy,
- ApiType apiType) {
+ ApiType apiType,
+ Map customHeaders) {
this.diagnosticsClientContext = diagnosticsClientContext;
this.httpClient = httpClient;
this.endpointManager = endpointManager;
@@ -86,6 +88,7 @@ public GlobalAddressResolver(
this.serviceConfigReader = serviceConfigReader;
this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled();
this.connectionPolicy = connectionPolicy;
+ this.customHeaders = customHeaders;
int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0;
this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover)
@@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) {
this.endpointManager,
this.connectionPolicy,
this.proactiveOpenConnectionsProcessor,
- this.gatewayServerErrorInjector);
+ this.gatewayServerErrorInjector,
+ this.customHeaders);
AddressResolver addressResolver = new AddressResolver();
addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache);
EndpointCache cache = new EndpointCache();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
index f6e570258042..de2d769f789b 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
@@ -366,6 +366,22 @@ public Set getKeywordIdentifiers() {
return this.actualRequestOptions.getKeywordIdentifiers();
}
+ /**
+ * Sets a custom header to be included with this specific request.
+ *
+ * This allows per-request header customization, such as setting a workload ID
+ * that overrides the client-level default set via
+ * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ *
+ * @param name the header name (e.g., "x-ms-cosmos-workload-id")
+ * @param value the header value (e.g., "20")
+ * @return the CosmosReadManyRequestOptions.
+ */
+ public CosmosReadManyRequestOptions setHeader(String name, String value) {
+ this.actualRequestOptions.setHeader(name, value);
+ return this;
+ }
+
CosmosQueryRequestOptionsBase> getImpl() {
return this.actualRequestOptions;
}
From c218ae3fe578f6ffd605e64aa3570ff410554dc3 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Mon, 9 Mar 2026 16:06:38 -0500
Subject: [PATCH 06/13] fix: addressing comments
---
.../azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 2 +-
.../azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 2 +-
.../azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 2 +-
.../azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 2 +-
.../cosmos/spark/CosmosClientCache.scala | 19 ++--
.../spark/CosmosClientConfiguration.scala | 8 +-
.../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++----
.../cosmos/spark/CosmosClientCacheITest.scala | 10 +--
.../spark/CosmosClientConfigurationSpec.scala | 48 +++++-----
.../spark/CosmosPartitionPlannerSpec.scala | 16 ++--
.../cosmos/spark/PartitionMetadataSpec.scala | 32 +++----
.../azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +-
.../com/azure/cosmos/CosmosAsyncClient.java | 2 +-
.../com/azure/cosmos/CosmosClientBuilder.java | 90 ++++++++-----------
.../implementation/AsyncDocumentClient.java | 8 +-
.../implementation/RxDocumentClientImpl.java | 20 ++---
.../GatewayAddressCache.java | 12 +--
.../GlobalAddressResolver.java | 8 +-
.../models/CosmosBatchRequestOptions.java | 32 +++++--
.../models/CosmosBulkExecutionOptions.java | 32 +++++--
.../CosmosChangeFeedRequestOptions.java | 32 +++++--
.../models/CosmosItemRequestOptions.java | 32 +++++--
.../models/CosmosQueryRequestOptions.java | 32 +++++--
.../models/CosmosReadManyRequestOptions.java | 33 +++++--
25 files changed, 314 insertions(+), 195 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
index 5c00937013ad..6d23a6f92cf2 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
index 705334e4c3a1..49da84cfc8d4 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
index 97bff869bb59..1290801e56f2 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
index d922e4a579d3..8660d2ba7b59 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
index f31024628cb7..3a0174a9a9a8 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
@@ -13,7 +13,7 @@ import com.azure.cosmos.models.{CosmosClientTelemetryConfig, CosmosMetricCategor
import com.azure.cosmos.spark.CosmosPredicates.isOnSparkDriver
import com.azure.cosmos.spark.catalog.{CosmosCatalogClient, CosmosCatalogCosmosSDKClient, CosmosCatalogManagementSDKClient}
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
-import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions}
+import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, CosmosHeaderName, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions}
import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder}
import com.azure.monitor.opentelemetry.autoconfigure.{AzureMonitorAutoConfigure, AzureMonitorAutoConfigureOptions}
import com.azure.resourcemanager.cosmos.CosmosManager
@@ -712,10 +712,15 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
}
}
- // Apply custom HTTP headers (e.g., workload-id) to the builder if configured.
+ // Apply additional HTTP headers (e.g., workload-id) to the builder if configured.
// These headers are attached to every Cosmos DB request made by this client instance.
- if (cosmosClientConfiguration.customHeaders.isDefined) {
- builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava)
+ // Converts Map[String, String] from Spark config to Map[CosmosHeaderName, String] for the builder.
+ if (cosmosClientConfiguration.additionalHeaders.isDefined) {
+ val enumHeaders = new java.util.HashMap[CosmosHeaderName, String]()
+ for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) {
+ enumHeaders.put(CosmosHeaderName.fromString(key), value)
+ }
+ builder.additionalHeaders(enumHeaders)
}
var client = builder.buildAsyncClient()
@@ -922,9 +927,9 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig],
- // Custom HTTP headers are part of the cache key because different workload-ids
+ // Additional HTTP headers are part of the cache key because different workload-ids
// should produce different CosmosAsyncClient instances
- customHeaders: Option[Map[String, String]]
+ additionalHeaders: Option[Map[String, String]]
)
private[this] object ClientConfigurationWrapper {
@@ -944,7 +949,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
clientConfig.clientInterceptors,
clientConfig.sampledDiagnosticsLoggerConfig,
clientConfig.azureMonitorConfig,
- clientConfig.customHeaders
+ clientConfig.additionalHeaders
)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
index 61fa0957af83..7e09c73698a7 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala
@@ -31,9 +31,9 @@ private[spark] case class CosmosClientConfiguration (
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig],
azureMonitorConfig: Option[AzureMonitorConfig],
- // Optional custom HTTP headers (e.g., workload-id) to attach to
- // all Cosmos DB requests via CosmosClientBuilder.customHeaders()
- customHeaders: Option[Map[String, String]]
+ // Optional additional HTTP headers (e.g., workload-id) to attach to
+ // all Cosmos DB requests via CosmosClientBuilder.additionalHeaders()
+ additionalHeaders: Option[Map[String, String]]
) {
private[spark] def getRoleInstanceName(machineId: Option[String]): String = {
CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId)
@@ -98,7 +98,7 @@ private[spark] object CosmosClientConfiguration {
cosmosAccountConfig.clientInterceptors,
diagnosticsConfig.sampledDiagnosticsLoggerConfig,
diagnosticsConfig.azureMonitorConfig,
- cosmosAccountConfig.customHeaders
+ cosmosAccountConfig.additionalHeaders
)
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
index 928b0cd09445..870ee17841c5 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala
@@ -152,10 +152,10 @@ private[spark] object CosmosConfigNames {
val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold"
val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel"
val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket"
- // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
+ // Additional HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance).
// Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"}
- // Flows through to CosmosClientBuilder.customHeaders().
- val CustomHeaders = "spark.cosmos.customHeaders"
+ // Flows through to CosmosClientBuilder.additionalHeaders().
+ val AdditionalHeaders = "spark.cosmos.additionalHeaders"
val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database"
val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container"
val ThroughputControlGlobalControlRenewalIntervalInMS =
@@ -303,7 +303,7 @@ private[spark] object CosmosConfigNames {
WriteFlushCloseIntervalInSeconds,
WriteMaxNoProgressIntervalInSeconds,
WriteMaxRetryNoProgressIntervalInSeconds,
- CustomHeaders
+ AdditionalHeaders
)
def validateConfigName(name: String): Unit = {
@@ -547,9 +547,9 @@ private case class CosmosAccountConfig(endpoint: String,
azureEnvironmentEndpoints: java.util.Map[String, String],
clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]],
clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]],
- // Optional custom HTTP headers (e.g., workload-id) parsed from
- // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder
- customHeaders: Option[Map[String, String]]
+ // Optional additional HTTP headers (e.g., workload-id) parsed from
+ // spark.cosmos.additionalHeaders JSON config, passed to CosmosClientBuilder.additionalHeaders()
+ additionalHeaders: Option[Map[String, String]]
)
private object CosmosAccountConfig extends BasicLoggingTrait {
@@ -738,9 +738,10 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
// Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like
// {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson.
- // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache.
- private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]](
- key = CosmosConfigNames.CustomHeaders,
+ // These headers are converted to Map[CosmosHeaderName, String] and passed to
+ // CosmosClientBuilder.additionalHeaders() in CosmosClientCache.
+ private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]](
+ key = CosmosConfigNames.AdditionalHeaders,
mandatory = false,
parseFromStringFunction = headersJson => {
try {
@@ -748,11 +749,11 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap
} catch {
case e: Exception => throw new IllegalArgumentException(
- s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " +
+ s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " +
"Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e)
}
},
- helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
+ helpMessage = "Optional additional headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}")
private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = {
val result = new java.util.ArrayList[CosmosContainerIdentity]
@@ -788,8 +789,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId)
val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors)
val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors)
- // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
- val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig)
+ // Parse optional additional HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"})
+ val additionalHeaders = CosmosConfigEntry.parse(cfg, AdditionalHeadersConfig)
val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery)
val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList)
@@ -910,7 +911,7 @@ private object CosmosAccountConfig extends BasicLoggingTrait {
azureEnvironmentOpt.get,
if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None },
if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None },
- customHeaders)
+ additionalHeaders)
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
index 4d542c44612e..8e44968a76df 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala
@@ -65,7 +65,7 @@ class CosmosClientCacheITest
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
),
(
@@ -93,7 +93,7 @@ class CosmosClientCacheITest
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
),
(
@@ -121,7 +121,7 @@ class CosmosClientCacheITest
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
),
(
@@ -149,7 +149,7 @@ class CosmosClientCacheITest
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
)
)
@@ -184,7 +184,7 @@ class CosmosClientCacheITest
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
logInfo(s"TestCase: {$testCaseName}")
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
index 377425189f07..8a2b6a8191f9 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala
@@ -409,27 +409,27 @@ class CosmosClientConfigurationSpec extends UnitSpec {
configuration.azureMonitorConfig shouldEqual None
}
- // Verifies that the spark.cosmos.customHeaders configuration option correctly parses
+ // Verifies that the spark.cosmos.additionalHeaders configuration option correctly parses
// a JSON string containing a single workload-id header into a Map[String, String] on
// CosmosClientConfiguration. This is the primary use case for the workload-id feature.
- it should "parse customHeaders JSON" in {
+ it should "parse additionalHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
- "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
+ "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}"""
)
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
- configuration.customHeaders shouldBe defined
- configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
+ configuration.additionalHeaders shouldBe defined
+ configuration.additionalHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15"
}
- // Verifies that when spark.cosmos.customHeaders is not specified in the config map,
- // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility —
- // existing Spark jobs that don't set customHeaders continue to work without changes.
- it should "handle missing customHeaders" in {
+ // Verifies that when spark.cosmos.additionalHeaders is not specified in the config map,
+ // CosmosClientConfiguration.additionalHeaders is None. This ensures backward compatibility —
+ // existing Spark jobs that don't set additionalHeaders continue to work without changes.
+ it should "handle missing additionalHeaders" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz"
@@ -438,43 +438,43 @@ class CosmosClientConfigurationSpec extends UnitSpec {
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
- configuration.customHeaders shouldBe None
+ configuration.additionalHeaders shouldBe None
}
- // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level.
- // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD),
- // unknown headers are silently dropped, so the allowlist ensures consistent behavior
- // across Gateway and Direct modes.
- it should "reject unknown custom headers" in {
+ // Verifies that spark.cosmos.additionalHeaders accepts multiple headers at the parsing level.
+ // The JSON is valid and CosmosClientConfiguration stores it as-is.
+ // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache when
+ // converting to Map[CosmosHeaderName, String] for CosmosClientBuilder.additionalHeaders().
+ it should "reject unknown additional headers" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
- "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
+ "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}"""
)
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
// Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is.
- // The allowlist validation happens later in CosmosClientBuilder.customHeaders()
- configuration.customHeaders shouldBe defined
- configuration.customHeaders.get should have size 2
+ // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache
+ configuration.additionalHeaders shouldBe defined
+ configuration.additionalHeaders.get should have size 2
}
- // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully,
+ // Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully,
// resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases
// and that no headers are injected when the JSON object is empty.
- it should "handle empty customHeaders JSON" in {
+ it should "handle empty additionalHeaders JSON" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhost:8081",
"spark.cosmos.accountKey" -> "xyz",
- "spark.cosmos.customHeaders" -> "{}"
+ "spark.cosmos.additionalHeaders" -> "{}"
)
val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT
val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "")
- configuration.customHeaders shouldBe defined
- configuration.customHeaders.get shouldBe empty
+ configuration.additionalHeaders shouldBe defined
+ configuration.additionalHeaders.get shouldBe empty
}
}
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
index ab73dc4e54d3..f745bae6b56d 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala
@@ -40,7 +40,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -118,7 +118,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -196,7 +196,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -274,7 +274,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -350,7 +350,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -442,7 +442,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -517,7 +517,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -584,7 +584,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
index 65274bee2b19..c17e7b02dad8 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala
@@ -39,7 +39,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -86,7 +86,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -172,7 +172,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -258,7 +258,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -326,7 +326,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -389,7 +389,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -446,7 +446,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -503,7 +503,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -560,7 +560,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -617,7 +617,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -697,7 +697,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -759,7 +759,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -816,7 +816,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -890,7 +890,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -964,7 +964,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
@@ -1043,7 +1043,7 @@ class PartitionMetadataSpec extends UnitSpec {
clientInterceptors = None,
sampledDiagnosticsLoggerConfig = None,
azureMonitorConfig = None,
- customHeaders = None
+ additionalHeaders = None
)
val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None)
diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
index 799d86414ff6..6149df9f05e4 100644
--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index c080bbc9c5c5..424aef4e3a30 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -4,7 +4,7 @@
#### Features Added
* Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757)
-* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
index f54f44482db5..13f526552301 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java
@@ -186,7 +186,7 @@ public final class CosmosAsyncClient implements Closeable {
.withDefaultSerializer(this.defaultCustomSerializer)
.withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled())
.withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled())
- .withCustomHeaders(builder.getCustomHeaders())
+ .withAdditionalHeaders(builder.getAdditionalHeaders())
.build();
this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index e4cdf6ca1e30..129b4b578c58 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -13,7 +13,6 @@
import com.azure.cosmos.implementation.ConnectionPolicy;
import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot;
import com.azure.cosmos.implementation.DiagnosticsProvider;
-import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.implementation.Strings;
import com.azure.cosmos.implementation.WriteRetryPolicy;
import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList;
@@ -34,6 +33,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@@ -157,21 +157,7 @@ public class CosmosClientBuilder implements
private boolean serverCertValidationDisabled = false;
private Function containerFactory = null;
- private Map customHeaders;
-
- /**
- * Allowlist of headers permitted in {@link #customHeaders(Map)}.
- *
- * In Direct mode (RNTBD), only headers with explicit encoding support in
- * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped.
- * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header
- * is allowed here, it works in both modes. To add a new allowed header, you must also add
- * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry +
- * {@code RntbdRequestHeaders.addXxx()} method).
- */
- private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet(
- new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID))
- );
+ private Map additionalHeaders;
/**
* Instantiates a new Cosmos client builder.
@@ -752,59 +738,53 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
}
/**
- * Sets custom HTTP headers that will be included with every request from this client.
+ * Sets additional HTTP headers that will be included with every request from this client.
*
- * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is
- * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw
- * {@link IllegalArgumentException}.
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. Currently
+ * the only supported header is {@link CosmosHeaderName#WORKLOAD_ID}
+ * ({@code x-ms-cosmos-workload-id}).
*
* This restriction exists because in Direct mode (RNTBD), only headers with explicit
- * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist
- * ensures consistent behavior across both Gateway and Direct modes.
+ * encoding support are sent on the wire. The enum ensures consistent behavior across
+ * both Gateway and Direct modes.
*
* If the same header is also set on request options (e.g.,
- * {@code CosmosItemRequestOptions.setHeader(String, String)}),
+ * {@code CosmosItemRequestOptions.setAdditionalHeaders(Map)}),
* the request-level value takes precedence over the client-level value.
*
- * @param customHeaders map of header name to value
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
* @return current CosmosClientBuilder
- * @throws IllegalArgumentException if any header key is not in the allowlist, or if the
- * workload-id value is not a valid integer
- */
- public CosmosClientBuilder customHeaders(Map customHeaders) {
- if (customHeaders != null) {
- for (Map.Entry entry : customHeaders.entrySet()) {
- String key = entry.getKey();
- String value = entry.getValue();
-
- if (!ALLOWED_CUSTOM_HEADERS.contains(key)) {
- throw new IllegalArgumentException(
- "Header '" + key + "' is not allowed in customHeaders. "
- + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS);
- }
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosClientBuilder additionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ this.additionalHeaders = additionalHeaders;
+ return this;
+ }
- // Validate workload-id value is a valid integer (range validation is left to the backend)
- if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) {
- try {
- Integer.parseInt(value);
- } catch (NumberFormatException e) {
- throw new IllegalArgumentException(
- "Invalid value '" + value + "' for header '" + key
- + "'. The value must be a valid integer.", e);
- }
- }
- }
+ /**
+ * Gets the additional headers configured on this builder, converted to a
+ * {@code Map} for internal use.
+ * @return the additional headers map with string keys, or null if not set
+ */
+ Map getAdditionalHeaders() {
+ if (this.additionalHeaders == null) {
+ return null;
}
- this.customHeaders = customHeaders;
- return this;
+ Map result = new HashMap<>();
+ for (Map.Entry entry : this.additionalHeaders.entrySet()) {
+ result.put(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ return result;
}
/**
- * Gets the custom headers configured on this builder.
- * @return the custom headers map, or null if not set
+ * Gets the additional headers configured on this builder in their original
+ * {@code Map} form.
+ * @return the additional headers map, or null if not set
*/
- Map getCustomHeaders() {
- return this.customHeaders;
+ Map getAdditionalHeadersRaw() {
+ return this.additionalHeaders;
}
/**
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
index 7953721019c5..81a8f2826f04 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java
@@ -116,7 +116,7 @@ class Builder {
private boolean isRegionScopedSessionCapturingEnabled;
private boolean isPerPartitionAutomaticFailoverEnabled;
private List operationPolicies;
- private Map customHeaders;
+ private Map additionalHeaders;
public Builder withServiceEndpoint(String serviceEndpoint) {
try {
@@ -289,8 +289,8 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu
return this;
}
- public Builder withCustomHeaders(Map customHeaders) {
- this.customHeaders = customHeaders;
+ public Builder withAdditionalHeaders(Map additionalHeaders) {
+ this.additionalHeaders = additionalHeaders;
return this;
}
@@ -335,7 +335,7 @@ public AsyncDocumentClient build() {
isRegionScopedSessionCapturingEnabled,
operationPolicies,
isPerPartitionAutomaticFailoverEnabled,
- customHeaders);
+ additionalHeaders);
client.init(state, null);
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
index c28400e86973..b13b6c02009c 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java
@@ -294,7 +294,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization
private final AtomicReference cachedCosmosAsyncClientSnapshot;
private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads;
private Consumer perPartitionFailoverConfigModifier;
- private Map customHeaders;
+ private Map additionalHeaders;
public RxDocumentClientImpl(URI serviceEndpoint,
String masterKeyOrResourceToken,
@@ -421,7 +421,7 @@ public RxDocumentClientImpl(URI serviceEndpoint,
boolean isRegionScopedSessionCapturingEnabled,
List operationPolicies,
boolean isPerPartitionAutomaticFailoverEnabled,
- Map customHeaders) {
+ Map additionalHeaders) {
this(
serviceEndpoint,
masterKeyOrResourceToken,
@@ -448,7 +448,7 @@ public RxDocumentClientImpl(URI serviceEndpoint,
this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver;
this.operationPolicies = operationPolicies;
- this.customHeaders = customHeaders;
+ this.additionalHeaders = additionalHeaders;
}
private RxDocumentClientImpl(URI serviceEndpoint,
@@ -865,7 +865,7 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func
this.globalEndpointManager,
this.reactorHttpClient,
this.apiType,
- this.customHeaders);
+ this.additionalHeaders);
this.thinProxy = createThinProxy(this.sessionContainer,
this.consistencyLevel,
@@ -983,7 +983,7 @@ private void initializeDirectConnectivity() {
null,
this.connectionPolicy,
this.apiType,
- this.customHeaders);
+ this.additionalHeaders);
this.storeClientFactory = new StoreClientFactory(
this.addressResolver,
@@ -1028,7 +1028,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
ApiType apiType,
- Map customHeaders) {
+ Map additionalHeaders) {
return new RxGatewayStoreModel(
this,
sessionContainer,
@@ -1038,7 +1038,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
globalEndpointManager,
httpClient,
apiType,
- customHeaders);
+ additionalHeaders);
}
ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer,
@@ -1956,9 +1956,9 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten
private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) {
Map headers = new HashMap<>();
- // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders())
- if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
- headers.putAll(this.customHeaders);
+ // Apply client-level additional headers first (e.g., workload-id from CosmosClientBuilder.additionalHeaders())
+ if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) {
+ headers.putAll(this.additionalHeaders);
}
if (this.useMultipleWriteLocations) {
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
index 7c761335b782..119c7309256d 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
@@ -124,7 +124,7 @@ public GatewayAddressCache(
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
GatewayServerErrorInjector gatewayServerErrorInjector,
- Map customHeaders) {
+ Map additionalHeaders) {
this.clientContext = clientContext;
try {
@@ -166,10 +166,10 @@ public GatewayAddressCache(
HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES,
HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES);
- // Apply client-level custom headers (e.g., workload-id) to metadata requests
+ // Apply client-level additional headers (e.g., workload-id) to metadata requests
// Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten
- if (customHeaders != null && !customHeaders.isEmpty()) {
- for (Map.Entry entry : customHeaders.entrySet()) {
+ if (additionalHeaders != null && !additionalHeaders.isEmpty()) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue());
}
}
@@ -198,7 +198,7 @@ public GatewayAddressCache(
ConnectionPolicy connectionPolicy,
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor,
GatewayServerErrorInjector gatewayServerErrorInjector,
- Map customHeaders) {
+ Map additionalHeaders) {
this(clientContext,
serviceEndpoint,
protocol,
@@ -211,7 +211,7 @@ public GatewayAddressCache(
connectionPolicy,
proactiveOpenConnectionsProcessor,
gatewayServerErrorInjector,
- customHeaders);
+ additionalHeaders);
}
@Override
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
index 2fd5287da028..6dded78b8a07 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java
@@ -62,7 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver {
private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor;
private ConnectionPolicy connectionPolicy;
private GatewayServerErrorInjector gatewayServerErrorInjector;
- private final Map customHeaders;
+ private final Map additionalHeaders;
public GlobalAddressResolver(
DiagnosticsClientContext diagnosticsClientContext,
@@ -76,7 +76,7 @@ public GlobalAddressResolver(
GatewayServiceConfigurationReader serviceConfigReader,
ConnectionPolicy connectionPolicy,
ApiType apiType,
- Map customHeaders) {
+ Map additionalHeaders) {
this.diagnosticsClientContext = diagnosticsClientContext;
this.httpClient = httpClient;
this.endpointManager = endpointManager;
@@ -88,7 +88,7 @@ public GlobalAddressResolver(
this.serviceConfigReader = serviceConfigReader;
this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled();
this.connectionPolicy = connectionPolicy;
- this.customHeaders = customHeaders;
+ this.additionalHeaders = additionalHeaders;
int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0;
this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover)
@@ -294,7 +294,7 @@ private EndpointCache getOrAddEndpoint(URI endpoint) {
this.connectionPolicy,
this.proactiveOpenConnectionsProcessor,
this.gatewayServerErrorInjector,
- this.customHeaders);
+ this.additionalHeaders);
AddressResolver addressResolver = new AddressResolver();
addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache);
EndpointCache cache = new EndpointCache();
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
index 3183fe59bdea..ff3b335c8824 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java
@@ -4,6 +4,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.ConsistencyLevel;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosDiagnosticsThresholds;
import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig;
import com.azure.cosmos.CosmosItemSerializer;
@@ -154,17 +155,38 @@ RequestOptions toRequestOptions() {
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosBatchRequestOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosBatchRequestOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosBatchRequestOptions.
*/
- public CosmosBatchRequestOptions setHeader(String name, String value) {
+ CosmosBatchRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
index cd688f8a0da6..588f74cea99e 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java
@@ -4,6 +4,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosItemSerializer;
import com.azure.cosmos.implementation.CosmosBulkExecutionOptionsImpl;
import com.azure.cosmos.implementation.ImplementationBridgeHelpers;
@@ -257,17 +258,38 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosBulkExecutionOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosBulkExecutionOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosBulkExecutionOptions.
*/
- public CosmosBulkExecutionOptions setHeader(String name, String value) {
+ CosmosBulkExecutionOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
index a1b675f2ffd8..b78499cd185c 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java
@@ -4,6 +4,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.ConsistencyLevel;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosDiagnosticsThresholds;
import com.azure.cosmos.CosmosItemSerializer;
import com.azure.cosmos.ReadConsistencyStrategy;
@@ -564,17 +565,38 @@ public List getExcludedRegions() {
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosChangeFeedRequestOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosChangeFeedRequestOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosChangeFeedRequestOptions.
*/
- public CosmosChangeFeedRequestOptions setHeader(String name, String value) {
+ CosmosChangeFeedRequestOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
index fbc540e5baeb..d5447034274c 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java
@@ -3,6 +3,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.ConsistencyLevel;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.CosmosDiagnosticsThresholds;
import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig;
@@ -566,17 +567,38 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosItemRequestOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosItemRequestOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosItemRequestOptions.
*/
- public CosmosItemRequestOptions setHeader(String name, String value) {
+ CosmosItemRequestOptions setHeader(String name, String value) {
if (this.customOptions == null) {
this.customOptions = new HashMap<>();
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
index f0de81bbf823..ae998a39c360 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java
@@ -4,6 +4,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.ConsistencyLevel;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosDiagnostics;
import com.azure.cosmos.CosmosDiagnosticsThresholds;
import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig;
@@ -261,17 +262,38 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions)
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosQueryRequestOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosQueryRequestOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosQueryRequestOptions.
*/
- public CosmosQueryRequestOptions setHeader(String name, String value) {
+ CosmosQueryRequestOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
index de2d769f789b..98505acdab98 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java
@@ -4,6 +4,7 @@
package com.azure.cosmos.models;
import com.azure.cosmos.ConsistencyLevel;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.CosmosDiagnosticsThresholds;
import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig;
import com.azure.cosmos.CosmosItemSerializer;
@@ -15,6 +16,7 @@
import java.time.Duration;
import java.util.List;
+import java.util.Map;
import java.util.Set;
/**
@@ -367,17 +369,38 @@ public Set getKeywordIdentifiers() {
}
/**
- * Sets a custom header to be included with this specific request.
+ * Sets additional headers to be included with this specific request.
*
+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported.
* This allows per-request header customization, such as setting a workload ID
* that overrides the client-level default set via
- * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}.
+ * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * If the same header is also set at the client level, the request-level value
+ * takes precedence.
+ *
+ * @param additionalHeaders map of {@link CosmosHeaderName} to value
+ * @return the CosmosReadManyRequestOptions.
+ * @throws IllegalArgumentException if the workload-id value is not a valid integer
+ */
+ public CosmosReadManyRequestOptions setAdditionalHeaders(Map additionalHeaders) {
+ CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
+ if (additionalHeaders != null) {
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue());
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a header to be included with this specific request.
*
- * @param name the header name (e.g., "x-ms-cosmos-workload-id")
- * @param value the header value (e.g., "20")
+ * @param name the header name
+ * @param value the header value
* @return the CosmosReadManyRequestOptions.
*/
- public CosmosReadManyRequestOptions setHeader(String name, String value) {
+ CosmosReadManyRequestOptions setHeader(String name, String value) {
this.actualRequestOptions.setHeader(name, value);
return this;
}
From b13d0e92656f6103cb427e623c57f2f0db304131 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Tue, 10 Mar 2026 12:32:05 -0500
Subject: [PATCH 07/13] fix: refactoring
---
.../spark/SparkE2EWorkloadIdITest.scala | 52 +--
.../azure/cosmos/AdditionalHeadersTests.java | 312 ++++++++++++++++++
.../com/azure/cosmos/CustomHeadersTests.java | 256 --------------
.../RxDocumentClientUnderTest.java | 4 +-
.../RxGatewayStoreModelTest.java | 26 +-
.../SpyClientUnderTestFactory.java | 4 +-
.../GatewayAddressCacheTest.java | 32 +-
.../rx/WorkloadIdDirectInterceptorTests.java | 223 +++++++++++++
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 85 +++--
.../com/azure/cosmos/CosmosHeaderName.java | 99 ++++++
.../implementation/RxGatewayStoreModel.java | 14 +-
11 files changed, 755 insertions(+), 352 deletions(-)
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
delete mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java
create mode 100644 sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
index d9706d0709e5..0f0a1f200d3e 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala
@@ -10,14 +10,18 @@ import com.fasterxml.jackson.databind.node.ObjectNode
import java.util.UUID
/**
- * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector.
+ * Integration tests (smoke tests) for the additional headers (workload-id) feature in the Spark connector.
+ * These are smoke tests — they verify that Spark DataFrame read and write operations succeed
+ * (no errors, correct data) when the `spark.cosmos.additionalHeaders` configuration is set.
+ * They do NOT assert that the workload-id header is actually present on the wire request.
+ * Wire-level header propagation is verified by:
+ * - Java SDK unit tests: RxGatewayStoreModelTest, GatewayAddressCacheTest
+ * - Java SDK integration tests: WorkloadIdE2ETests (interceptor-based wire assertions)
*
- * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows
- * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that
- * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations
- * initiated via Spark DataFrames (reads and writes).
- *
- * Requires the Cosmos DB Emulator running
+ * Test cases:
+ * 1. Read with workload-id header — Spark read succeeds, correct item returned
+ * 2. Write with workload-id header — Spark write succeeds, item verified via SDK read-back
+ * 3. No additionalHeaders (regression) — operations succeed without the config set
*/
class SparkE2EWorkloadIdITest
extends IntegrationSpec
@@ -32,10 +36,12 @@ class SparkE2EWorkloadIdITest
//scalastyle:off magic.number
//scalastyle:off null
- // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders
- // is configured with a workload-id header. The header should be passed through to the
- // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior.
- "spark query with customHeaders" can "read items with workload-id header" in {
+ // Integration smoke test #1: Spark read with workload-id header.
+ // Creates an item via SDK, then reads it back via Spark DataFrame with
+ // spark.cosmos.additionalHeaders set to {"x-ms-cosmos-workload-id": "15"}.
+ // Verifies the read succeeds and returns the correct item.
+ // This proves the header flows through the Spark config pipeline without causing errors.
+ "spark query with additionalHeaders" can "read items with workload-id header" in {
val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY
@@ -58,7 +64,7 @@ class SparkE2EWorkloadIdITest
"spark.cosmos.accountKey" -> cosmosMasterKey,
"spark.cosmos.database" -> cosmosDatabase,
"spark.cosmos.container" -> cosmosContainer,
- "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""",
+ "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""",
"spark.cosmos.read.partitioning.strategy" -> "Restrictive"
)
@@ -70,10 +76,12 @@ class SparkE2EWorkloadIdITest
item.getAs[String]("id") shouldEqual id
}
- // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders
- // is configured with a workload-id header. The item is written via Spark and then verified
- // via a direct SDK read to confirm the write was persisted correctly.
- "spark write with customHeaders" can "write items with workload-id header" in {
+ // Integration smoke test #2: Spark write with workload-id header.
+ // Writes an item via Spark DataFrame with spark.cosmos.additionalHeaders set to
+ // {"x-ms-cosmos-workload-id": "20"}, then reads it back via SDK to confirm
+ // write was persisted correctly.
+ // This proves the header flows through the Spark write pipeline without causing errors.
+ "spark write with additionalHeaders" can "write items with workload-id header" in {
val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY
@@ -91,7 +99,7 @@ class SparkE2EWorkloadIdITest
"spark.cosmos.accountKey" -> cosmosMasterKey,
"spark.cosmos.database" -> cosmosDatabase,
"spark.cosmos.container" -> cosmosContainer,
- "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""",
+ "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""",
"spark.cosmos.write.strategy" -> "ItemOverwrite",
"spark.cosmos.write.bulk.enabled" -> "false",
"spark.cosmos.serialization.inclusionMode" -> "NonDefault"
@@ -110,10 +118,12 @@ class SparkE2EWorkloadIdITest
readItem.getItem.get("name").textValue() shouldEqual "testWriteItem"
}
- // Regression test: verifies that Spark read operations continue to work correctly when
- // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not
- // break existing behavior for clients that do not use custom headers.
- "spark operations without customHeaders" can "still succeed" in {
+ // Integration smoke test #3: Regression test — no additionalHeaders configured.
+ // Verifies that Spark read operations continue to work correctly when
+ // spark.cosmos.additionalHeaders is NOT specified in the config map.
+ // This ensures the feature addition does not break existing Spark jobs
+ // that don't use additional headers.
+ "spark operations without additionalHeaders" can "still succeed" in {
val cosmosEndpoint = TestConfigurations.HOST
val cosmosMasterKey = TestConfigurations.MASTER_KEY
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
new file mode 100644
index 000000000000..9596988db621
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
@@ -0,0 +1,312 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.models.CosmosBatchRequestOptions;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.CosmosReadManyRequestOptions;
+import com.azure.cosmos.models.FeedRange;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Unit tests for the additional headers (workload-id) feature on CosmosClientBuilder and request options classes.
+ *
+ * These tests verify the public API surface: builder fluent methods, getter behavior,
+ * null/empty handling, CosmosHeaderName enum, and that setAdditionalHeaders() is publicly accessible
+ * on all request options classes.
+ */
+public class AdditionalHeadersTests {
+
+ /**
+ * Verifies that additional headers (e.g., workload-id) set via CosmosClientBuilder.additionalHeaders()
+ * are stored correctly and retrievable via getAdditionalHeaders().
+ */
+ @Test(groups = { "unit" })
+ public void additionalHeadersSetOnBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "25");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers);
+
+ assertThat(builder.getAdditionalHeaders())
+ .containsEntry("x-ms-cosmos-workload-id", "25");
+ }
+
+ /**
+ * Verifies that passing null to additionalHeaders() does not throw and that
+ * getAdditionalHeaders() returns null, ensuring graceful null handling.
+ */
+ @Test(groups = { "unit" })
+ public void additionalHeadersNullHandledGracefully() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(null);
+
+ assertThat(builder.getAdditionalHeaders()).isNull();
+ }
+
+ /**
+ * Verifies that passing an empty map to additionalHeaders() is accepted and
+ * getAdditionalHeaders() returns an empty (not null) map.
+ */
+ @Test(groups = { "unit" })
+ public void additionalHeadersEmptyMapHandled() {
+ Map emptyHeaders = new HashMap<>();
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(emptyHeaders);
+
+ assertThat(builder.getAdditionalHeaders()).isEmpty();
+ }
+
+ /**
+ * Verifies that CosmosHeaderName.WORKLOAD_ID maps to the correct header string.
+ */
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameWorkloadIdValue() {
+ assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName())
+ .isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that CosmosHeaderName.fromString() resolves known header strings to the
+ * correct enum value. This is used by the Spark connector to convert config strings
+ * to enum keys.
+ */
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringResolvesKnownHeader() {
+ CosmosHeaderName name = CosmosHeaderName.fromString("x-ms-cosmos-workload-id");
+ assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
+ }
+
+ /**
+ * Verifies that CosmosHeaderName.fromString() is case-insensitive.
+ */
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringIsCaseInsensitive() {
+ CosmosHeaderName name = CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID");
+ assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
+ }
+
+ /**
+ * Verifies that CosmosHeaderName.fromString() throws IllegalArgumentException
+ * for unknown header strings. This is the runtime equivalent of the compile-time
+ * safety provided by the enum — used when converting from Spark JSON config.
+ */
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringRejectsUnknownHeader() {
+ assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header")
+ .hasMessageContaining("Unknown header");
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosItemRequestOptions
+ * and supports fluent chaining for per-request header overrides on CRUD operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnItemRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
+
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBatchRequestOptions
+ * and supports fluent chaining for per-request header overrides on batch operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnBatchRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "20");
+
+ CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosChangeFeedRequestOptions
+ * and supports fluent chaining for per-request header overrides on change feed operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnChangeFeedRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "25");
+
+ CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
+ .createForProcessingFromBeginning(FeedRange.forFullRange())
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBulkExecutionOptions
+ * and supports fluent chaining for per-request header overrides on bulk ingestion operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnBulkExecutionOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "30");
+
+ CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosQueryRequestOptions
+ * and supports fluent chaining for per-request header overrides on query operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnQueryRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "35");
+
+ CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that setAdditionalHeaders() is publicly accessible on CosmosReadManyRequestOptions
+ * and supports fluent chaining for per-request header overrides on read-many operations.
+ */
+ @Test(groups = { "unit" })
+ public void setAdditionalHeadersOnReadManyRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "40");
+
+ CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
+ .setAdditionalHeaders(headers);
+
+ assertThat(options).isNotNull();
+ }
+
+ /**
+ * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
+ * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
+ * by the Cosmos DB service.
+ */
+ @Test(groups = { "unit" })
+ public void workloadIdHttpHeaderConstant() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at builder level with
+ * IllegalArgumentException. This covers both Gateway and Direct modes consistently
+ * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
+ * Range validation [1, 50] is the backend's responsibility — the SDK only validates
+ * that the value is a valid integer. This avoids hardcoding a range the backend team
+ * might change in the future.
+ */
+ @Test(groups = { "unit" })
+ public void outOfRangeWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "51");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers);
+
+ assertThat(builder.getAdditionalHeaders())
+ .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
+ }
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at request-options level
+ * (CosmosItemRequestOptions) with IllegalArgumentException, ensuring validation is
+ * symmetric with the builder level.
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosItemRequestOptions()
+ .setAdditionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that a non-numeric workload-id value is rejected at request-options level
+ * (CosmosBulkExecutionOptions) with IllegalArgumentException, ensuring validation is
+ * symmetric with the builder level.
+ */
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number");
+
+ assertThatThrownBy(() -> new CosmosBulkExecutionOptions()
+ .setAdditionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("not-a-number")
+ .hasMessageContaining("valid integer");
+ }
+
+ /**
+ * Verifies that a valid workload-id value passes builder validation.
+ */
+ @Test(groups = { "unit" })
+ public void validWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers);
+
+ assertThat(builder.getAdditionalHeaders())
+ .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+ }
+}
+
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
deleted file mode 100644
index 19eb03744d1a..000000000000
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java
+++ /dev/null
@@ -1,256 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package com.azure.cosmos;
-
-import com.azure.cosmos.implementation.HttpConstants;
-import com.azure.cosmos.models.CosmosBatchRequestOptions;
-import com.azure.cosmos.models.CosmosBulkExecutionOptions;
-import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
-import com.azure.cosmos.models.CosmosItemRequestOptions;
-import com.azure.cosmos.models.CosmosQueryRequestOptions;
-import com.azure.cosmos.models.CosmosReadManyRequestOptions;
-import com.azure.cosmos.models.FeedRange;
-import org.testng.annotations.Test;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
-/**
- * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes.
- *
- * These tests verify the public API surface: builder fluent methods, getter behavior,
- * null/empty handling, and that setHeader() is publicly accessible on all request options classes.
- */
-public class CustomHeadersTests {
-
- /**
- * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders()
- * are stored correctly and retrievable via getCustomHeaders().
- */
- @Test(groups = { "unit" })
- public void customHeadersSetOnBuilder() {
- Map headers = new HashMap<>();
- headers.put("x-ms-cosmos-workload-id", "25");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers);
-
- assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25");
- }
-
- /**
- * Verifies that passing null to customHeaders() does not throw and that
- * getCustomHeaders() returns null, ensuring graceful null handling.
- */
- @Test(groups = { "unit" })
- public void customHeadersNullHandledGracefully() {
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(null);
-
- assertThat(builder.getCustomHeaders()).isNull();
- }
-
- /**
- * Verifies that passing an empty map to customHeaders() is accepted and
- * getCustomHeaders() returns an empty (not null) map.
- */
- @Test(groups = { "unit" })
- public void customHeadersEmptyMapHandled() {
- Map emptyHeaders = new HashMap<>();
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(emptyHeaders);
-
- assertThat(builder.getCustomHeaders()).isEmpty();
- }
-
- /**
- * Verifies that headers not in the allowlist are rejected with IllegalArgumentException.
- * This ensures consistent behavior across Gateway and Direct modes — only headers with
- * RNTBD encoding support are allowed.
- */
- @Test(groups = { "unit" })
- public void unknownHeaderRejectedByAllowlist() {
- Map headers = new HashMap<>();
- headers.put("x-ms-custom-header", "value");
-
- assertThatThrownBy(() -> new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("x-ms-custom-header")
- .hasMessageContaining("not allowed");
- }
-
- /**
- * Verifies that a map containing both an allowed header and a disallowed header
- * is rejected — the entire map must pass the allowlist check.
- */
- @Test(groups = { "unit" })
- public void mixedAllowedAndDisallowedHeadersRejected() {
- Map headers = new HashMap<>();
- headers.put("x-ms-cosmos-workload-id", "15");
- headers.put("x-ms-custom-header", "value");
-
- assertThatThrownBy(() -> new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("x-ms-custom-header");
- }
-
- /**
- * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions
- * (previously package-private) and supports fluent chaining for per-request
- * header overrides on CRUD operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnItemRequestOptionsIsPublic() {
- CosmosItemRequestOptions options = new CosmosItemRequestOptions()
- .setHeader("x-ms-cosmos-workload-id", "15");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions
- * (previously package-private) and supports fluent chaining for per-request
- * header overrides on batch operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnBatchRequestOptionsIsPublic() {
- CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
- .setHeader("x-ms-cosmos-workload-id", "20");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions
- * (previously package-private) and supports fluent chaining for per-request
- * header overrides on change feed operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnChangeFeedRequestOptionsIsPublic() {
- CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
- .createForProcessingFromBeginning(FeedRange.forFullRange())
- .setHeader("x-ms-cosmos-workload-id", "25");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions
- * (previously package-private) and supports fluent chaining for per-request
- * header overrides on bulk ingestion operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnBulkExecutionOptionsIsPublic() {
- CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
- .setHeader("x-ms-cosmos-workload-id", "30");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions
- * is publicly accessible and supports fluent chaining for per-request header
- * overrides on query operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnQueryRequestOptionsIsPublic() {
- CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
- .setHeader("x-ms-cosmos-workload-id", "35");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions
- * is publicly accessible and supports fluent chaining for per-request header
- * overrides on read-many operations.
- */
- @Test(groups = { "unit" })
- public void setHeaderOnReadManyRequestOptionsIsPublic() {
- CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
- .setHeader("x-ms-cosmos-workload-id", "40");
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
- * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
- * by the Cosmos DB service.
- */
- @Test(groups = { "unit" })
- public void workloadIdHttpHeaderConstant() {
- assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
- }
-
- /**
- * Verifies that a non-numeric workload-id value is rejected at builder level with
- * IllegalArgumentException. This covers both Gateway and Direct modes consistently
- * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
- */
- @Test(groups = { "unit" })
- public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
- Map headers = new HashMap<>();
- headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc");
-
- assertThatThrownBy(() -> new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("abc")
- .hasMessageContaining("valid integer");
- }
-
- /**
- * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
- * Range validation [1, 50] is the backend's responsibility — the SDK only validates
- * that the value is a valid integer. This avoids hardcoding a range the backend team
- * might change in the future.
- */
- @Test(groups = { "unit" })
- public void outOfRangeWorkloadIdAcceptedByBuilder() {
- Map headers = new HashMap<>();
- headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers);
-
- assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
- }
-
- /**
- * Verifies that a valid workload-id value passes builder validation.
- */
- @Test(groups = { "unit" })
- public void validWorkloadIdAcceptedByBuilder() {
- Map headers = new HashMap<>();
- headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .customHeaders(headers);
-
- assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
- }
-}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
index d5f8b92ac7a6..b4cb04e0ae3e 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java
@@ -77,7 +77,7 @@ RxGatewayStoreModel createRxGatewayProxy(
GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker,
HttpClient rxOrigClient,
ApiType apiType,
- Map customHeaders) {
+ Map additionalHeaders) {
origHttpClient = rxOrigClient;
spyHttpClient = Mockito.spy(rxOrigClient);
@@ -96,6 +96,6 @@ RxGatewayStoreModel createRxGatewayProxy(
globalEndpointManager,
spyHttpClient,
apiType,
- customHeaders);
+ additionalHeaders);
}
}
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
index 587844f4043a..edf16806489c 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java
@@ -436,12 +436,12 @@ private boolean runCancelAfterRetainIteration() throws Exception {
}
/**
- * Verifies that client-level customHeaders (e.g., workload-id) are injected into
+ * Verifies that client-level additionalHeaders (e.g., workload-id) are injected into
* outgoing HTTP requests by performRequest(). This covers metadata requests
* (collection cache, partition key range) that don't go through getRequestHeaders().
*/
@Test(groups = "unit")
- public void customHeadersInjectedInPerformRequest() throws Exception {
+ public void additionalHeadersInjectedInPerformRequest() throws Exception {
DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
@@ -453,8 +453,8 @@ public void customHeadersInjectedInPerformRequest() throws Exception {
ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
- Map customHeaders = new HashMap<>();
- customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ Map additionalHeaders = new HashMap<>();
+ additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
clientContext,
@@ -465,9 +465,9 @@ public void customHeadersInjectedInPerformRequest() throws Exception {
globalEndpointManager,
httpClient,
null,
- customHeaders);
+ additionalHeaders);
- // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself
+ // Simulate a metadata request (e.g., collection cache lookup) — no additionalHeaders on the request itself
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
OperationType.Read,
@@ -490,12 +490,12 @@ public void customHeadersInjectedInPerformRequest() throws Exception {
}
/**
- * Verifies that request-level headers take precedence over client-level customHeaders.
+ * Verifies that request-level headers take precedence over client-level additionalHeaders.
* If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest()
* should NOT overwrite it.
*/
@Test(groups = "unit")
- public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception {
+ public void requestLevelHeadersTakePrecedenceOverAdditionalHeaders() throws Exception {
DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
@@ -507,8 +507,8 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio
ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException()));
- Map customHeaders = new HashMap<>();
- customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10");
+ Map additionalHeaders = new HashMap<>();
+ additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10");
RxGatewayStoreModel storeModel = new RxGatewayStoreModel(
clientContext,
@@ -519,7 +519,7 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio
globalEndpointManager,
httpClient,
null,
- customHeaders);
+ additionalHeaders);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
@@ -547,11 +547,11 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio
}
/**
- * Verifies that when customHeaders is null, performRequest() still works normally
+ * Verifies that when additionalHeaders is null, performRequest() still works normally
* without injecting any extra headers.
*/
@Test(groups = "unit")
- public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception {
+ public void nullAdditionalHeadersDoesNotAffectPerformRequest() throws Exception {
DiagnosticsClientContext clientContext = mockDiagnosticsClientContext();
ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
index 775b74785630..aa9029576b24 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java
@@ -128,7 +128,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
GlobalEndpointManager globalEndpointManager,
HttpClient rxClient,
ApiType apiType,
- Map customHeaders) {
+ Map additionalHeaders) {
this.origRxGatewayStoreModel = super.createRxGatewayProxy(
sessionContainer,
consistencyLevel,
@@ -137,7 +137,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer,
globalEndpointManager,
rxClient,
apiType,
- customHeaders);
+ additionalHeaders);
this.requests = Collections.synchronizedList(new ArrayList<>());
this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel);
this.initRequestCapture();
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
index 71d78aefdb2d..2d81311d8417 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java
@@ -1647,16 +1647,16 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs
}
/**
- * Verifies that client-level customHeaders (e.g., workload-id) are included in
+ * Verifies that client-level additionalHeaders (e.g., workload-id) are included in
* GatewayAddressCache's defaultRequestHeaders, which are sent on every address
* resolution request.
*/
@Test(groups = { "unit" })
- public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
+ public void additionalHeadersIncludedInDefaultRequestHeaders() throws Exception {
URI serviceEndpoint = new URI("https://localhost");
- Map customHeaders = new HashMap<>();
- customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ Map additionalHeaders = new HashMap<>();
+ additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
GatewayAddressCache cache = new GatewayAddressCache(
mockDiagnosticsClientContext(),
@@ -1670,7 +1670,7 @@ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
null,
null,
null,
- customHeaders);
+ additionalHeaders);
Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
defaultRequestHeadersField.setAccessible(true);
@@ -1681,18 +1681,18 @@ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception {
}
/**
- * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.)
+ * Verifies that additionalHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.)
* in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers
- * set before customHeaders are preserved.
+ * set before additionalHeaders are preserved.
*/
@Test(groups = { "unit" })
- public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
+ public void additionalHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
URI serviceEndpoint = new URI("https://localhost");
- Map customHeaders = new HashMap<>();
- customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent");
- customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version");
- customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
+ Map additionalHeaders = new HashMap<>();
+ additionalHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent");
+ additionalHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version");
+ additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
GatewayAddressCache cache = new GatewayAddressCache(
mockDiagnosticsClientContext(),
@@ -1706,7 +1706,7 @@ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
null,
null,
null,
- customHeaders);
+ additionalHeaders);
Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders");
defaultRequestHeadersField.setAccessible(true);
@@ -1716,16 +1716,16 @@ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception {
// SDK headers should NOT be overwritten
assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent");
assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION);
- // Custom header should still be added
+ // Additional header should still be added
assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25");
}
/**
- * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders
+ * Verifies that when additionalHeaders is null, GatewayAddressCache's defaultRequestHeaders
* contains only SDK system headers and no extra entries.
*/
@Test(groups = { "unit" })
- public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception {
+ public void nullAdditionalHeadersDoesNotAffectDefaultRequestHeaders() throws Exception {
URI serviceEndpoint = new URI("https://localhost");
GatewayAddressCache cache = new GatewayAddressCache(
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java
new file mode 100644
index 000000000000..2fe594e60831
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java
@@ -0,0 +1,223 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.rx;
+
+import com.azure.cosmos.CosmosAsyncClient;
+import com.azure.cosmos.CosmosAsyncContainer;
+import com.azure.cosmos.CosmosAsyncDatabase;
+import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.CosmosHeaderName;
+import com.azure.cosmos.TestObject;
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.implementation.ResourceType;
+import com.azure.cosmos.implementation.RxDocumentServiceRequest;
+import com.azure.cosmos.models.CosmosContainerProperties;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.PartitionKey;
+import com.azure.cosmos.models.PartitionKeyDefinition;
+import com.azure.cosmos.test.implementation.interceptor.CosmosInterceptorHelper;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Factory;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Interceptor-based integration tests for the workload-id / additional headers feature
+ * running in Direct mode (RNTBD) .
+ *
+ * These tests use {@link CosmosInterceptorHelper#registerTransportClientInterceptor} to
+ * capture {@link RxDocumentServiceRequest} objects at the TransportClient level and assert
+ * that the {@code x-ms-cosmos-workload-id} header is present with the correct value on
+ * wire requests. The interceptor only fires in Direct mode because Gateway mode data-plane
+ * requests go through {@code RxGatewayStoreModel → HttpClient}, bypassing the
+ * TransportClient entirely.
+ *
+ * Uses {@code @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")} to ensure
+ * all tests run exclusively in Direct/TCP mode where the transport interceptor is active.
+ *
+ * Gateway-mode header injection is separately verified by:
+ *
+ * Unit tests in {@code RxGatewayStoreModelTest} (data-plane)
+ * Unit tests in {@code GatewayAddressCacheTest} (metadata requests)
+ * Smoke tests in {@link WorkloadIdE2ETests} (Gateway mode end-to-end)
+ *
+ */
+public class WorkloadIdDirectInterceptorTests extends TestSuiteBase {
+
+ private static final String DATABASE_ID = "workloadIdDirectTestDb-" + UUID.randomUUID();
+ private static final String CONTAINER_ID = "workloadIdDirectTestContainer-" + UUID.randomUUID();
+
+ private CosmosAsyncClient clientWithWorkloadId;
+ private CosmosAsyncDatabase database;
+ private CosmosAsyncContainer container;
+
+ @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")
+ public WorkloadIdDirectInterceptorTests(CosmosClientBuilder clientBuilder) {
+ super(clientBuilder);
+ }
+
+ @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
+ public void beforeClass() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
+
+ clientWithWorkloadId = getClientBuilder()
+ .additionalHeaders(headers)
+ .buildAsyncClient();
+
+ database = createDatabase(clientWithWorkloadId, DATABASE_ID);
+
+ PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition();
+ ArrayList paths = new ArrayList<>();
+ paths.add("/mypk");
+ partitionKeyDef.setPaths(paths);
+ CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef);
+ database.createContainer(containerProperties).block();
+ container = database.getContainer(CONTAINER_ID);
+ }
+
+ /**
+ * Verifies that the client-level workload-id header (value "15") is present on the
+ * {@link RxDocumentServiceRequest} for create (data-plane) operation in Direct mode.
+ *
+ * Registers a transport client interceptor that captures every request before it hits
+ * the RNTBD wire. After performing a createItem, asserts that at least one captured
+ * request with {@code ResourceType.Document} carries the {@code x-ms-cosmos-workload-id}
+ * header with value {@code "15"}.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void verifyWorkloadIdHeaderPresentOnDataPlaneRequest() {
+ ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>();
+
+ CosmosInterceptorHelper.registerTransportClientInterceptor(
+ clientWithWorkloadId,
+ (request, storeResponse) -> {
+ capturedRequests.add(request);
+ return storeResponse;
+ }
+ );
+
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ // In Direct mode, the interceptor MUST capture requests — fail if it didn't
+ assertThat(capturedRequests)
+ .as("Transport interceptor should capture requests in Direct mode")
+ .isNotEmpty();
+
+ // Assert that at least one Document-type request carries the workload-id header
+ boolean foundWorkloadIdOnDocument = capturedRequests.stream()
+ .filter(r -> r.getResourceType() == ResourceType.Document)
+ .anyMatch(r -> "15".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)));
+
+ assertThat(foundWorkloadIdOnDocument)
+ .as("Expected workload-id header '15' on at least one Document request")
+ .isTrue();
+ }
+
+ /**
+ * Verifies that a per-request workload-id override (value "30") is present on the wire
+ * request instead of the client-level default (value "15").
+ *
+ * This confirms that the request-level header set via
+ * {@link CosmosItemRequestOptions#setAdditionalHeaders(Map)} takes precedence over
+ * the client-level header set via {@link CosmosClientBuilder#additionalHeaders(Map)}
+ * in Direct mode (RNTBD).
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void verifyRequestLevelOverrideOnWire() {
+ ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>();
+
+ CosmosInterceptorHelper.registerTransportClientInterceptor(
+ clientWithWorkloadId,
+ (request, storeResponse) -> {
+ capturedRequests.add(request);
+ return storeResponse;
+ }
+ );
+
+ Map requestHeaders = new HashMap<>();
+ requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30");
+
+ CosmosItemRequestOptions options = new CosmosItemRequestOptions()
+ .setAdditionalHeaders(requestHeaders);
+
+ TestObject doc = TestObject.create();
+ container.createItem(doc, new PartitionKey(doc.getMypk()), options).block();
+
+ // In Direct mode, the interceptor MUST capture requests — fail if it didn't
+ assertThat(capturedRequests)
+ .as("Transport interceptor should capture requests in Direct mode")
+ .isNotEmpty();
+
+ // Assert that the Document request carries the overridden value "30", not the client default "15"
+ boolean foundOverriddenWorkloadId = capturedRequests.stream()
+ .filter(r -> r.getResourceType() == ResourceType.Document)
+ .anyMatch(r -> "30".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)));
+
+ assertThat(foundOverriddenWorkloadId)
+ .as("Expected workload-id header '30' (request-level override) on Document request")
+ .isTrue();
+ }
+
+ /**
+ * Negative test: verifies that a client created WITHOUT additional headers does NOT
+ * have the workload-id header on wire requests in Direct mode. Ensures the header is
+ * only present when explicitly configured.
+ */
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT)
+ public void verifyNoWorkloadIdHeaderWhenNotConfigured() {
+ CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder())
+ .buildAsyncClient();
+
+ try {
+ ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>();
+
+ CosmosInterceptorHelper.registerTransportClientInterceptor(
+ clientWithoutHeaders,
+ (request, storeResponse) -> {
+ capturedRequests.add(request);
+ return storeResponse;
+ }
+ );
+
+ CosmosAsyncContainer c = clientWithoutHeaders
+ .getDatabase(DATABASE_ID)
+ .getContainer(CONTAINER_ID);
+
+ TestObject doc = TestObject.create();
+ c.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+
+ // In Direct mode, the interceptor MUST capture requests
+ assertThat(capturedRequests)
+ .as("Transport interceptor should capture requests in Direct mode")
+ .isNotEmpty();
+
+ // Assert that NO Document-type request carries the workload-id header
+ boolean anyDocRequestHasWorkloadId = capturedRequests.stream()
+ .filter(r -> r.getResourceType() == ResourceType.Document)
+ .anyMatch(r -> r.getHeaders().containsKey(HttpConstants.HttpHeaders.WORKLOAD_ID));
+
+ assertThat(anyDocRequestHasWorkloadId)
+ .as("Expected NO workload-id header on Document requests when not configured")
+ .isFalse();
+ } finally {
+ safeClose(clientWithoutHeaders);
+ }
+ }
+
+ @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
+ public void afterClass() {
+ safeDeleteDatabase(database);
+ safeClose(clientWithWorkloadId);
+ }
+}
+
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
index 3bf2fdafce7c..c41d8293f88f 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java
@@ -6,8 +6,8 @@
import com.azure.cosmos.CosmosAsyncContainer;
import com.azure.cosmos.CosmosAsyncDatabase;
import com.azure.cosmos.CosmosClientBuilder;
+import com.azure.cosmos.CosmosHeaderName;
import com.azure.cosmos.TestObject;
-import com.azure.cosmos.implementation.HttpConstants;
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosItemResponse;
@@ -27,13 +27,21 @@
import static org.assertj.core.api.Assertions.assertThat;
/**
- * End-to-end integration tests for the custom headers / workload-id feature.
+ * End-to-end smoke tests for the additional headers / workload-id feature in Gateway mode.
*
- * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally.
+ * Test type: EMULATOR INTEGRATION TEST — requires a Cosmos DB account or emulator.
*
* Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests
- * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC),
- * ensuring the workload-id header is correctly encoded and sent in both transport paths.
+ * in Gateway mode with Session consistency.
+ *
+ * These are smoke tests — they verify CRUD/query operations succeed (status code
+ * 200/201/204) when the workload-id header is set. They prove the header doesn't break
+ * anything but do NOT assert the header is actually present on the wire request.
+ *
+ * For wire-level assertion tests that verify the header is actually present on
+ * {@link com.azure.cosmos.implementation.RxDocumentServiceRequest}, see
+ * {@link WorkloadIdDirectInterceptorTests} which runs in Direct mode (RNTBD) using
+ * the transport client interceptor.
*/
public class WorkloadIdE2ETests extends TestSuiteBase {
@@ -51,11 +59,11 @@ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) {
@BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT)
public void beforeClass() {
- Map headers = new HashMap<>();
- headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
clientWithWorkloadId = getClientBuilder()
- .customHeaders(headers)
+ .additionalHeaders(headers)
.buildAsyncClient();
database = createDatabase(clientWithWorkloadId, DATABASE_ID);
@@ -71,7 +79,7 @@ public void beforeClass() {
/**
* verifies that a create (POST) operation succeeds when the client
- * has a workload-id custom header set at the builder level. Confirms the header
+ * has a workload-id additional header set at the builder level. Confirms the header
* flows through the request pipeline without causing errors.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
@@ -145,7 +153,7 @@ public void deleteItemWithClientLevelWorkloadId() {
/**
* Verifies that a per-request workload-id header override via
- * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header
+ * {@code CosmosItemRequestOptions.setAdditionalHeaders()} works. The request-level header
* (value "30") should take precedence over the client-level default (value "15").
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
@@ -153,8 +161,11 @@ public void createItemWithRequestLevelWorkloadIdOverride() {
// Verify per-request header override works — request-level should take precedence
TestObject doc = TestObject.create();
+ Map requestHeaders = new HashMap<>();
+ requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30");
+
CosmosItemRequestOptions options = new CosmosItemRequestOptions()
- .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30");
+ .setAdditionalHeaders(requestHeaders);
CosmosItemResponse response = container
.createItem(doc, new PartitionKey(doc.getMypk()), options)
@@ -166,7 +177,7 @@ public void createItemWithRequestLevelWorkloadIdOverride() {
/**
* Verifies that a cross-partition query operation succeeds when the client has a
- * workload-id custom header. Confirms the header flows correctly through the
+ * workload-id additional header. Confirms the header flows correctly through the
* query pipeline and does not affect result correctness.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
@@ -187,7 +198,7 @@ public void queryItemsWithClientLevelWorkloadId() {
/**
* Verifies that a per-request workload-id header override on
- * {@code CosmosQueryRequestOptions.setHeader()} works for query operations.
+ * {@code CosmosQueryRequestOptions.setAdditionalHeaders()} works for query operations.
* The request-level header (value "42") should take precedence over the
* client-level default.
*/
@@ -197,8 +208,11 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() {
TestObject doc = TestObject.create();
container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block();
+ Map requestHeaders = new HashMap<>();
+ requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "42");
+
CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions()
- .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42");
+ .setAdditionalHeaders(requestHeaders);
long count = container
.queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class)
@@ -210,13 +224,13 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() {
}
/**
- * Regression test: verifies that a client created without any custom headers
- * continues to work normally. Ensures the custom headers feature does not
+ * Regression test: verifies that a client created without any additional headers
+ * continues to work normally. Ensures the additional headers feature does not
* introduce regressions for clients that do not use it.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
- public void clientWithNoCustomHeadersStillWorks() {
- // Verify that a client without custom headers works normally (no regression)
+ public void clientWithNoAdditionalHeadersStillWorks() {
+ // Verify that a client without additional headers works normally (no regression)
CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder())
.buildAsyncClient();
@@ -238,15 +252,15 @@ public void clientWithNoCustomHeadersStillWorks() {
}
/**
- * Verifies that a client created with an empty custom headers map works normally.
- * An empty map should behave identically to no custom headers — no errors,
+ * Verifies that a client created with an empty additional headers map works normally.
+ * An empty map should behave identically to no additional headers — no errors,
* no unexpected behavior.
*/
@Test(groups = { "emulator" }, timeOut = TIMEOUT)
- public void clientWithEmptyCustomHeaders() {
- // Verify that a client with empty custom headers map works normally
+ public void clientWithEmptyAdditionalHeaders() {
+ // Verify that a client with empty additional headers map works normally
CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder())
- .customHeaders(new HashMap<>())
+ .additionalHeaders(new HashMap<>())
.buildAsyncClient();
try {
@@ -266,20 +280,21 @@ public void clientWithEmptyCustomHeaders() {
}
}
+
/**
- * Verifies that unknown headers in customHeaders are rejected by the allowlist.
- * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist
- * ensures consistent behavior across Gateway and Direct modes.
+ * Verifies that the {@link CosmosHeaderName} enum-based allowlist rejects unknown
+ * header names at client build time. Attempting to set an unrecognized header via
+ * {@code additionalHeaders()} should throw {@link IllegalArgumentException} from
+ * {@link CosmosHeaderName#fromString(String)}, preventing arbitrary headers from
+ * being sent.
*/
- @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class)
- public void unknownCustomHeadersRejectedByAllowlist() {
- Map headers = new HashMap<>();
- headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20");
- headers.put("x-ms-custom-test-header", "test-value");
-
- // Should throw IllegalArgumentException due to unknown header
- copyCosmosClientBuilder(getClientBuilder())
- .customHeaders(headers);
+ @Test(groups = { "emulator" }, timeOut = TIMEOUT,
+ expectedExceptions = IllegalArgumentException.class)
+ public void unknownAdditionalHeadersRejectedByAllowlist() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
+ // Use fromString with an unknown header — should throw IllegalArgumentException
+ CosmosHeaderName.fromString("x-unknown-header");
}
@AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java
new file mode 100644
index 000000000000..7be67765d706
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java
@@ -0,0 +1,99 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+
+import java.util.Arrays;
+import java.util.Map;
+
+/**
+ * Defines the set of additional headers that can be set on a {@link CosmosClientBuilder}
+ * via {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * Only headers with RNTBD encoding support are included in this enum, ensuring consistent
+ * behavior across both Gateway mode (HTTP) and Direct mode (RNTBD binary protocol).
+ */
+public enum CosmosHeaderName {
+
+ /**
+ * The workload ID header ({@code x-ms-cosmos-workload-id}).
+ *
+ * Valid values: a string representation of an integer (e.g., {@code "15"}).
+ * The service accepts values in the range 1–50 for Azure Monitor metrics attribution.
+ * The SDK validates that the value is a valid integer but does not enforce range limits —
+ * range validation is the backend's responsibility.
+ */
+ WORKLOAD_ID(HttpConstants.HttpHeaders.WORKLOAD_ID);
+
+ private final String headerName;
+
+ CosmosHeaderName(String headerName) {
+ this.headerName = headerName;
+ }
+
+ /**
+ * Gets the canonical HTTP header name string (e.g., {@code "x-ms-cosmos-workload-id"}).
+ *
+ * @return the header name string
+ */
+ public String getHeaderName() {
+ return this.headerName;
+ }
+
+ /**
+ * Converts a header name string to the corresponding {@link CosmosHeaderName} enum value.
+ *
+ * This is primarily used by the Spark connector, which parses header names from JSON
+ * configuration strings and needs to convert them to enum values before calling
+ * {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}.
+ *
+ * @param headerName the header name string (e.g., {@code "x-ms-cosmos-workload-id"})
+ * @return the matching {@link CosmosHeaderName}
+ * @throws IllegalArgumentException if the header name does not match any known enum value
+ */
+ public static CosmosHeaderName fromString(String headerName) {
+ for (CosmosHeaderName name : values()) {
+ if (name.headerName.equalsIgnoreCase(headerName)) {
+ return name;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Unknown header: '" + headerName + "'. Allowed headers: " + Arrays.toString(values()));
+ }
+
+ /**
+ * Validates all entries in an additional-headers map.
+ *
+ * Each {@link CosmosHeaderName} enum value carries its own validation rules. Currently:
+ *
+ * {@link #WORKLOAD_ID}: value must be a valid integer string
+ *
+ *
+ * This method is called by {@link CosmosClientBuilder#additionalHeaders(Map)} and
+ * by every request-options class's {@code setAdditionalHeaders} method, so the
+ * validation logic lives in one place.
+ *
+ * @param additionalHeaders the map to validate (may be null — no-op in that case)
+ * @throws IllegalArgumentException if any header value fails validation
+ */
+ public static void validateAdditionalHeaders(Map additionalHeaders) {
+ if (additionalHeaders == null) {
+ return;
+ }
+ for (Map.Entry entry : additionalHeaders.entrySet()) {
+ CosmosHeaderName key = entry.getKey();
+ String value = entry.getValue();
+
+ if (WORKLOAD_ID == key && value != null) {
+ try {
+ Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException(
+ "Invalid value '" + value + "' for header '" + key.getHeaderName()
+ + "'. The value must be a valid integer.", e);
+ }
+ }
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
index 35f6c64c0079..1a0311a90905 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java
@@ -91,7 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize
private GatewayServiceConfigurationReader gatewayServiceConfigurationReader;
private RxClientCollectionCache collectionCache;
private GatewayServerErrorInjector gatewayServerErrorInjector;
- private final Map customHeaders;
+ private final Map additionalHeaders;
public RxGatewayStoreModel(
DiagnosticsClientContext clientContext,
@@ -102,7 +102,7 @@ public RxGatewayStoreModel(
GlobalEndpointManager globalEndpointManager,
HttpClient httpClient,
ApiType apiType,
- Map customHeaders) {
+ Map additionalHeaders) {
this.clientContext = clientContext;
@@ -118,7 +118,7 @@ public RxGatewayStoreModel(
this.httpClient = httpClient;
this.sessionContainer = sessionContainer;
- this.customHeaders = customHeaders;
+ this.additionalHeaders = additionalHeaders;
}
public RxGatewayStoreModel(RxGatewayStoreModel inner) {
@@ -130,7 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) {
this.httpClient = inner.httpClient;
this.sessionContainer = inner.sessionContainer;
- this.customHeaders = inner.customHeaders;
+ this.additionalHeaders = inner.additionalHeaders;
}
protected Map getDefaultHeaders(
@@ -283,10 +283,10 @@ public Mono performRequest(RxDocumentServiceRequest r
request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics();
}
- // Apply client-level custom headers (e.g., workload-id) to all requests
+ // Apply client-level additional headers (e.g., workload-id) to all requests
// including metadata requests (collection cache, partition key range, etc.)
- if (this.customHeaders != null && !this.customHeaders.isEmpty()) {
- for (Map.Entry entry : this.customHeaders.entrySet()) {
+ if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) {
+ for (Map.Entry entry : this.additionalHeaders.entrySet()) {
// Only set if not already present — request-level headers take precedence
if (!request.getHeaders().containsKey(entry.getKey())) {
request.getHeaders().put(entry.getKey(), entry.getValue());
From a79808d6ca2be4a189fcf743780034c6b94fe02c Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Thu, 12 Mar 2026 16:47:52 -0500
Subject: [PATCH 08/13] fix: updated chnage log text
---
sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 2 +-
sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +-
6 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
index 6d23a6f92cf2..31851ffaca31 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
index 49da84cfc8d4..070dbc48e1d9 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
index 1290801e56f2..122a54604f38 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
index 8660d2ba7b59..94ac0cfa7b0e 100644
--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
index 6149df9f05e4..6a42a70828b4 100644
--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md
@@ -3,7 +3,7 @@
### 4.45.0-beta.1 (Unreleased)
#### Features Added
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index 424aef4e3a30..c300f7185825 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -4,7 +4,7 @@
#### Features Added
* Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757)
-* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
+* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128)
#### Breaking Changes
From 9ce5d98a4971524969315d137e6583c3e3ecb798 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Thu, 12 Mar 2026 16:54:54 -0500
Subject: [PATCH 09/13] fix: addressing comments
---
.../src/main/java/com/azure/cosmos/CosmosClientBuilder.java | 4 +++-
.../directconnectivity/GatewayAddressCache.java | 6 ++++--
2 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
index 129b4b578c58..5b5e80639fed 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java
@@ -758,7 +758,9 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) {
*/
public CosmosClientBuilder additionalHeaders(Map additionalHeaders) {
CosmosHeaderName.validateAdditionalHeaders(additionalHeaders);
- this.additionalHeaders = additionalHeaders;
+ this.additionalHeaders = additionalHeaders != null
+ ? new HashMap<>(additionalHeaders)
+ : null;
return this;
}
diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
index 119c7309256d..1c09f37e6fa8 100644
--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java
@@ -166,8 +166,10 @@ public GatewayAddressCache(
HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES,
HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES);
- // Apply client-level additional headers (e.g., workload-id) to metadata requests
- // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten
+ // Apply client-level additional headers (e.g., workload-id) to address resolution requests.
+ // Address resolution is the one metadata path that bypasses RxGatewayStoreModel (which
+ // handles all other metadata requests) — GatewayAddressCache calls httpClient.send() directly.
+ // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten.
if (additionalHeaders != null && !additionalHeaders.isEmpty()) {
for (Map.Entry entry : additionalHeaders.entrySet()) {
this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue());
From 289d321751ed2ebc68bfe1312f6b341b0395dc03 Mon Sep 17 00:00:00 2001
From: dibahlfi <106994927+dibahlfi@users.noreply.github.com>
Date: Wed, 18 Mar 2026 15:28:53 -0500
Subject: [PATCH 10/13] fix: addressing comments
---
.../cosmos/spark/CosmosClientCache.scala | 6 +-
.../azure/cosmos/AdditionalHeadersTests.java | 312 ---------------
.../azure/cosmos/WorkloadIdHeaderTests.java | 367 ++++++++++++++++++
.../ThinClientStoreModelTest.java | 129 +++++-
.../azure/cosmos/rx/WorkloadIdE2ETests.java | 2 +-
.../com/azure/cosmos/CosmosClientBuilder.java | 6 +-
.../com/azure/cosmos/CosmosHeaderName.java | 103 ++++-
.../implementation/RxDocumentClientImpl.java | 9 +-
.../implementation/ThinClientStoreModel.java | 5 +-
.../models/CosmosBatchRequestOptions.java | 2 +-
.../models/CosmosBulkExecutionOptions.java | 2 +-
.../CosmosChangeFeedRequestOptions.java | 2 +-
.../models/CosmosContainerRequestOptions.java | 43 ++
.../models/CosmosDatabaseRequestOptions.java | 43 ++
.../models/CosmosItemRequestOptions.java | 2 +-
.../models/CosmosQueryRequestOptions.java | 2 +-
.../models/CosmosReadManyRequestOptions.java | 2 +-
.../CosmosStoredProcedureRequestOptions.java | 43 ++
18 files changed, 728 insertions(+), 352 deletions(-)
delete mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java
diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
index 3a0174a9a9a8..49b051a97276 100644
--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala
@@ -716,11 +716,11 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait {
// These headers are attached to every Cosmos DB request made by this client instance.
// Converts Map[String, String] from Spark config to Map[CosmosHeaderName, String] for the builder.
if (cosmosClientConfiguration.additionalHeaders.isDefined) {
- val enumHeaders = new java.util.HashMap[CosmosHeaderName, String]()
+ val headerMap = new java.util.HashMap[CosmosHeaderName, String]()
for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) {
- enumHeaders.put(CosmosHeaderName.fromString(key), value)
+ headerMap.put(CosmosHeaderName.fromString(key), value)
}
- builder.additionalHeaders(enumHeaders)
+ builder.additionalHeaders(headerMap)
}
var client = builder.buildAsyncClient()
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
deleted file mode 100644
index 9596988db621..000000000000
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java
+++ /dev/null
@@ -1,312 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-package com.azure.cosmos;
-
-import com.azure.cosmos.implementation.HttpConstants;
-import com.azure.cosmos.models.CosmosBatchRequestOptions;
-import com.azure.cosmos.models.CosmosBulkExecutionOptions;
-import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
-import com.azure.cosmos.models.CosmosItemRequestOptions;
-import com.azure.cosmos.models.CosmosQueryRequestOptions;
-import com.azure.cosmos.models.CosmosReadManyRequestOptions;
-import com.azure.cosmos.models.FeedRange;
-import org.testng.annotations.Test;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
-/**
- * Unit tests for the additional headers (workload-id) feature on CosmosClientBuilder and request options classes.
- *
- * These tests verify the public API surface: builder fluent methods, getter behavior,
- * null/empty handling, CosmosHeaderName enum, and that setAdditionalHeaders() is publicly accessible
- * on all request options classes.
- */
-public class AdditionalHeadersTests {
-
- /**
- * Verifies that additional headers (e.g., workload-id) set via CosmosClientBuilder.additionalHeaders()
- * are stored correctly and retrievable via getAdditionalHeaders().
- */
- @Test(groups = { "unit" })
- public void additionalHeadersSetOnBuilder() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "25");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(headers);
-
- assertThat(builder.getAdditionalHeaders())
- .containsEntry("x-ms-cosmos-workload-id", "25");
- }
-
- /**
- * Verifies that passing null to additionalHeaders() does not throw and that
- * getAdditionalHeaders() returns null, ensuring graceful null handling.
- */
- @Test(groups = { "unit" })
- public void additionalHeadersNullHandledGracefully() {
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(null);
-
- assertThat(builder.getAdditionalHeaders()).isNull();
- }
-
- /**
- * Verifies that passing an empty map to additionalHeaders() is accepted and
- * getAdditionalHeaders() returns an empty (not null) map.
- */
- @Test(groups = { "unit" })
- public void additionalHeadersEmptyMapHandled() {
- Map emptyHeaders = new HashMap<>();
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(emptyHeaders);
-
- assertThat(builder.getAdditionalHeaders()).isEmpty();
- }
-
- /**
- * Verifies that CosmosHeaderName.WORKLOAD_ID maps to the correct header string.
- */
- @Test(groups = { "unit" })
- public void cosmosHeaderNameWorkloadIdValue() {
- assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName())
- .isEqualTo("x-ms-cosmos-workload-id");
- }
-
- /**
- * Verifies that CosmosHeaderName.fromString() resolves known header strings to the
- * correct enum value. This is used by the Spark connector to convert config strings
- * to enum keys.
- */
- @Test(groups = { "unit" })
- public void cosmosHeaderNameFromStringResolvesKnownHeader() {
- CosmosHeaderName name = CosmosHeaderName.fromString("x-ms-cosmos-workload-id");
- assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
- }
-
- /**
- * Verifies that CosmosHeaderName.fromString() is case-insensitive.
- */
- @Test(groups = { "unit" })
- public void cosmosHeaderNameFromStringIsCaseInsensitive() {
- CosmosHeaderName name = CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID");
- assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
- }
-
- /**
- * Verifies that CosmosHeaderName.fromString() throws IllegalArgumentException
- * for unknown header strings. This is the runtime equivalent of the compile-time
- * safety provided by the enum — used when converting from Spark JSON config.
- */
- @Test(groups = { "unit" })
- public void cosmosHeaderNameFromStringRejectsUnknownHeader() {
- assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header"))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("x-ms-custom-header")
- .hasMessageContaining("Unknown header");
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosItemRequestOptions
- * and supports fluent chaining for per-request header overrides on CRUD operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnItemRequestOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
-
- CosmosItemRequestOptions options = new CosmosItemRequestOptions()
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBatchRequestOptions
- * and supports fluent chaining for per-request header overrides on batch operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnBatchRequestOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "20");
-
- CosmosBatchRequestOptions options = new CosmosBatchRequestOptions()
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosChangeFeedRequestOptions
- * and supports fluent chaining for per-request header overrides on change feed operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnChangeFeedRequestOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "25");
-
- CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions
- .createForProcessingFromBeginning(FeedRange.forFullRange())
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBulkExecutionOptions
- * and supports fluent chaining for per-request header overrides on bulk ingestion operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnBulkExecutionOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "30");
-
- CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions()
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosQueryRequestOptions
- * and supports fluent chaining for per-request header overrides on query operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnQueryRequestOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "35");
-
- CosmosQueryRequestOptions options = new CosmosQueryRequestOptions()
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that setAdditionalHeaders() is publicly accessible on CosmosReadManyRequestOptions
- * and supports fluent chaining for per-request header overrides on read-many operations.
- */
- @Test(groups = { "unit" })
- public void setAdditionalHeadersOnReadManyRequestOptions() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "40");
-
- CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions()
- .setAdditionalHeaders(headers);
-
- assertThat(options).isNotNull();
- }
-
- /**
- * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined
- * with the correct canonical header name "x-ms-cosmos-workload-id" as expected
- * by the Cosmos DB service.
- */
- @Test(groups = { "unit" })
- public void workloadIdHttpHeaderConstant() {
- assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
- }
-
- /**
- * Verifies that a non-numeric workload-id value is rejected at builder level with
- * IllegalArgumentException. This covers both Gateway and Direct modes consistently
- * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode).
- */
- @Test(groups = { "unit" })
- public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
-
- assertThatThrownBy(() -> new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("abc")
- .hasMessageContaining("valid integer");
- }
-
- /**
- * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK.
- * Range validation [1, 50] is the backend's responsibility — the SDK only validates
- * that the value is a valid integer. This avoids hardcoding a range the backend team
- * might change in the future.
- */
- @Test(groups = { "unit" })
- public void outOfRangeWorkloadIdAcceptedByBuilder() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "51");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(headers);
-
- assertThat(builder.getAdditionalHeaders())
- .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51");
- }
-
- /**
- * Verifies that a non-numeric workload-id value is rejected at request-options level
- * (CosmosItemRequestOptions) with IllegalArgumentException, ensuring validation is
- * symmetric with the builder level.
- */
- @Test(groups = { "unit" })
- public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
-
- assertThatThrownBy(() -> new CosmosItemRequestOptions()
- .setAdditionalHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("abc")
- .hasMessageContaining("valid integer");
- }
-
- /**
- * Verifies that a non-numeric workload-id value is rejected at request-options level
- * (CosmosBulkExecutionOptions) with IllegalArgumentException, ensuring validation is
- * symmetric with the builder level.
- */
- @Test(groups = { "unit" })
- public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number");
-
- assertThatThrownBy(() -> new CosmosBulkExecutionOptions()
- .setAdditionalHeaders(headers))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("not-a-number")
- .hasMessageContaining("valid integer");
- }
-
- /**
- * Verifies that a valid workload-id value passes builder validation.
- */
- @Test(groups = { "unit" })
- public void validWorkloadIdAcceptedByBuilder() {
- Map headers = new HashMap<>();
- headers.put(CosmosHeaderName.WORKLOAD_ID, "15");
-
- CosmosClientBuilder builder = new CosmosClientBuilder()
- .endpoint("https://test.documents.azure.com:443/")
- .key("dGVzdEtleQ==")
- .additionalHeaders(headers);
-
- assertThat(builder.getAdditionalHeaders())
- .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15");
- }
-}
-
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java
new file mode 100644
index 000000000000..2a9c6f8e2a02
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java
@@ -0,0 +1,367 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos;
+
+import com.azure.cosmos.implementation.HttpConstants;
+import com.azure.cosmos.implementation.ImplementationBridgeHelpers;
+import com.azure.cosmos.implementation.RequestOptions;
+import com.azure.cosmos.models.CosmosBatchRequestOptions;
+import com.azure.cosmos.models.CosmosBulkExecutionOptions;
+import com.azure.cosmos.models.CosmosChangeFeedRequestOptions;
+import com.azure.cosmos.models.CosmosContainerRequestOptions;
+import com.azure.cosmos.models.CosmosDatabaseRequestOptions;
+import com.azure.cosmos.models.CosmosItemRequestOptions;
+import com.azure.cosmos.models.CosmosQueryRequestOptions;
+import com.azure.cosmos.models.CosmosReadManyRequestOptions;
+import com.azure.cosmos.models.CosmosStoredProcedureRequestOptions;
+import com.azure.cosmos.models.FeedRange;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Comprehensive unit tests for the workload-id / additional headers feature.
+ *
+ * Covers three layers:
+ *
+ * Public API surface — {@link CosmosHeaderName} constants, {@code CosmosClientBuilder.additionalHeaders()},
+ * and that {@code setAdditionalHeaders()} is callable on every request options class.
+ * Validation — non-numeric workload-id rejected at builder and request-options levels;
+ * out-of-range values accepted (range enforcement is the backend's responsibility).
+ * Internal wiring — headers set via {@code setAdditionalHeaders()} actually reach
+ * {@code RequestOptions.getHeaders()}, which is what {@code RxDocumentClientImpl.getRequestHeaders()}
+ * reads to populate outbound request headers. Covers both data-plane and control-plane classes.
+ *
+ * E2E tests (SDK → Gateway/Backend → Jarvis) are in {@code WorkloadIdJarvisValidationTests}.
+ */
+public class WorkloadIdHeaderTests {
+
+ private static final String WORKLOAD_ID_HEADER = HttpConstants.HttpHeaders.WORKLOAD_ID;
+ private static final String TEST_WORKLOAD_ID = "42";
+
+ // ==============================================================================================
+ // 1. CosmosHeaderName constants
+ // ==============================================================================================
+
+ @Test(groups = { "unit" })
+ public void workloadIdHttpHeaderConstant() {
+ assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameWorkloadIdValue() {
+ assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName()).isEqualTo("x-ms-cosmos-workload-id");
+ }
+
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringResolvesKnownHeader() {
+ assertThat(CosmosHeaderName.fromString("x-ms-cosmos-workload-id")).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
+ }
+
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringIsCaseInsensitive() {
+ assertThat(CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID")).isEqualTo(CosmosHeaderName.WORKLOAD_ID);
+ }
+
+ @Test(groups = { "unit" })
+ public void cosmosHeaderNameFromStringRejectsUnknownHeader() {
+ assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("x-ms-custom-header");
+ }
+
+ // ==============================================================================================
+ // 2. CosmosClientBuilder — additionalHeaders()
+ // ==============================================================================================
+
+ @Test(groups = { "unit" })
+ public void builderStoresAdditionalHeaders() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "25");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers);
+
+ assertThat(builder.getAdditionalHeaders())
+ .containsEntry(WORKLOAD_ID_HEADER, "25");
+ }
+
+ @Test(groups = { "unit" })
+ public void builderHandlesNullAdditionalHeaders() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(null);
+
+ assertThat(builder.getAdditionalHeaders()).isNull();
+ }
+
+ @Test(groups = { "unit" })
+ public void builderHandlesEmptyAdditionalHeaders() {
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(new HashMap<>());
+
+ assertThat(builder.getAdditionalHeaders()).isEmpty();
+ }
+
+ // ==============================================================================================
+ // 3. Validation — non-numeric rejected, out-of-range accepted
+ // ==============================================================================================
+
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBuilderLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc")
+ .hasMessageContaining("valid integer");
+ }
+
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "abc");
+
+ assertThatThrownBy(() -> new CosmosItemRequestOptions().setAdditionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("abc");
+ }
+
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number");
+
+ assertThatThrownBy(() -> new CosmosBulkExecutionOptions().setAdditionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("not-a-number");
+ }
+
+ @Test(groups = { "unit" })
+ public void nonNumericWorkloadIdRejectedAtDatabaseRequestOptionsLevel() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number");
+
+ assertThatThrownBy(() -> new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("not-a-number");
+ }
+
+ /** Range validation [1, 50] is the backend's responsibility — SDK only validates integer format. */
+ @Test(groups = { "unit" })
+ public void outOfRangeWorkloadIdAcceptedByBuilder() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, "51");
+
+ CosmosClientBuilder builder = new CosmosClientBuilder()
+ .endpoint("https://test.documents.azure.com:443/")
+ .key("dGVzdEtleQ==")
+ .additionalHeaders(headers);
+
+ assertThat(builder.getAdditionalHeaders())
+ .containsEntry(WORKLOAD_ID_HEADER, "51");
+ }
+
+ // ==============================================================================================
+ // 4. Internal wiring — workload-id reaches RequestOptions.getHeaders()
+ // This is the acceptance proof that the header will be on the wire.
+ // ==============================================================================================
+
+ /**
+ * Data provider covering request options classes that have a no-arg {@code toRequestOptions()}.
+ * Includes both data-plane and control-plane classes.
+ */
+ @DataProvider(name = "requestOptionsWithToRequestOptions")
+ public Object[][] requestOptionsWithToRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID);
+
+ return new Object[][] {
+ // Data-plane
+ { "CosmosItemRequestOptions",
+ reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(headers)) },
+ { "CosmosBatchRequestOptions",
+ reflectToRequestOptions(new CosmosBatchRequestOptions().setAdditionalHeaders(headers)) },
+ // Control-plane
+ { "CosmosDatabaseRequestOptions",
+ reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers)) },
+ { "CosmosContainerRequestOptions",
+ reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(headers)) },
+ { "CosmosStoredProcedureRequestOptions",
+ reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(headers)) },
+ };
+ }
+
+ @Test(groups = { "unit" }, dataProvider = "requestOptionsWithToRequestOptions")
+ public void workloadIdReachesOutboundHeaders(String optionsClassName, RequestOptions requestOptions) {
+ assertThat(requestOptions.getHeaders())
+ .as("workload-id should reach RequestOptions.getHeaders() for " + optionsClassName)
+ .isNotNull()
+ .containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID);
+ }
+
+ /** CosmosQueryRequestOptions uses a bridge accessor pattern (no simple toRequestOptions()). */
+ @Test(groups = { "unit" })
+ public void workloadIdReachesQueryRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID);
+
+ RequestOptions opts = ImplementationBridgeHelpers
+ .CosmosQueryRequestOptionsHelper
+ .getCosmosQueryRequestOptionsAccessor()
+ .toRequestOptions(new CosmosQueryRequestOptions().setAdditionalHeaders(headers));
+
+ assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID);
+ }
+
+ /** CosmosChangeFeedRequestOptions uses a bridge accessor that exposes getHeaders(). */
+ @Test(groups = { "unit" })
+ public void workloadIdReachesChangeFeedRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID);
+
+ Map extracted = ImplementationBridgeHelpers
+ .CosmosChangeFeedRequestOptionsHelper
+ .getCosmosChangeFeedRequestOptionsAccessor()
+ .getHeaders(
+ CosmosChangeFeedRequestOptions
+ .createForProcessingFromBeginning(FeedRange.forFullRange())
+ .setAdditionalHeaders(headers));
+
+ assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID);
+ }
+
+ /** CosmosReadManyRequestOptions uses getImpl().applyToRequestOptions(). */
+ @Test(groups = { "unit" })
+ public void workloadIdReachesReadManyRequestOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID);
+
+ RequestOptions opts = ImplementationBridgeHelpers
+ .CosmosReadManyRequestOptionsHelper
+ .getCosmosReadManyRequestOptionsAccessor()
+ .getImpl(new CosmosReadManyRequestOptions().setAdditionalHeaders(headers))
+ .applyToRequestOptions(new RequestOptions());
+
+ assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID);
+ }
+
+ /** CosmosBulkExecutionOptions delegates to an internal CosmosBatchRequestOptions. */
+ @Test(groups = { "unit" })
+ public void workloadIdReachesBulkExecutionOptions() {
+ Map headers = new HashMap<>();
+ headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID);
+
+ Map extracted = reflectGetHeaders(
+ new CosmosBulkExecutionOptions().setAdditionalHeaders(headers));
+
+ assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID);
+ }
+
+ // ==============================================================================================
+ // 5. Null/empty additionalHeaders does not inject workload-id into outbound headers
+ // ==============================================================================================
+
+ @Test(groups = { "unit" })
+ public void nullAdditionalHeadersDoesNotInjectWorkloadId() {
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(null)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(null)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(null)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(null)));
+ }
+
+ @Test(groups = { "unit" })
+ public void emptyAdditionalHeadersDoesNotInjectWorkloadId() {
+ Map empty = new HashMap<>();
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(empty)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(empty)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(empty)));
+ assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(empty)));
+ }
+
+ // ==============================================================================================
+ // 6. Coverage matrix guard — setAdditionalHeaders() exists on all expected classes
+ // ==============================================================================================
+
+ /** Guard against accidental removal of setAdditionalHeaders() from any request options class. */
+ @Test(groups = { "unit" })
+ public void allExpectedRequestOptionsClassesSupportAdditionalHeaders() {
+ Class>[] expectedClasses = {
+ // Data-plane
+ CosmosItemRequestOptions.class,
+ CosmosBatchRequestOptions.class,
+ CosmosBulkExecutionOptions.class,
+ CosmosQueryRequestOptions.class,
+ CosmosReadManyRequestOptions.class,
+ CosmosChangeFeedRequestOptions.class,
+ // Control-plane
+ CosmosDatabaseRequestOptions.class,
+ CosmosContainerRequestOptions.class,
+ CosmosStoredProcedureRequestOptions.class,
+ };
+
+ for (Class> clazz : expectedClasses) {
+ assertThat(hasSetAdditionalHeaders(clazz))
+ .as(clazz.getSimpleName() + " should have setAdditionalHeaders()")
+ .isTrue();
+ }
+ }
+
+ // ==============================================================================================
+ // Helpers
+ // ==============================================================================================
+
+ private static RequestOptions reflectToRequestOptions(Object cosmosRequestOptions) {
+ try {
+ java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("toRequestOptions");
+ m.setAccessible(true);
+ return (RequestOptions) m.invoke(cosmosRequestOptions);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Failed to call toRequestOptions() on " + cosmosRequestOptions.getClass().getSimpleName(), e);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static Map reflectGetHeaders(Object cosmosRequestOptions) {
+ try {
+ java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("getHeaders");
+ m.setAccessible(true);
+ return (Map) m.invoke(cosmosRequestOptions);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Failed to call getHeaders() on " + cosmosRequestOptions.getClass().getSimpleName(), e);
+ }
+ }
+
+ private static void assertNoWorkloadId(RequestOptions options) {
+ Map headers = options.getHeaders();
+ if (headers != null) {
+ assertThat(headers).doesNotContainKey(WORKLOAD_ID_HEADER);
+ }
+ }
+
+ private static boolean hasSetAdditionalHeaders(Class> clazz) {
+ try {
+ clazz.getMethod("setAdditionalHeaders", Map.class);
+ return true;
+ } catch (NoSuchMethodException e) {
+ return false;
+ }
+ }
+}
+
diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java
index 64cd7fe37115..1b5e3670ace0 100644
--- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java
+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java
@@ -2,14 +2,19 @@
import com.azure.cosmos.ConsistencyLevel;
import com.azure.cosmos.implementation.http.HttpClient;
+import com.azure.cosmos.implementation.http.HttpRequest;
import com.azure.cosmos.implementation.routing.RegionalRoutingContext;
import io.netty.channel.ConnectTimeoutException;
+import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.testng.annotations.Test;
import reactor.core.publisher.Mono;
import java.net.URI;
+import java.util.HashMap;
+import java.util.Map;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
public class ThinClientStoreModelTest {
@@ -44,7 +49,8 @@ public void testThinClientStoreModel() throws Exception {
ConsistencyLevel.SESSION,
new UserAgentContainer(),
globalEndpointManager,
- httpClient);
+ httpClient,
+ null);
RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName(
clientContext,
@@ -58,4 +64,125 @@ public void testThinClientStoreModel() throws Exception {
//no-op
}
}
+
+ /**
+ * Verifies that additionalHeaders (e.g., workload-id) passed to ThinClientStoreModel's
+ * constructor are correctly propagated to the parent RxGatewayStoreModel and injected
+ * into outgoing requests via performRequest().
+ *
+ * This is for the workload-id feature: ThinClientStoreModel extends
+ * RxGatewayStoreModel, and the additionalHeaders must flow through the constructor
+ * chain so that performRequest() injects them into every request — including
+ * metadata requests (collection cache, PKRange cache, etc.).
+ */
+ @Test(groups = "unit")
+ public void testAdditionalHeadersFlowThroughThinClientStoreModel() throws Exception {
+ DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class);
+ Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig();
+ Mockito
+ .doReturn(ImplementationBridgeHelpers
+ .CosmosDiagnosticsHelper
+ .getCosmosDiagnosticsAccessor()
+ .create(clientContext, 1d))
+ .when(clientContext).createDiagnostics();
+
+ String sdkGlobalSessionToken = "1#100#1=20#2=5#3=30";
+ ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class);
+ Mockito.doReturn(sdkGlobalSessionToken).when(sessionContainer).resolveGlobalSessionToken(any());
+
+ GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class);
+ Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080")))
+ .when(globalEndpointManager).resolveServiceEndpoint(any());
+
+ // Capture the HttpRequest sent by performRequest() to verify headers
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ ArgumentCaptor