From f894bf425c9b5be3800a32ffcad4d977a108620d Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:33:05 -0600 Subject: [PATCH 01/13] workload-id feature - initial commit --- .../com/azure/cosmos/CustomHeadersTests.java | 170 +++++++++ .../rntbd/RntbdWorkloadIdTests.java | 115 ++++++ .../azure/cosmos/rx/WorkloadIdE2ETests.java | 327 ++++++++++++++++++ .../com/azure/cosmos/CosmosAsyncClient.java | 1 + .../com/azure/cosmos/CosmosClientBuilder.java | 29 ++ .../implementation/AsyncDocumentClient.java | 9 +- .../cosmos/implementation/HttpConstants.java | 3 + .../implementation/RxDocumentClientImpl.java | 61 ++++ .../rntbd/RntbdConstants.java | 3 +- .../rntbd/RntbdRequestHeaders.java | 19 + .../models/CosmosBatchRequestOptions.java | 13 +- .../models/CosmosBulkExecutionOptions.java | 12 +- .../CosmosChangeFeedRequestOptions.java | 13 +- .../models/CosmosItemRequestOptions.java | 15 +- .../models/CosmosQueryRequestOptions.java | 16 + 15 files changed, 784 insertions(+), 22 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java new file mode 100644 index 000000000000..3c95c8bd2687 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. + *

+ * These tests verify the public API surface: builder fluent methods, getter behavior, + * null/empty handling, and that setHeader() is publicly accessible on all request options classes. + */ +public class CustomHeadersTests { + + /** + * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders() + * are stored correctly and retrievable via getCustomHeaders(). + */ + @Test(groups = { "unit" }) + public void customHeadersSetOnBuilder() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25"); + } + + /** + * Verifies that passing null to customHeaders() does not throw and that + * getCustomHeaders() returns null, ensuring graceful null handling. + */ + @Test(groups = { "unit" }) + public void customHeadersNullHandledGracefully() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(null); + + assertThat(builder.getCustomHeaders()).isNull(); + } + + /** + * Verifies that passing an empty map to customHeaders() is accepted and + * getCustomHeaders() returns an empty (not null) map. + */ + @Test(groups = { "unit" }) + public void customHeadersEmptyMapHandled() { + Map emptyHeaders = new HashMap<>(); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(emptyHeaders); + + assertThat(builder.getCustomHeaders()).isEmpty(); + } + + /** + * Verifies that multiple custom headers can be set at once on the builder and + * all entries are preserved and retrievable with correct keys and values. + */ + @Test(groups = { "unit" }) + public void multipleCustomHeadersSupported() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "15"); + headers.put("x-ms-custom-header", "value"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).hasSize(2); + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15"); + assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value"); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on CRUD operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnItemRequestOptionsIsPublic() { + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "15"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on batch operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBatchRequestOptionsIsPublic() { + CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "20"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on change feed operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnChangeFeedRequestOptionsIsPublic() { + CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setHeader("x-ms-cosmos-workload-id", "25"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions + * (previously package-private) and supports fluent chaining for per-request + * header overrides on bulk ingestion operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnBulkExecutionOptionsIsPublic() { + CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() + .setHeader("x-ms-cosmos-workload-id", "30"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on query operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnQueryRequestOptionsIsPublic() { + CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "35"); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined + * with the correct canonical header name "x-ms-cosmos-workload-id" as expected + * by the Cosmos DB service. + */ + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java new file mode 100644 index 000000000000..9ca123e16160 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdWorkloadIdTests.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.implementation.directconnectivity.rntbd; + +import com.azure.cosmos.implementation.HttpConstants; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for the WorkloadId RNTBD header definition in RntbdConstants. + *

+ * + * These tests verify that the WorkloadId enum entry exists with the correct wire ID (0x00DC), + * correct token type (Byte), is not required, and is not in the thin-client ordered header list + * (so it will be auto-encoded in the second pass of RntbdTokenStream.encode()). + */ +public class RntbdWorkloadIdTests { + + /** + * Verifies that the WORKLOAD_ID HTTP header constant exists in HttpConstants.HttpHeaders + * with the correct canonical name "x-ms-cosmos-workload-id" used in Gateway mode and + * as the lookup key in RntbdRequestHeaders for HTTP-to-RNTBD mapping. + */ + @Test(groups = { "unit" }) + public void workloadIdConstantExists() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that the WorkloadId enum entry exists in RntbdConstants.RntbdRequestHeader + * with the correct wire ID (0x00DC). This ID is used to identify the header in the + * binary RNTBD protocol when communicating in Direct mode. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderExists() { + // Verify WorkloadId enum value exists with correct ID + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader).isNotNull(); + assertThat(workloadIdHeader.id()).isEqualTo((short) 0x00DC); + } + + /** + * Verifies that the WorkloadId RNTBD header is defined as Byte token type, + * consistent with the ThroughputBucket pattern. The workload ID value (1-50) + * is encoded as a single byte on the wire. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsByteType() { + // Verify WorkloadId is Byte type (same as ThroughputBucket pattern) + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.type()).isEqualTo(RntbdTokenType.Byte); + } + + /** + * Verifies that WorkloadId is not a required RNTBD header. The header is optional — + * requests without a workload ID are valid and should not be rejected by the SDK. + */ + @Test(groups = { "unit" }) + public void workloadIdRntbdHeaderIsNotRequired() { + // WorkloadId should not be a required header + RntbdConstants.RntbdRequestHeader workloadIdHeader = RntbdConstants.RntbdRequestHeader.WorkloadId; + assertThat(workloadIdHeader.isRequired()).isFalse(); + } + + /** + * Verifies that WorkloadId is NOT in the thin client ordered header list. Thin client + * mode uses a pre-ordered list of headers for its first encoding pass. WorkloadId is + * excluded from this list and will be auto-encoded in the second pass of + * RntbdTokenStream.encode() along with other non-ordered headers. + */ + @Test(groups = { "unit" }) + public void workloadIdNotInThinClientOrderedList() { + // WorkloadId should NOT be in thinClientHeadersInOrderList + // It will be automatically encoded in the second pass of RntbdTokenStream.encode() + assertThat(RntbdConstants.RntbdRequestHeader.thinClientHeadersInOrderList) + .doesNotContain(RntbdConstants.RntbdRequestHeader.WorkloadId); + } + + /** + * Verifies that valid workload ID values (1-50) can be parsed from String to int + * and cast to byte without data loss. Note: the SDK itself does not validate the + * range — this test confirms the encoding path works for expected values. + */ + @Test(groups = { "unit" }) + public void workloadIdValidValues() { + // Test valid range 1-50 — SDK does NOT validate, just verify the values parse correctly + String[] validValues = {"1", "25", "50"}; + for (String value : validValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + assertThat(byteVal).isBetween((byte) 1, (byte) 50); + } + } + + /** + * Verifies that out-of-range workload ID values (0, 51, -1, 100) do not cause + * exceptions in the SDK's parsing path. The SDK intentionally does not validate + * the range — invalid values are accepted and sent to the service, which silently + * ignores them. + */ + @Test(groups = { "unit" }) + public void workloadIdInvalidValuesAcceptedBySdk() { + // SDK does NOT validate range — service silently ignores invalid values + // These should not throw exceptions in SDK + String[] invalidValues = {"0", "51", "-1", "100"}; + for (String value : invalidValues) { + int parsed = Integer.parseInt(value); + byte byteVal = (byte) parsed; + // SDK accepts any integer value that fits in a byte + assertThat(byteVal).isNotNull(); + } + } +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java new file mode 100644 index 000000000000..85b87090bb36 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.TestConfigurations; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosBulkOperations; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosItemResponse; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * End-to-end integration tests for the custom headers / workload-id feature. + *

+ * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + * These tests create a real database and container, then execute CRUD and query operations + * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level. + *

+ * What is verified: + * 1. CRUD operations succeed with client-level custom headers (workload-id) + * 2. Per-request header overrides work via setHeader() + * 3. Client with no custom headers continues to work (no regression) + * 4. Query operations succeed with workload-id + * 5. Empty headers and multiple headers are handled correctly + *

+ + */ +public class WorkloadIdE2ETests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + public WorkloadIdE2ETests() { + super(new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY)); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + clientWithWorkloadId = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * Smoke test: verifies that a create (POST) operation succeeds when the client + * has a workload-id custom header set at the builder level. Confirms the header + * flows through the request pipeline without causing errors. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithClientLevelWorkloadId() { + // Smoke test: verify create operation succeeds with client-level workload-id header + TestObject doc = TestObject.create(); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a read (GET) operation succeeds with the client-level workload-id + * header and that the correct document is returned. Ensures the header does not + * interfere with normal read semantics. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void readItemWithClientLevelWorkloadId() { + // Verify read operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .readItem(doc.getId(), new PartitionKey(doc.getMypk()), TestObject.class) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getItem().getId()).isEqualTo(doc.getId()); + } + + /** + * Verifies that a replace (PUT) operation succeeds with the client-level workload-id + * header. Confirms the header propagates correctly for update operations. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void replaceItemWithClientLevelWorkloadId() { + // Verify replace operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + doc.setStringProp("updated-" + UUID.randomUUID()); + CosmosItemResponse response = container + .replaceItem(doc, doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(200); + } + + /** + * Verifies that a delete operation succeeds with the client-level workload-id header + * and returns the expected 204 No Content status code. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void deleteItemWithClientLevelWorkloadId() { + // Verify delete operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosItemResponse response = container + .deleteItem(doc.getId(), new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(204); + } + + /** + * Verifies that a per-request workload-id header override via + * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header + * (value "30") should take precedence over the client-level default (value "15"). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void createItemWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override works — request-level should take precedence + TestObject doc = TestObject.create(); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30"); + + CosmosItemResponse response = container + .createItem(doc, new PartitionKey(doc.getMypk()), options) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } + + /** + * Verifies that a cross-partition query operation succeeds when the client has a + * workload-id custom header. Confirms the header flows correctly through the + * query pipeline and does not affect result correctness. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithClientLevelWorkloadId() { + // Verify query operation succeeds with workload-id header + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions(); + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Verifies that a per-request workload-id header override on + * {@code CosmosQueryRequestOptions.setHeader()} works for query operations. + * The request-level header (value "42") should take precedence over the + * client-level default. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void queryItemsWithRequestLevelWorkloadIdOverride() { + // Verify per-request header override on query options works + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions() + .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + long count = container + .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) + .collectList() + .block() + .size(); + + assertThat(count).isGreaterThanOrEqualTo(1); + } + + /** + * Regression test: verifies that a client created without any custom headers + * continues to work normally. Ensures the custom headers feature does not + * introduce regressions for clients that do not use it. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithNoCustomHeadersStillWorks() { + // Verify that a client without custom headers works normally (no regression) + CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithoutHeaders); + } + } + + /** + * Verifies that a client created with an empty custom headers map works normally. + * An empty map should behave identically to no custom headers — no errors, + * no unexpected behavior. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithEmptyCustomHeaders() { + // Verify that a client with empty custom headers map works normally + CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(new HashMap<>()) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithEmptyHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithEmptyHeaders); + } + } + + /** + * Verifies that a client can be configured with multiple custom headers simultaneously + * (workload-id plus an additional custom header). Confirms that all headers flow + * through the pipeline without interfering with each other. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void clientWithMultipleCustomHeaders() { + // Verify that multiple custom headers can be set simultaneously + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); + headers.put("x-ms-custom-test-header", "test-value"); + + CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder() + .endpoint(TestConfigurations.HOST) + .key(TestConfigurations.MASTER_KEY) + .customHeaders(headers) + .buildAsyncClient(); + + try { + CosmosAsyncContainer c = clientWithMultipleHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + CosmosItemResponse response = c + .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) + .block(); + + assertThat(response).isNotNull(); + assertThat(response.getStatusCode()).isEqualTo(201); + } finally { + safeClose(clientWithMultipleHeaders); + } + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java index ec0dd64af008..f54f44482db5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java @@ -186,6 +186,7 @@ public final class CosmosAsyncClient implements Closeable { .withDefaultSerializer(this.defaultCustomSerializer) .withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled()) .withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled()) + .withCustomHeaders(builder.getCustomHeaders()) .build(); this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 12d022e69ee7..aea282be566c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -37,6 +37,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; @@ -155,6 +156,7 @@ public class CosmosClientBuilder implements private boolean serverCertValidationDisabled = false; private Function containerFactory = null; + private Map customHeaders; /** * Instantiates a new Cosmos client builder. @@ -734,6 +736,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { return this; } + /** + * Sets custom HTTP headers that will be included with every request from this client. + *

+ * These headers are sent with all requests. For Direct/RNTBD mode, only known headers + * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers + * work only in Gateway mode. + *

+ * If the same header is also set on request options (e.g., + * {@code CosmosItemRequestOptions.setHeader(String, String)}), + * the request-level value takes precedence over the client-level value. + * + * @param customHeaders map of header name to value + * @return current CosmosClientBuilder + */ + public CosmosClientBuilder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + /** + * Gets the custom headers configured on this builder. + * @return the custom headers map, or null if not set + */ + Map getCustomHeaders() { + return this.customHeaders; + } + /** * Sets the retry policy options associated with the DocumentClient instance. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 03590c1f8a5d..7953721019c5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -116,6 +116,7 @@ class Builder { private boolean isRegionScopedSessionCapturingEnabled; private boolean isPerPartitionAutomaticFailoverEnabled; private List operationPolicies; + private Map customHeaders; public Builder withServiceEndpoint(String serviceEndpoint) { try { @@ -288,6 +289,11 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu return this; } + public Builder withCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + private void ifThrowIllegalArgException(boolean value, String error) { if (value) { throw new IllegalArgumentException(error); @@ -328,7 +334,8 @@ public AsyncDocumentClient build() { defaultCustomSerializer, isRegionScopedSessionCapturingEnabled, operationPolicies, - isPerPartitionAutomaticFailoverEnabled); + isPerPartitionAutomaticFailoverEnabled, + customHeaders); client.init(state, null); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java index 4e283defbc1d..32378ef0cc8d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/HttpConstants.java @@ -298,6 +298,9 @@ public static class HttpHeaders { // Region affinity headers public static final String HUB_REGION_PROCESSING_ONLY = "x-ms-cosmos-hub-region-processing-only"; + + // Workload ID header for Azure Monitor metrics attribution + public static final String WORKLOAD_ID = "x-ms-cosmos-workload-id"; } public static class A_IMHeaderValues { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 192f8175978f..2f0bd4271d86 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -293,6 +293,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization private final AtomicReference cachedCosmosAsyncClientSnapshot; private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads; private Consumer perPartitionFailoverConfigModifier; + private Map customHeaders; public RxDocumentClientImpl(URI serviceEndpoint, String masterKeyOrResourceToken, @@ -366,6 +367,60 @@ public RxDocumentClientImpl(URI serviceEndpoint, boolean isRegionScopedSessionCapturingEnabled, List operationPolicies, boolean isPerPartitionAutomaticFailoverEnabled) { + this( + serviceEndpoint, + masterKeyOrResourceToken, + permissionFeed, + connectionPolicy, + consistencyLevel, + readConsistencyStrategy, + configs, + cosmosAuthorizationTokenResolver, + credential, + tokenCredential, + sessionCapturingOverride, + connectionSharingAcrossClientsEnabled, + contentResponseOnWriteEnabled, + metadataCachesSnapshot, + apiType, + clientTelemetryConfig, + clientCorrelationId, + cosmosEndToEndOperationLatencyPolicyConfig, + sessionRetryOptions, + containerProactiveInitConfig, + defaultCustomSerializer, + isRegionScopedSessionCapturingEnabled, + operationPolicies, + isPerPartitionAutomaticFailoverEnabled, + null + ); + } + + public RxDocumentClientImpl(URI serviceEndpoint, + String masterKeyOrResourceToken, + List permissionFeed, + ConnectionPolicy connectionPolicy, + ConsistencyLevel consistencyLevel, + ReadConsistencyStrategy readConsistencyStrategy, + Configs configs, + CosmosAuthorizationTokenResolver cosmosAuthorizationTokenResolver, + AzureKeyCredential credential, + TokenCredential tokenCredential, + boolean sessionCapturingOverride, + boolean connectionSharingAcrossClientsEnabled, + boolean contentResponseOnWriteEnabled, + CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, + ApiType apiType, + CosmosClientTelemetryConfig clientTelemetryConfig, + String clientCorrelationId, + CosmosEndToEndOperationLatencyPolicyConfig cosmosEndToEndOperationLatencyPolicyConfig, + SessionRetryOptions sessionRetryOptions, + CosmosContainerProactiveInitConfig containerProactiveInitConfig, + CosmosItemSerializer defaultCustomSerializer, + boolean isRegionScopedSessionCapturingEnabled, + List operationPolicies, + boolean isPerPartitionAutomaticFailoverEnabled, + Map customHeaders) { this( serviceEndpoint, masterKeyOrResourceToken, @@ -392,6 +447,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver; this.operationPolicies = operationPolicies; + this.customHeaders = customHeaders; } private RxDocumentClientImpl(URI serviceEndpoint, @@ -1884,6 +1940,11 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) { Map headers = new HashMap<>(); + // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders()) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + headers.putAll(this.customHeaders); + } + if (this.useMultipleWriteLocations) { headers.put(HttpConstants.HttpHeaders.ALLOW_TENTATIVE_WRITES, Boolean.TRUE.toString()); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index ba3ec8d2017d..d79231793679 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -598,7 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader { PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), - HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); + HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false), + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false); public static final List thinClientHeadersInOrderList = Arrays.asList( EffectivePartitionKey, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index 6f6e46ee695d..387e8cf3ed5a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonFilter; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -51,6 +53,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { // region Fields + private static final Logger logger = LoggerFactory.getLogger(RntbdRequestHeaders.class); private static final String URL_TRIM = "/"; // endregion @@ -134,6 +137,7 @@ final class RntbdRequestHeaders extends RntbdTokenStream { this.addGlobalDatabaseAccountName(headers); this.addThroughputBucket(headers); this.addHubRegionProcessingOnly(headers); + this.addWorkloadId(headers); // Normal headers (Strings, Ints, Longs, etc.) @@ -297,6 +301,8 @@ private RntbdToken getCorrelatedActivityId() { private RntbdToken getHubRegionProcessingOnly() { return this.get(RntbdRequestHeader.HubRegionProcessingOnly); } + private RntbdToken getWorkloadId() { return this.get(RntbdRequestHeader.WorkloadId); } + private RntbdToken getGlobalDatabaseAccountName() { return this.get(RntbdRequestHeader.GlobalDatabaseAccountName); } @@ -816,6 +822,19 @@ private void addHubRegionProcessingOnly(final Map headers) { } } + private void addWorkloadId(final Map headers) { + final String value = headers.get(HttpHeaders.WORKLOAD_ID); + + if (StringUtils.isNotEmpty(value)) { + try { + final int workloadId = Integer.valueOf(value); + this.getWorkloadId().setValue((byte) workloadId); + } catch (NumberFormatException e) { + logger.warn("Invalid value for workload id header: {}", value, e); + } + } + } + private void addGlobalDatabaseAccountName(final Map headers) { final String value = headers.get(HttpHeaders.GLOBAL_DATABASE_ACCOUNT_NAME); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index 7d5a27324f95..3183fe59bdea 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -154,14 +154,17 @@ RequestOptions toRequestOptions() { } /** - * Sets the custom batch request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBatchRequestOptions. */ - CosmosBatchRequestOptions setHeader(String name, String value) { + public CosmosBatchRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index f125c02d6725..cd688f8a0da6 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -257,13 +257,17 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat } /** - * Sets the custom bulk request option value by key + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosBulkExecutionOptions. */ - CosmosBulkExecutionOptions setHeader(String name, String value) { + public CosmosBulkExecutionOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index 3ac526de6d63..a1b675f2ffd8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -564,14 +564,17 @@ public List getExcludedRegions() { } /** - * Sets the custom change feed request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosChangeFeedRequestOptions. */ - CosmosChangeFeedRequestOptions setHeader(String name, String value) { + public CosmosChangeFeedRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index 72eb108a6428..fbc540e5baeb 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -566,14 +566,17 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre } /** - * Sets the custom item request option value by key - * - * @param name a string representing the custom option's name - * @param value a string representing the custom option's value - * + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") * @return the CosmosItemRequestOptions. */ - CosmosItemRequestOptions setHeader(String name, String value) { + public CosmosItemRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index 7ead6e208781..f0de81bbf823 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -260,6 +260,22 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) return this; } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosQueryRequestOptions. + */ + public CosmosQueryRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + /** * Gets the list of regions to exclude for the request/retries. These regions are excluded * from the preferred region list. From a75ab7b1bd494dfec12ad501b23923a101b1cba4 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 17:57:37 -0600 Subject: [PATCH 02/13] workload-id feature - initial commit(Spark) --- .../cosmos/spark/CosmosClientCache.scala | 15 +- .../spark/CosmosClientConfiguration.scala | 8 +- .../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++- .../cosmos/spark/CosmosClientCacheITest.scala | 17 +- .../spark/CosmosClientConfigurationSpec.scala | 68 ++++++++ .../spark/CosmosPartitionPlannerSpec.scala | 24 ++- .../cosmos/spark/PartitionMetadataSpec.scala | 48 ++++-- .../spark/SparkE2EWorkloadIdITest.scala | 150 ++++++++++++++++++ 8 files changed, 323 insertions(+), 38 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index e61a271aeb8b..9ad739d8f3f6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -42,7 +42,6 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import java.util.function.BiPredicate import scala.collection.concurrent.TrieMap - // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import @@ -713,6 +712,12 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + // Apply custom HTTP headers (e.g., workload-id) to the builder if configured. + // These headers are attached to every Cosmos DB request made by this client instance. + if (cosmosClientConfiguration.customHeaders.isDefined) { + builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava) + } + var client = builder.buildAsyncClient() if (cosmosClientConfiguration.clientInterceptors.isDefined) { @@ -916,7 +921,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Custom HTTP headers are part of the cache key because different workload-ids + // should produce different CosmosAsyncClient instances + customHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -935,7 +943,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientBuilderInterceptors, clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, - clientConfig.azureMonitorConfig + clientConfig.azureMonitorConfig, + clientConfig.customHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 6f4e26e1f503..61fa0957af83 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -30,7 +30,10 @@ private[spark] case class CosmosClientConfiguration ( clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], - azureMonitorConfig: Option[AzureMonitorConfig] + azureMonitorConfig: Option[AzureMonitorConfig], + // Optional custom HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.customHeaders() + customHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -94,7 +97,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientBuilderInterceptors, cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, - diagnosticsConfig.azureMonitorConfig + diagnosticsConfig.azureMonitorConfig, + cosmosAccountConfig.customHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index eef3f6ae1f8d..6646b2e69ae5 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -34,6 +34,7 @@ import java.time.format.DateTimeFormatter import java.time.{Duration, Instant} import java.util import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ // scalastyle:ignore underscore.import import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable @@ -150,6 +151,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" + // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} + // Flows through to CosmosClientBuilder.customHeaders(). + val CustomHeaders = "spark.cosmos.customHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -295,7 +300,8 @@ private[spark] object CosmosConfigNames { WriteOnRetryCommitInterceptor, WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, - WriteMaxRetryNoProgressIntervalInSeconds + WriteMaxRetryNoProgressIntervalInSeconds, + CustomHeaders ) def validateConfigName(name: String): Unit = { @@ -538,7 +544,10 @@ private case class CosmosAccountConfig(endpoint: String, resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], - clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + // Optional custom HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder + customHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -719,6 +728,19 @@ private object CosmosAccountConfig extends BasicLoggingTrait { parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like + // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. + // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache. + private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.CustomHeaders, + mandatory = false, + parseFromStringFunction = headersJson => { + val mapper = new com.fasterxml.jackson.databind.ObjectMapper() + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + mapper.readValue(headersJson, typeRef).asScala.toMap + }, + helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -753,6 +775,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) + // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -864,7 +888,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { resourceGroupNameOpt, azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, - if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, + customHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index ccf36791dc96..4d542c44612e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -64,7 +64,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -91,7 +92,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -118,7 +120,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ), ( @@ -145,7 +148,8 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) ) ) @@ -179,8 +183,9 @@ class CosmosClientCacheITest clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None - ) + azureMonitorConfig = None, + customHeaders = None + ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 7fcc601ba016..a0627c0cf3dd 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -408,4 +408,72 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.applicationName shouldEqual s"${CosmosConstants.userAgentSuffix}|$sparkEnvironmentInfo|${ManagementFactory.getRuntimeMXBean.getName}|$myApp".replace("@", " ") configuration.azureMonitorConfig shouldEqual None } + + // Verifies that the spark.cosmos.customHeaders configuration option correctly parses + // a JSON string containing a single workload-id header into a Map[String, String] on + // CosmosClientConfiguration. This is the primary use case for the workload-id feature. + it should "parse customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + } + + // Verifies that when spark.cosmos.customHeaders is not specified in the config map, + // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set customHeaders continue to work without changes. + it should "handle missing customHeaders" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe None + } + + // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing + // multiple headers into a Map with all entries preserved. This supports use cases where + // multiple custom headers need to be sent alongside workload-id. + it should "parse multiple custom headers" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get should have size 2 + configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20" + configuration.customHeaders.get("x-custom-header") shouldEqual "value" + } + + // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, + // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases + // and that no headers are injected when the JSON object is empty. + it should "handle empty customHeaders JSON" in { + val userConfig = Map( + "spark.cosmos.accountEndpoint" -> "https://localhost:8081", + "spark.cosmos.accountKey" -> "xyz", + "spark.cosmos.customHeaders" -> "{}" + ) + + val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT + val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + + configuration.customHeaders shouldBe defined + configuration.customHeaders.get shouldBe empty + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 6ef90b55989d..ab73dc4e54d3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -39,7 +39,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -116,7 +117,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -193,7 +195,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -270,7 +273,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -345,7 +349,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -436,7 +441,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -510,7 +516,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -576,7 +583,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index dfd14c36c80f..65274bee2b19 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -38,7 +38,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -84,7 +85,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -169,7 +171,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -254,7 +257,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -321,7 +325,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -383,7 +388,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -439,7 +445,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -495,7 +502,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -551,7 +559,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -607,7 +616,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -686,7 +696,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -747,7 +758,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -803,7 +815,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -876,7 +889,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -949,7 +963,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1027,7 +1042,8 @@ class PartitionMetadataSpec extends UnitSpec { clientBuilderInterceptors = None, clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, - azureMonitorConfig = None + azureMonitorConfig = None, + customHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala new file mode 100644 index 000000000000..d9706d0709e5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +/** + * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector. + * + * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows + * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that + * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations + * initiated via Spark DataFrames (reads and writes). + * + * Requires the Cosmos DB Emulator running + */ +class SparkE2EWorkloadIdITest + extends IntegrationSpec + with Spark + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait { + + val objectMapper = new ObjectMapper() + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The header should be passed through to the + // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior. + "spark query with customHeaders" can "read items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } + + // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders + // is configured with a workload-id header. The item is written via Spark and then verified + // via a direct SDK read to confirm the write was persisted correctly. + "spark write with customHeaders" can "write items with workload-id header" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "testWriteItem" + | } + |""".stripMargin + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.write.strategy" -> "ItemOverwrite", + "spark.cosmos.write.bulk.enabled" -> "false", + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + val spark_session = spark + import spark_session.implicits._ + val df = spark.read.json(Seq(rawItem).toDS()) + + df.write.format("cosmos.oltp").options(cfg).mode("Append").save() + + // Verify the item was written by reading it back via the SDK directly + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + val readItem = container.readItem(id, new com.azure.cosmos.models.PartitionKey(id), classOf[ObjectNode]).block() + readItem.getItem.get("id").textValue() shouldEqual id + readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" + } + + // Regression test: verifies that Spark read operations continue to work correctly when + // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not + // break existing behavior for clients that do not use custom headers. + "spark operations without customHeaders" can "still succeed" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + val rawItem = + s""" + | { + | "id" : "$id", + | "name" : "noHeadersItem" + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where(s"id = '$id'").collect() + rowsArray should have size 1 + rowsArray(0).getAs[String]("id") shouldEqual id + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals + //scalastyle:on null +} From 1bbf6b9bb962fec1f576fbf3a4b0863a0a132099 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:08:04 -0600 Subject: [PATCH 03/13] workload-id feature - cleaning up comments --- .../com/azure/cosmos/rx/WorkloadIdE2ETests.java | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index 85b87090bb36..a57b4d9d9a0b 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -32,17 +32,6 @@ * End-to-end integration tests for the custom headers / workload-id feature. *

* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. - * These tests create a real database and container, then execute CRUD and query operations - * with the {@code x-ms-cosmos-workload-id} header set at client level and/or request level. - *

- * What is verified: - * 1. CRUD operations succeed with client-level custom headers (workload-id) - * 2. Per-request header overrides work via setHeader() - * 3. Client with no custom headers continues to work (no regression) - * 4. Query operations succeed with workload-id - * 5. Empty headers and multiple headers are handled correctly - *

- */ public class WorkloadIdE2ETests extends TestSuiteBase { @@ -82,13 +71,12 @@ public void beforeClass() { } /** - * Smoke test: verifies that a create (POST) operation succeeds when the client + * verifies that a create (POST) operation succeeds when the client * has a workload-id custom header set at the builder level. Confirms the header * flows through the request pipeline without causing errors. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void createItemWithClientLevelWorkloadId() { - // Smoke test: verify create operation succeeds with client-level workload-id header TestObject doc = TestObject.create(); CosmosItemResponse response = container From 7aeb8b4e030958f77e14983fc5f0d3475333effb Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:20:36 -0600 Subject: [PATCH 04/13] workload-id feature - addressing copilot comments --- sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + .../directconnectivity/rntbd/RntbdConstants.java | 4 ++-- .../directconnectivity/rntbd/RntbdRequestHeaders.java | 2 +- 8 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index de77b69485a8..2eeccc69dada 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 80072357c58f..6e72286dfab6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index b905a025a1e6..d0d80af466d6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index e17cecdfdac2..ae444a9c399f 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index e9be63ef89bd..ae910280495b 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.44.0-beta.1 (Unreleased) #### Features Added +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 8475b83dc5ee..ea4fdb82e1dc 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -4,6 +4,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) +* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java index d79231793679..d75bf5dc88e1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdConstants.java @@ -598,8 +598,8 @@ public enum RntbdRequestHeader implements RntbdHeader { PriorityLevel((short) 0x00BF, RntbdTokenType.Byte, false), GlobalDatabaseAccountName((short) 0x00CE, RntbdTokenType.String, false), ThroughputBucket((short)0x00DB, RntbdTokenType.Byte, false), - HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false), - WorkloadId((short)0x00DC, RntbdTokenType.Byte, false); + WorkloadId((short)0x00DC, RntbdTokenType.Byte, false), + HubRegionProcessingOnly((short)0x00EF, RntbdTokenType.Byte , false); public static final List thinClientHeadersInOrderList = Arrays.asList( EffectivePartitionKey, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java index 387e8cf3ed5a..46f8060387fc 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestHeaders.java @@ -827,7 +827,7 @@ private void addWorkloadId(final Map headers) { if (StringUtils.isNotEmpty(value)) { try { - final int workloadId = Integer.valueOf(value); + final int workloadId = Integer.parseInt(value); this.getWorkloadId().setValue((byte) workloadId); } catch (NumberFormatException e) { logger.warn("Invalid value for workload id header: {}", value, e); From 6d00d49320caa42432aa26725befe529e5934ad4 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:38:14 -0600 Subject: [PATCH 05/13] workload-id feature - addressing comments --- .../com/azure/cosmos/spark/CosmosConfig.scala | 13 +- .../spark/CosmosClientConfigurationSpec.scala | 13 +- .../com/azure/cosmos/CustomHeadersTests.java | 104 ++++++++++- .../RxDocumentClientUnderTest.java | 7 +- .../RxGatewayStoreModelTest.java | 172 +++++++++++++++++- .../SpyClientUnderTestFactory.java | 7 +- .../GatewayAddressCacheTest.java | 125 +++++++++++++ .../GlobalAddressResolverTest.java | 3 +- .../azure/cosmos/rx/WorkloadIdE2ETests.java | 62 ++----- .../com/azure/cosmos/CosmosClientBuilder.java | 50 ++++- .../implementation/RxDocumentClientImpl.java | 12 +- .../implementation/RxGatewayStoreModel.java | 17 +- .../implementation/ThinClientStoreModel.java | 3 +- .../GatewayAddressCache.java | 17 +- .../GlobalAddressResolver.java | 8 +- .../models/CosmosReadManyRequestOptions.java | 16 ++ 16 files changed, 546 insertions(+), 83 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 6646b2e69ae5..62802d23b14c 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -7,7 +7,7 @@ import com.azure.core.management.AzureEnvironment import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper -import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} +import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerIdentity, CosmosParameterizedQuery, DedicatedGatewayRequestOptions, FeedRange, PartitionKeyDefinition} import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} @@ -735,9 +735,14 @@ private object CosmosAccountConfig extends BasicLoggingTrait { key = CosmosConfigNames.CustomHeaders, mandatory = false, parseFromStringFunction = headersJson => { - val mapper = new com.fasterxml.jackson.databind.ObjectMapper() - val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} - mapper.readValue(headersJson, typeRef).asScala.toMap + try { + val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} + Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + } catch { + case e: Exception => throw new IllegalArgumentException( + s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " + + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) + } }, helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index a0627c0cf3dd..377425189f07 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -441,10 +441,11 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.customHeaders shouldBe None } - // Verifies that spark.cosmos.customHeaders correctly parses a JSON string containing - // multiple headers into a Map with all entries preserved. This supports use cases where - // multiple custom headers need to be sent alongside workload-id. - it should "parse multiple custom headers" in { + // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level. + // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD), + // unknown headers are silently dropped, so the allowlist ensures consistent behavior + // across Gateway and Direct modes. + it should "reject unknown custom headers" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", @@ -454,10 +455,10 @@ class CosmosClientConfigurationSpec extends UnitSpec { val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is. + // The allowlist validation happens later in CosmosClientBuilder.customHeaders() configuration.customHeaders shouldBe defined configuration.customHeaders.get should have size 2 - configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "20" - configuration.customHeaders.get("x-custom-header") shouldEqual "value" } // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java index 3c95c8bd2687..19eb03744d1a 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java @@ -9,6 +9,7 @@ import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; import com.azure.cosmos.models.CosmosItemRequestOptions; import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; import com.azure.cosmos.models.FeedRange; import org.testng.annotations.Test; @@ -16,6 +17,7 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. @@ -73,23 +75,40 @@ public void customHeadersEmptyMapHandled() { } /** - * Verifies that multiple custom headers can be set at once on the builder and - * all entries are preserved and retrievable with correct keys and values. + * Verifies that headers not in the allowlist are rejected with IllegalArgumentException. + * This ensures consistent behavior across Gateway and Direct modes — only headers with + * RNTBD encoding support are allowed. */ @Test(groups = { "unit" }) - public void multipleCustomHeadersSupported() { + public void unknownHeaderRejectedByAllowlist() { Map headers = new HashMap<>(); - headers.put("x-ms-cosmos-workload-id", "15"); headers.put("x-ms-custom-header", "value"); - CosmosClientBuilder builder = new CosmosClientBuilder() + assertThatThrownBy(() -> new CosmosClientBuilder() .endpoint("https://test.documents.azure.com:443/") .key("dGVzdEtleQ==") - .customHeaders(headers); + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header") + .hasMessageContaining("not allowed"); + } + + /** + * Verifies that a map containing both an allowed header and a disallowed header + * is rejected — the entire map must pass the allowlist check. + */ + @Test(groups = { "unit" }) + public void mixedAllowedAndDisallowedHeadersRejected() { + Map headers = new HashMap<>(); + headers.put("x-ms-cosmos-workload-id", "15"); + headers.put("x-ms-custom-header", "value"); - assertThat(builder.getCustomHeaders()).hasSize(2); - assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "15"); - assertThat(builder.getCustomHeaders()).containsEntry("x-ms-custom-header", "value"); + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header"); } /** @@ -158,6 +177,19 @@ public void setHeaderOnQueryRequestOptionsIsPublic() { assertThat(options).isNotNull(); } + /** + * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions + * is publicly accessible and supports fluent chaining for per-request header + * overrides on read-many operations. + */ + @Test(groups = { "unit" }) + public void setHeaderOnReadManyRequestOptionsIsPublic() { + CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() + .setHeader("x-ms-cosmos-workload-id", "40"); + + assertThat(options).isNotNull(); + } + /** * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined * with the correct canonical header name "x-ms-cosmos-workload-id" as expected @@ -167,4 +199,58 @@ public void setHeaderOnQueryRequestOptionsIsPublic() { public void workloadIdHttpHeaderConstant() { assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); } + + /** + * Verifies that a non-numeric workload-id value is rejected at builder level with + * IllegalArgumentException. This covers both Gateway and Direct modes consistently + * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. + * Range validation [1, 50] is the backend's responsibility — the SDK only validates + * that the value is a valid integer. This avoids hardcoding a range the backend team + * might change in the future. + */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + } + + /** + * Verifies that a valid workload-id value passes builder validation. + */ + @Test(groups = { "unit" }) + public void validWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .customHeaders(headers); + + assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java index a9f5cb35549c..d5f8b92ac7a6 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.mockito.Mockito.doAnswer; @@ -75,7 +76,8 @@ RxGatewayStoreModel createRxGatewayProxy( GlobalEndpointManager globalEndpointManager, GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker, HttpClient rxOrigClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { origHttpClient = rxOrigClient; spyHttpClient = Mockito.spy(rxOrigClient); @@ -93,6 +95,7 @@ RxGatewayStoreModel createRxGatewayProxy( userAgentContainer, globalEndpointManager, spyHttpClient, - apiType); + apiType, + customHeaders); } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java index 54440ecfabc5..587844f4043a 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java @@ -27,6 +27,8 @@ import java.net.SocketException; import java.net.URI; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext; @@ -102,6 +104,7 @@ public void readTimeout() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -146,6 +149,7 @@ public void serviceUnavailable() throws Exception { userAgentContainer, globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -205,7 +209,8 @@ public void applySessionToken( new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( @@ -277,7 +282,8 @@ public void validateApiType() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, - apiType); + apiType, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -391,6 +397,7 @@ private boolean runCancelAfterRetainIteration() throws Exception { new UserAgentContainer(), globalEndpointManager, httpClient, + null, null); storeModel.setGatewayServiceConfigurationReader(gatewayServiceConfigurationReader); @@ -428,6 +435,167 @@ private boolean runCancelAfterRetainIteration() throws Exception { return false; } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are injected into + * outgoing HTTP requests by performRequest(). This covers metadata requests + * (collection cache, partition key range) that don't go through getRequestHeaders(). + */ + @Test(groups = "unit") + public void customHeadersInjectedInPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("25"); + } + + /** + * Verifies that request-level headers take precedence over client-level customHeaders. + * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest() + * should NOT overwrite it. + */ + @Test(groups = "unit") + public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + customHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col/docs/doc1", + ResourceType.Document); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + // Simulate request-level header already set (e.g., by getRequestHeaders()) + dsr.getHeaders().put(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // Request-level header "42" should win over client-level "10" + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isEqualTo("42"); + } + + /** + * Verifies that when customHeaders is null, performRequest() still works normally + * without injecting any extra headers. + */ + @Test(groups = "unit") + public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception { + DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); + + RxGatewayStoreModel storeModel = new RxGatewayStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + QueryCompatibilityMode.Default, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/dbs/db/colls/col", + ResourceType.DocumentCollection); + dsr.requestContext = new DocumentServiceRequestContext(); + dsr.requestContext.regionalRoutingContextToRoute = new RegionalRoutingContext(new URI("https://localhost")); + + try { + storeModel.performRequest(dsr).block(); + fail("Request should fail"); + } catch (Exception e) { + // expected + } + + Mockito.verify(httpClient).send(httpClientRequestCaptor.capture(), any()); + HttpRequest httpRequest = httpClientRequestCaptor.getValue(); + HttpHeaders headers = ReflectionUtils.getHttpHeaders(httpRequest); + // No workload-id header should be present + assertThat(headers.toMap().get(HttpConstants.HttpHeaders.WORKLOAD_ID)).isNull(); + } + enum SessionTokenType { NONE, // no session token applied USER, // userControlled session token diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java index b06d6f89b8e9..775b74785630 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.stream.Collectors; @@ -126,7 +127,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient rxClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.origRxGatewayStoreModel = super.createRxGatewayProxy( sessionContainer, consistencyLevel, @@ -134,7 +136,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, rxClient, - apiType); + apiType, + customHeaders); this.requests = Collections.synchronizedList(new ArrayList<>()); this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel); this.initRequestCapture(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java index 172c00f799bc..9b938d0a1520 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java @@ -57,6 +57,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -145,6 +146,7 @@ public void getServerAddressesViaGateway(List partitionKeyRangeIds, null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -186,6 +188,7 @@ public void getMasterAddressesViaGatewayAsync(Protocol protocol) throws Exceptio null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); for (int i = 0; i < 2; i++) { @@ -238,6 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -296,6 +300,7 @@ public void tryGetAddresses_ForDataPartitions_AddressCachedByOpenAsync_NoHttpReq null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -366,6 +371,7 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -472,6 +478,7 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh( null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); String collectionRid = createdCollection.getResourceId(); @@ -614,6 +621,7 @@ public void tryGetAddresses_ForMasterPartition(Protocol protocol) throws Excepti null, null, null, + null, null); RxDocumentServiceRequest req = @@ -666,6 +674,7 @@ public void tryGetAddresses_ForMasterPartition_MasterPartitionAddressAlreadyCach null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -717,6 +726,7 @@ public void tryGetAddresses_ForMasterPartition_ForceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); RxDocumentServiceRequest req = @@ -775,6 +785,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_NotStaleEnough_NoRefresh() null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -873,6 +884,7 @@ public void tryGetAddresses_SuboptimalMasterPartition_Stale_DoRefresh() throws E null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); GatewayAddressCache spyCache = Mockito.spy(origCache); @@ -990,6 +1002,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1152,6 +1165,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1214,6 +1228,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception { null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1304,6 +1319,7 @@ public void tryGetAddress_repeatedlySetUnhealthyStatus_forceRefresh() throws Int null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); RxDocumentServiceRequest req = @@ -1396,6 +1412,7 @@ public void validateReplicaAddressesTests(boolean isCollectionUnderWarmUpFlow) t null, ConnectionPolicy.getDefaultPolicy(), proactiveOpenConnectionsProcessorMock, + null, null); Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask); @@ -1495,6 +1512,7 @@ public void mergeAddressesTests() throws URISyntaxException, NoSuchMethodExcepti null, ConnectionPolicy.getDefaultPolicy(), null, + null, null); // connected status @@ -1628,6 +1646,113 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs return new HttpClientUnderTestWrapper(origHttpClient); } + /** + * Verifies that client-level customHeaders (e.g., workload-id) are included in + * GatewayAddressCache's defaultRequestHeaders, which are sent on every address + * resolution request. + */ + @Test(groups = { "unit" }) + public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) + * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers + * set before customHeaders are preserved. + */ + @Test(groups = { "unit" }) + public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + Map customHeaders = new HashMap<>(); + customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); + customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); + customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + customHeaders); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // SDK headers should NOT be overwritten + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent"); + assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION); + // Custom header should still be added + assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + } + + /** + * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders + * contains only SDK system headers and no extra entries. + */ + @Test(groups = { "unit" }) + public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { + URI serviceEndpoint = new URI("https://localhost"); + + GatewayAddressCache cache = new GatewayAddressCache( + mockDiagnosticsClientContext(), + serviceEndpoint, + Protocol.HTTPS, + Mockito.mock(IAuthorizationTokenProvider.class), + null, + Mockito.mock(HttpClient.class), + null, + null, + null, + null, + null, + null); + + Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); + defaultRequestHeadersField.setAccessible(true); + @SuppressWarnings("unchecked") + HashMap defaultRequestHeaders = (HashMap) defaultRequestHeadersField.get(cache); + + // Should only contain SDK system headers, no workload-id + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.USER_AGENT); + assertThat(defaultRequestHeaders).containsKey(HttpConstants.HttpHeaders.VERSION); + assertThat(defaultRequestHeaders).doesNotContainKey(HttpConstants.HttpHeaders.WORKLOAD_ID); + } + public String getNameBasedCollectionLink() { return "dbs/" + createdDatabase.getId() + "/colls/" + createdCollection.getId(); } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java index 331be53cc7af..5879e7d3e61c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolverTest.java @@ -110,7 +110,7 @@ public void resolveAsync() throws Exception { GlobalAddressResolver globalAddressResolver = new GlobalAddressResolver(mockDiagnosticsClientContext(), httpClient, endpointManager, Protocol.HTTPS, authorizationTokenProvider, collectionCache, routingMapProvider, userAgentContainer, - serviceConfigReader, connectionPolicy, null); + serviceConfigReader, connectionPolicy, null, null); RxDocumentServiceRequest request; request = RxDocumentServiceRequest.createFromName(mockDiagnosticsClientContext(), OperationType.Read, @@ -145,6 +145,7 @@ public void submitOpenConnectionTasksAndInitCaches() { userAgentContainer, serviceConfigReader, connectionPolicy, + null, null); GlobalAddressResolver.EndpointCache endpointCache = new GlobalAddressResolver.EndpointCache(); GatewayAddressCache gatewayAddressCache = Mockito.mock(GatewayAddressCache.class); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index a57b4d9d9a0b..3bf2fdafce7c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -8,9 +8,6 @@ import com.azure.cosmos.CosmosClientBuilder; import com.azure.cosmos.TestObject; import com.azure.cosmos.implementation.HttpConstants; -import com.azure.cosmos.implementation.TestConfigurations; -import com.azure.cosmos.models.CosmosBulkExecutionOptions; -import com.azure.cosmos.models.CosmosBulkOperations; import com.azure.cosmos.models.CosmosContainerProperties; import com.azure.cosmos.models.CosmosItemRequestOptions; import com.azure.cosmos.models.CosmosItemResponse; @@ -19,6 +16,7 @@ import com.azure.cosmos.models.PartitionKeyDefinition; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; import org.testng.annotations.Test; import java.util.ArrayList; @@ -32,6 +30,10 @@ * End-to-end integration tests for the custom headers / workload-id feature. *

* Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests + * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC), + * ensuring the workload-id header is correctly encoded and sent in both transport paths. */ public class WorkloadIdE2ETests extends TestSuiteBase { @@ -42,10 +44,9 @@ public class WorkloadIdE2ETests extends TestSuiteBase { private CosmosAsyncDatabase database; private CosmosAsyncContainer container; - public WorkloadIdE2ETests() { - super(new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY)); + @Factory(dataProvider = "simpleClientBuilderGatewaySession") + public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); } @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) @@ -53,9 +54,7 @@ public void beforeClass() { Map headers = new HashMap<>(); headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); - clientWithWorkloadId = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + clientWithWorkloadId = getClientBuilder() .customHeaders(headers) .buildAsyncClient(); @@ -218,9 +217,7 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() { @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void clientWithNoCustomHeadersStillWorks() { // Verify that a client without custom headers works normally (no regression) - CosmosAsyncClient clientWithoutHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) .buildAsyncClient(); try { @@ -248,9 +245,7 @@ public void clientWithNoCustomHeadersStillWorks() { @Test(groups = { "emulator" }, timeOut = TIMEOUT) public void clientWithEmptyCustomHeaders() { // Verify that a client with empty custom headers map works normally - CosmosAsyncClient clientWithEmptyHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) + CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder()) .customHeaders(new HashMap<>()) .buildAsyncClient(); @@ -272,38 +267,19 @@ public void clientWithEmptyCustomHeaders() { } /** - * Verifies that a client can be configured with multiple custom headers simultaneously - * (workload-id plus an additional custom header). Confirms that all headers flow - * through the pipeline without interfering with each other. + * Verifies that unknown headers in customHeaders are rejected by the allowlist. + * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist + * ensures consistent behavior across Gateway and Direct modes. */ - @Test(groups = { "emulator" }, timeOut = TIMEOUT) - public void clientWithMultipleCustomHeaders() { - // Verify that multiple custom headers can be set simultaneously + @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void unknownCustomHeadersRejectedByAllowlist() { Map headers = new HashMap<>(); headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); headers.put("x-ms-custom-test-header", "test-value"); - CosmosAsyncClient clientWithMultipleHeaders = new CosmosClientBuilder() - .endpoint(TestConfigurations.HOST) - .key(TestConfigurations.MASTER_KEY) - .customHeaders(headers) - .buildAsyncClient(); - - try { - CosmosAsyncContainer c = clientWithMultipleHeaders - .getDatabase(DATABASE_ID) - .getContainer(CONTAINER_ID); - - TestObject doc = TestObject.create(); - CosmosItemResponse response = c - .createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()) - .block(); - - assertThat(response).isNotNull(); - assertThat(response.getStatusCode()).isEqualTo(201); - } finally { - safeClose(clientWithMultipleHeaders); - } + // Should throw IllegalArgumentException due to unknown header + copyCosmosClientBuilder(getClientBuilder()) + .customHeaders(headers); } @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index aea282be566c..e4cdf6ca1e30 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -13,6 +13,7 @@ import com.azure.cosmos.implementation.ConnectionPolicy; import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot; import com.azure.cosmos.implementation.DiagnosticsProvider; +import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.implementation.Strings; import com.azure.cosmos.implementation.WriteRetryPolicy; import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList; @@ -158,6 +159,20 @@ public class CosmosClientBuilder implements private Function containerFactory = null; private Map customHeaders; + /** + * Allowlist of headers permitted in {@link #customHeaders(Map)}. + *

+ * In Direct mode (RNTBD), only headers with explicit encoding support in + * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped. + * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header + * is allowed here, it works in both modes. To add a new allowed header, you must also add + * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry + + * {@code RntbdRequestHeaders.addXxx()} method). + */ + private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet( + new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID)) + ); + /** * Instantiates a new Cosmos client builder. */ @@ -739,9 +754,13 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { /** * Sets custom HTTP headers that will be included with every request from this client. *

- * These headers are sent with all requests. For Direct/RNTBD mode, only known headers - * (like {@code x-ms-cosmos-workload-id}) will be encoded and sent. Unknown headers - * work only in Gateway mode. + * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is + * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw + * {@link IllegalArgumentException}. + *

+ * This restriction exists because in Direct mode (RNTBD), only headers with explicit + * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist + * ensures consistent behavior across both Gateway and Direct modes. *

* If the same header is also set on request options (e.g., * {@code CosmosItemRequestOptions.setHeader(String, String)}), @@ -749,8 +768,33 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { * * @param customHeaders map of header name to value * @return current CosmosClientBuilder + * @throws IllegalArgumentException if any header key is not in the allowlist, or if the + * workload-id value is not a valid integer */ public CosmosClientBuilder customHeaders(Map customHeaders) { + if (customHeaders != null) { + for (Map.Entry entry : customHeaders.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + + if (!ALLOWED_CUSTOM_HEADERS.contains(key)) { + throw new IllegalArgumentException( + "Header '" + key + "' is not allowed in customHeaders. " + + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS); + } + + // Validate workload-id value is a valid integer (range validation is left to the backend) + if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) { + try { + Integer.parseInt(value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + value + "' for header '" + key + + "'. The value must be a valid integer.", e); + } + } + } + } this.customHeaders = customHeaders; return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 2f0bd4271d86..122542a8810a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -863,7 +863,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.userAgentContainer, this.globalEndpointManager, this.reactorHttpClient, - this.apiType); + this.apiType, + this.customHeaders); this.thinProxy = createThinProxy(this.sessionContainer, this.consistencyLevel, @@ -969,7 +970,8 @@ private void initializeDirectConnectivity() { // this.gatewayConfigurationReader, null, this.connectionPolicy, - this.apiType); + this.apiType, + this.customHeaders); this.storeClientFactory = new StoreClientFactory( this.addressResolver, @@ -1013,7 +1015,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { return new RxGatewayStoreModel( this, sessionContainer, @@ -1022,7 +1025,8 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, userAgentContainer, globalEndpointManager, httpClient, - apiType); + apiType, + customHeaders); } ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java index 979c528b32bb..a197723f0ae3 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java @@ -91,6 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize private GatewayServiceConfigurationReader gatewayServiceConfigurationReader; private RxClientCollectionCache collectionCache; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public RxGatewayStoreModel( DiagnosticsClientContext clientContext, @@ -100,7 +101,8 @@ public RxGatewayStoreModel( UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.clientContext = clientContext; @@ -116,6 +118,7 @@ public RxGatewayStoreModel( this.httpClient = httpClient; this.sessionContainer = sessionContainer; + this.customHeaders = customHeaders; } public RxGatewayStoreModel(RxGatewayStoreModel inner) { @@ -127,6 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) { this.httpClient = inner.httpClient; this.sessionContainer = inner.sessionContainer; + this.customHeaders = inner.customHeaders; } protected Map getDefaultHeaders( @@ -279,6 +283,17 @@ public Mono performRequest(RxDocumentServiceRequest r request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics(); } + // Apply client-level custom headers (e.g., workload-id) to all requests + // including metadata requests (collection cache, partition key range, etc.) + if (this.customHeaders != null && !this.customHeaders.isEmpty()) { + for (Map.Entry entry : this.customHeaders.entrySet()) { + // Only set if not already present — request-level headers take precedence + if (!request.getHeaders().containsKey(entry.getKey())) { + request.getHeaders().put(entry.getKey(), entry.getValue()); + } + } + } + URI uri = getUri(request); request.requestContext.resourcePhysicalAddress = uri.toString(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java index d32e5d901f18..ff139e203d2e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java @@ -56,7 +56,8 @@ public ThinClientStoreModel( userAgentContainer, globalEndpointManager, httpClient, - ApiType.SQL); + ApiType.SQL, + null); String userAgent = userAgentContainer != null ? userAgentContainer.getUserAgent() diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index e62d7b8c6ca4..7c761335b782 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -123,7 +123,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this.clientContext = clientContext; try { @@ -165,6 +166,14 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); + // Apply client-level custom headers (e.g., workload-id) to metadata requests + // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten + if (customHeaders != null && !customHeaders.isEmpty()) { + for (Map.Entry entry : customHeaders.entrySet()) { + this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + this.lastForcedRefreshMap = new ConcurrentHashMap<>(); this.globalEndpointManager = globalEndpointManager; this.proactiveOpenConnectionsProcessor = proactiveOpenConnectionsProcessor; @@ -188,7 +197,8 @@ public GatewayAddressCache( GlobalEndpointManager globalEndpointManager, ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, - GatewayServerErrorInjector gatewayServerErrorInjector) { + GatewayServerErrorInjector gatewayServerErrorInjector, + Map customHeaders) { this(clientContext, serviceEndpoint, protocol, @@ -200,7 +210,8 @@ public GatewayAddressCache( globalEndpointManager, connectionPolicy, proactiveOpenConnectionsProcessor, - gatewayServerErrorInjector); + gatewayServerErrorInjector, + customHeaders); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java index 00905682b4d1..2fd5287da028 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java @@ -62,6 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver { private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor; private ConnectionPolicy connectionPolicy; private GatewayServerErrorInjector gatewayServerErrorInjector; + private final Map customHeaders; public GlobalAddressResolver( DiagnosticsClientContext diagnosticsClientContext, @@ -74,7 +75,8 @@ public GlobalAddressResolver( UserAgentContainer userAgentContainer, GatewayServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, - ApiType apiType) { + ApiType apiType, + Map customHeaders) { this.diagnosticsClientContext = diagnosticsClientContext; this.httpClient = httpClient; this.endpointManager = endpointManager; @@ -86,6 +88,7 @@ public GlobalAddressResolver( this.serviceConfigReader = serviceConfigReader; this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled(); this.connectionPolicy = connectionPolicy; + this.customHeaders = customHeaders; int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0; this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover) @@ -290,7 +293,8 @@ private EndpointCache getOrAddEndpoint(URI endpoint) { this.endpointManager, this.connectionPolicy, this.proactiveOpenConnectionsProcessor, - this.gatewayServerErrorInjector); + this.gatewayServerErrorInjector, + this.customHeaders); AddressResolver addressResolver = new AddressResolver(); addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache); EndpointCache cache = new EndpointCache(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index f6e570258042..de2d769f789b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -366,6 +366,22 @@ public Set getKeywordIdentifiers() { return this.actualRequestOptions.getKeywordIdentifiers(); } + /** + * Sets a custom header to be included with this specific request. + *

+ * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * + * @param name the header name (e.g., "x-ms-cosmos-workload-id") + * @param value the header value (e.g., "20") + * @return the CosmosReadManyRequestOptions. + */ + public CosmosReadManyRequestOptions setHeader(String name, String value) { + this.actualRequestOptions.setHeader(name, value); + return this; + } + CosmosQueryRequestOptionsBase getImpl() { return this.actualRequestOptions; } From c218ae3fe578f6ffd605e64aa3570ff410554dc3 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:06:38 -0500 Subject: [PATCH 06/13] fix: addressing comments --- .../azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 2 +- .../cosmos/spark/CosmosClientCache.scala | 19 ++-- .../spark/CosmosClientConfiguration.scala | 8 +- .../com/azure/cosmos/spark/CosmosConfig.scala | 31 +++---- .../cosmos/spark/CosmosClientCacheITest.scala | 10 +-- .../spark/CosmosClientConfigurationSpec.scala | 48 +++++----- .../spark/CosmosPartitionPlannerSpec.scala | 16 ++-- .../cosmos/spark/PartitionMetadataSpec.scala | 32 +++---- .../azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- .../com/azure/cosmos/CosmosAsyncClient.java | 2 +- .../com/azure/cosmos/CosmosClientBuilder.java | 90 ++++++++----------- .../implementation/AsyncDocumentClient.java | 8 +- .../implementation/RxDocumentClientImpl.java | 20 ++--- .../GatewayAddressCache.java | 12 +-- .../GlobalAddressResolver.java | 8 +- .../models/CosmosBatchRequestOptions.java | 32 +++++-- .../models/CosmosBulkExecutionOptions.java | 32 +++++-- .../CosmosChangeFeedRequestOptions.java | 32 +++++-- .../models/CosmosItemRequestOptions.java | 32 +++++-- .../models/CosmosQueryRequestOptions.java | 32 +++++-- .../models/CosmosReadManyRequestOptions.java | 33 +++++-- 25 files changed, 314 insertions(+), 195 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index 5c00937013ad..6d23a6f92cf2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 705334e4c3a1..49da84cfc8d4 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 97bff869bb59..1290801e56f2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index d922e4a579d3..8660d2ba7b59 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index f31024628cb7..3a0174a9a9a8 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -13,7 +13,7 @@ import com.azure.cosmos.models.{CosmosClientTelemetryConfig, CosmosMetricCategor import com.azure.cosmos.spark.CosmosPredicates.isOnSparkDriver import com.azure.cosmos.spark.catalog.{CosmosCatalogClient, CosmosCatalogCosmosSDKClient, CosmosCatalogManagementSDKClient} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} +import com.azure.cosmos.{ConsistencyLevel, CosmosAsyncClient, CosmosClientBuilder, CosmosContainerProactiveInitConfigBuilder, CosmosDiagnosticsThresholds, CosmosHeaderName, DirectConnectionConfig, GatewayConnectionConfig, ReadConsistencyStrategy, ThrottlingRetryOptions} import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder} import com.azure.monitor.opentelemetry.autoconfigure.{AzureMonitorAutoConfigure, AzureMonitorAutoConfigureOptions} import com.azure.resourcemanager.cosmos.CosmosManager @@ -712,10 +712,15 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } - // Apply custom HTTP headers (e.g., workload-id) to the builder if configured. + // Apply additional HTTP headers (e.g., workload-id) to the builder if configured. // These headers are attached to every Cosmos DB request made by this client instance. - if (cosmosClientConfiguration.customHeaders.isDefined) { - builder.customHeaders(cosmosClientConfiguration.customHeaders.get.asJava) + // Converts Map[String, String] from Spark config to Map[CosmosHeaderName, String] for the builder. + if (cosmosClientConfiguration.additionalHeaders.isDefined) { + val enumHeaders = new java.util.HashMap[CosmosHeaderName, String]() + for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) { + enumHeaders.put(CosmosHeaderName.fromString(key), value) + } + builder.additionalHeaders(enumHeaders) } var client = builder.buildAsyncClient() @@ -922,9 +927,9 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], azureMonitorConfig: Option[AzureMonitorConfig], - // Custom HTTP headers are part of the cache key because different workload-ids + // Additional HTTP headers are part of the cache key because different workload-ids // should produce different CosmosAsyncClient instances - customHeaders: Option[Map[String, String]] + additionalHeaders: Option[Map[String, String]] ) private[this] object ClientConfigurationWrapper { @@ -944,7 +949,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { clientConfig.clientInterceptors, clientConfig.sampledDiagnosticsLoggerConfig, clientConfig.azureMonitorConfig, - clientConfig.customHeaders + clientConfig.additionalHeaders ) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 61fa0957af83..7e09c73698a7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -31,9 +31,9 @@ private[spark] case class CosmosClientConfiguration ( clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], sampledDiagnosticsLoggerConfig: Option[SampledDiagnosticsLoggerConfig], azureMonitorConfig: Option[AzureMonitorConfig], - // Optional custom HTTP headers (e.g., workload-id) to attach to - // all Cosmos DB requests via CosmosClientBuilder.customHeaders() - customHeaders: Option[Map[String, String]] + // Optional additional HTTP headers (e.g., workload-id) to attach to + // all Cosmos DB requests via CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) { private[spark] def getRoleInstanceName(machineId: Option[String]): String = { CosmosClientConfiguration.getRoleInstanceName(sparkEnvironmentInfo, machineId) @@ -98,7 +98,7 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.clientInterceptors, diagnosticsConfig.sampledDiagnosticsLoggerConfig, diagnosticsConfig.azureMonitorConfig, - cosmosAccountConfig.customHeaders + cosmosAccountConfig.additionalHeaders ) } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 928b0cd09445..870ee17841c5 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -152,10 +152,10 @@ private[spark] object CosmosConfigNames { val ThroughputControlTargetThroughputThreshold = "spark.cosmos.throughputControl.targetThroughputThreshold" val ThroughputControlPriorityLevel = "spark.cosmos.throughputControl.priorityLevel" val ThroughputControlThroughputBucket = "spark.cosmos.throughputControl.throughputBucket" - // Custom HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). + // Additional HTTP headers to attach to all Cosmos DB requests (e.g., workload-id for resource governance). // Value is a JSON string like: {"x-ms-cosmos-workload-id": "15"} - // Flows through to CosmosClientBuilder.customHeaders(). - val CustomHeaders = "spark.cosmos.customHeaders" + // Flows through to CosmosClientBuilder.additionalHeaders(). + val AdditionalHeaders = "spark.cosmos.additionalHeaders" val ThroughputControlGlobalControlDatabase = "spark.cosmos.throughputControl.globalControl.database" val ThroughputControlGlobalControlContainer = "spark.cosmos.throughputControl.globalControl.container" val ThroughputControlGlobalControlRenewalIntervalInMS = @@ -303,7 +303,7 @@ private[spark] object CosmosConfigNames { WriteFlushCloseIntervalInSeconds, WriteMaxNoProgressIntervalInSeconds, WriteMaxRetryNoProgressIntervalInSeconds, - CustomHeaders + AdditionalHeaders ) def validateConfigName(name: String): Unit = { @@ -547,9 +547,9 @@ private case class CosmosAccountConfig(endpoint: String, azureEnvironmentEndpoints: java.util.Map[String, String], clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], - // Optional custom HTTP headers (e.g., workload-id) parsed from - // spark.cosmos.customHeaders JSON config, passed to CosmosClientBuilder - customHeaders: Option[Map[String, String]] + // Optional additional HTTP headers (e.g., workload-id) parsed from + // spark.cosmos.additionalHeaders JSON config, passed to CosmosClientBuilder.additionalHeaders() + additionalHeaders: Option[Map[String, String]] ) private object CosmosAccountConfig extends BasicLoggingTrait { @@ -738,9 +738,10 @@ private object CosmosAccountConfig extends BasicLoggingTrait { // Config entry for custom HTTP headers (e.g., workload-id). Parses a JSON string like // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. - // These headers are passed to CosmosClientBuilder.customHeaders() in CosmosClientCache. - private val CustomHeadersConfig = CosmosConfigEntry[Map[String, String]]( - key = CosmosConfigNames.CustomHeaders, + // These headers are converted to Map[CosmosHeaderName, String] and passed to + // CosmosClientBuilder.additionalHeaders() in CosmosClientCache. + private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]]( + key = CosmosConfigNames.AdditionalHeaders, mandatory = false, parseFromStringFunction = headersJson => { try { @@ -748,11 +749,11 @@ private object CosmosAccountConfig extends BasicLoggingTrait { Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap } catch { case e: Exception => throw new IllegalArgumentException( - s"Invalid JSON for '${CosmosConfigNames.CustomHeaders}': '$headersJson'. " + + s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) } }, - helpMessage = "Optional custom headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") + helpMessage = "Optional additional headers as JSON map. Example: {\"x-ms-cosmos-workload-id\": \"15\"}") private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] @@ -788,8 +789,8 @@ private object CosmosAccountConfig extends BasicLoggingTrait { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) - // Parse optional custom HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) - val customHeaders = CosmosConfigEntry.parse(cfg, CustomHeadersConfig) + // Parse optional additional HTTP headers from JSON config (e.g., {"x-ms-cosmos-workload-id": "15"}) + val additionalHeaders = CosmosConfigEntry.parse(cfg, AdditionalHeadersConfig) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -910,7 +911,7 @@ private object CosmosAccountConfig extends BasicLoggingTrait { azureEnvironmentOpt.get, if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }, - customHeaders) + additionalHeaders) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index 4d542c44612e..8e44968a76df 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -65,7 +65,7 @@ class CosmosClientCacheITest clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) ), ( @@ -93,7 +93,7 @@ class CosmosClientCacheITest clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) ), ( @@ -121,7 +121,7 @@ class CosmosClientCacheITest clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) ), ( @@ -149,7 +149,7 @@ class CosmosClientCacheITest clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) ) ) @@ -184,7 +184,7 @@ class CosmosClientCacheITest clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) logInfo(s"TestCase: {$testCaseName}") diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 377425189f07..8a2b6a8191f9 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -409,27 +409,27 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.azureMonitorConfig shouldEqual None } - // Verifies that the spark.cosmos.customHeaders configuration option correctly parses + // Verifies that the spark.cosmos.additionalHeaders configuration option correctly parses // a JSON string containing a single workload-id header into a Map[String, String] on // CosmosClientConfiguration. This is the primary use case for the workload-id feature. - it should "parse customHeaders JSON" in { + it should "parse additionalHeaders JSON" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", - "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""" ) val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") - configuration.customHeaders shouldBe defined - configuration.customHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get("x-ms-cosmos-workload-id") shouldEqual "15" } - // Verifies that when spark.cosmos.customHeaders is not specified in the config map, - // CosmosClientConfiguration.customHeaders is None. This ensures backward compatibility — - // existing Spark jobs that don't set customHeaders continue to work without changes. - it should "handle missing customHeaders" in { + // Verifies that when spark.cosmos.additionalHeaders is not specified in the config map, + // CosmosClientConfiguration.additionalHeaders is None. This ensures backward compatibility — + // existing Spark jobs that don't set additionalHeaders continue to work without changes. + it should "handle missing additionalHeaders" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz" @@ -438,43 +438,43 @@ class CosmosClientConfigurationSpec extends UnitSpec { val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") - configuration.customHeaders shouldBe None + configuration.additionalHeaders shouldBe None } - // Verifies that spark.cosmos.customHeaders rejects unknown headers at the parsing level. - // Only headers in CosmosClientBuilder's allowlist are permitted. In Direct mode (RNTBD), - // unknown headers are silently dropped, so the allowlist ensures consistent behavior - // across Gateway and Direct modes. - it should "reject unknown custom headers" in { + // Verifies that spark.cosmos.additionalHeaders accepts multiple headers at the parsing level. + // The JSON is valid and CosmosClientConfiguration stores it as-is. + // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache when + // converting to Map[CosmosHeaderName, String] for CosmosClientBuilder.additionalHeaders(). + it should "reject unknown additional headers" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", - "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20", "x-custom-header": "value"}""" ) val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is. - // The allowlist validation happens later in CosmosClientBuilder.customHeaders() - configuration.customHeaders shouldBe defined - configuration.customHeaders.get should have size 2 + // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get should have size 2 } - // Verifies that spark.cosmos.customHeaders handles an empty JSON object ("{}") gracefully, + // Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully, // resulting in a defined but empty Map. This ensures the parser doesn't fail on edge cases // and that no headers are injected when the JSON object is empty. - it should "handle empty customHeaders JSON" in { + it should "handle empty additionalHeaders JSON" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", - "spark.cosmos.customHeaders" -> "{}" + "spark.cosmos.additionalHeaders" -> "{}" ) val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") - configuration.customHeaders shouldBe defined - configuration.customHeaders.get shouldBe empty + configuration.additionalHeaders shouldBe defined + configuration.additionalHeaders.get shouldBe empty } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index ab73dc4e54d3..f745bae6b56d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -40,7 +40,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -118,7 +118,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -196,7 +196,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -274,7 +274,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -350,7 +350,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -442,7 +442,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -517,7 +517,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -584,7 +584,7 @@ class CosmosPartitionPlannerSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index 65274bee2b19..c17e7b02dad8 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -39,7 +39,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -86,7 +86,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -172,7 +172,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -258,7 +258,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -326,7 +326,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -389,7 +389,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -446,7 +446,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -503,7 +503,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -560,7 +560,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -617,7 +617,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -697,7 +697,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -759,7 +759,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -816,7 +816,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -890,7 +890,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -964,7 +964,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) @@ -1043,7 +1043,7 @@ class PartitionMetadataSpec extends UnitSpec { clientInterceptors = None, sampledDiagnosticsLoggerConfig = None, azureMonitorConfig = None, - customHeaders = None + additionalHeaders = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString, None) diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index 799d86414ff6..6149df9f05e4 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index c080bbc9c5c5..424aef4e3a30 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -4,7 +4,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) -* Added `customHeaders` support to allow setting custom HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java index f54f44482db5..13f526552301 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncClient.java @@ -186,7 +186,7 @@ public final class CosmosAsyncClient implements Closeable { .withDefaultSerializer(this.defaultCustomSerializer) .withRegionScopedSessionCapturingEnabled(builder.isRegionScopedSessionCapturingEnabled()) .withPerPartitionAutomaticFailoverEnabled(builder.isPerPartitionAutomaticFailoverEnabled()) - .withCustomHeaders(builder.getCustomHeaders()) + .withAdditionalHeaders(builder.getAdditionalHeaders()) .build(); this.accountConsistencyLevel = this.asyncDocumentClient.getDefaultConsistencyLevelOfAccount(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index e4cdf6ca1e30..129b4b578c58 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -13,7 +13,6 @@ import com.azure.cosmos.implementation.ConnectionPolicy; import com.azure.cosmos.implementation.CosmosClientMetadataCachesSnapshot; import com.azure.cosmos.implementation.DiagnosticsProvider; -import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.implementation.Strings; import com.azure.cosmos.implementation.WriteRetryPolicy; import com.azure.cosmos.implementation.apachecommons.collections.list.UnmodifiableList; @@ -34,6 +33,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -157,21 +157,7 @@ public class CosmosClientBuilder implements private boolean serverCertValidationDisabled = false; private Function containerFactory = null; - private Map customHeaders; - - /** - * Allowlist of headers permitted in {@link #customHeaders(Map)}. - *

- * In Direct mode (RNTBD), only headers with explicit encoding support in - * {@code RntbdRequestHeaders} are sent on the wire. Unknown headers are silently dropped. - * This allowlist ensures consistent behavior across Gateway and Direct modes - if a header - * is allowed here, it works in both modes. To add a new allowed header, you must also add - * RNTBD encoding support ({@code RntbdConstants.RntbdRequestHeader} enum entry + - * {@code RntbdRequestHeaders.addXxx()} method). - */ - private static final Set ALLOWED_CUSTOM_HEADERS = Collections.unmodifiableSet( - new HashSet<>(Collections.singletonList(HttpConstants.HttpHeaders.WORKLOAD_ID)) - ); + private Map additionalHeaders; /** * Instantiates a new Cosmos client builder. @@ -752,59 +738,53 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { } /** - * Sets custom HTTP headers that will be included with every request from this client. + * Sets additional HTTP headers that will be included with every request from this client. *

- * Only headers in the SDK's allowlist are permitted. Currently the only allowed header is - * {@code x-ms-cosmos-workload-id}. Passing any other header key will throw - * {@link IllegalArgumentException}. + * The {@link CosmosHeaderName} enum defines exactly which headers are supported. Currently + * the only supported header is {@link CosmosHeaderName#WORKLOAD_ID} + * ({@code x-ms-cosmos-workload-id}). *

* This restriction exists because in Direct mode (RNTBD), only headers with explicit - * encoding support are sent on the wire. Unknown headers are silently dropped. The allowlist - * ensures consistent behavior across both Gateway and Direct modes. + * encoding support are sent on the wire. The enum ensures consistent behavior across + * both Gateway and Direct modes. *

* If the same header is also set on request options (e.g., - * {@code CosmosItemRequestOptions.setHeader(String, String)}), + * {@code CosmosItemRequestOptions.setAdditionalHeaders(Map)}), * the request-level value takes precedence over the client-level value. * - * @param customHeaders map of header name to value + * @param additionalHeaders map of {@link CosmosHeaderName} to value * @return current CosmosClientBuilder - * @throws IllegalArgumentException if any header key is not in the allowlist, or if the - * workload-id value is not a valid integer - */ - public CosmosClientBuilder customHeaders(Map customHeaders) { - if (customHeaders != null) { - for (Map.Entry entry : customHeaders.entrySet()) { - String key = entry.getKey(); - String value = entry.getValue(); - - if (!ALLOWED_CUSTOM_HEADERS.contains(key)) { - throw new IllegalArgumentException( - "Header '" + key + "' is not allowed in customHeaders. " - + "Allowed headers: " + ALLOWED_CUSTOM_HEADERS); - } + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosClientBuilder additionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + this.additionalHeaders = additionalHeaders; + return this; + } - // Validate workload-id value is a valid integer (range validation is left to the backend) - if (HttpConstants.HttpHeaders.WORKLOAD_ID.equals(key) && value != null) { - try { - Integer.parseInt(value); - } catch (NumberFormatException e) { - throw new IllegalArgumentException( - "Invalid value '" + value + "' for header '" + key - + "'. The value must be a valid integer.", e); - } - } - } + /** + * Gets the additional headers configured on this builder, converted to a + * {@code Map} for internal use. + * @return the additional headers map with string keys, or null if not set + */ + Map getAdditionalHeaders() { + if (this.additionalHeaders == null) { + return null; } - this.customHeaders = customHeaders; - return this; + Map result = new HashMap<>(); + for (Map.Entry entry : this.additionalHeaders.entrySet()) { + result.put(entry.getKey().getHeaderName(), entry.getValue()); + } + return result; } /** - * Gets the custom headers configured on this builder. - * @return the custom headers map, or null if not set + * Gets the additional headers configured on this builder in their original + * {@code Map} form. + * @return the additional headers map, or null if not set */ - Map getCustomHeaders() { - return this.customHeaders; + Map getAdditionalHeadersRaw() { + return this.additionalHeaders; } /** diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 7953721019c5..81a8f2826f04 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -116,7 +116,7 @@ class Builder { private boolean isRegionScopedSessionCapturingEnabled; private boolean isPerPartitionAutomaticFailoverEnabled; private List operationPolicies; - private Map customHeaders; + private Map additionalHeaders; public Builder withServiceEndpoint(String serviceEndpoint) { try { @@ -289,8 +289,8 @@ public Builder withPerPartitionAutomaticFailoverEnabled(boolean isPerPartitionAu return this; } - public Builder withCustomHeaders(Map customHeaders) { - this.customHeaders = customHeaders; + public Builder withAdditionalHeaders(Map additionalHeaders) { + this.additionalHeaders = additionalHeaders; return this; } @@ -335,7 +335,7 @@ public AsyncDocumentClient build() { isRegionScopedSessionCapturingEnabled, operationPolicies, isPerPartitionAutomaticFailoverEnabled, - customHeaders); + additionalHeaders); client.init(state, null); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index c28400e86973..b13b6c02009c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -294,7 +294,7 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization private final AtomicReference cachedCosmosAsyncClientSnapshot; private CosmosEndToEndOperationLatencyPolicyConfig ppafEnforcedE2ELatencyPolicyConfigForReads; private Consumer perPartitionFailoverConfigModifier; - private Map customHeaders; + private Map additionalHeaders; public RxDocumentClientImpl(URI serviceEndpoint, String masterKeyOrResourceToken, @@ -421,7 +421,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, boolean isRegionScopedSessionCapturingEnabled, List operationPolicies, boolean isPerPartitionAutomaticFailoverEnabled, - Map customHeaders) { + Map additionalHeaders) { this( serviceEndpoint, masterKeyOrResourceToken, @@ -448,7 +448,7 @@ public RxDocumentClientImpl(URI serviceEndpoint, this.cosmosAuthorizationTokenResolver = cosmosAuthorizationTokenResolver; this.operationPolicies = operationPolicies; - this.customHeaders = customHeaders; + this.additionalHeaders = additionalHeaders; } private RxDocumentClientImpl(URI serviceEndpoint, @@ -865,7 +865,7 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.globalEndpointManager, this.reactorHttpClient, this.apiType, - this.customHeaders); + this.additionalHeaders); this.thinProxy = createThinProxy(this.sessionContainer, this.consistencyLevel, @@ -983,7 +983,7 @@ private void initializeDirectConnectivity() { null, this.connectionPolicy, this.apiType, - this.customHeaders); + this.additionalHeaders); this.storeClientFactory = new StoreClientFactory( this.addressResolver, @@ -1028,7 +1028,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, GlobalEndpointManager globalEndpointManager, HttpClient httpClient, ApiType apiType, - Map customHeaders) { + Map additionalHeaders) { return new RxGatewayStoreModel( this, sessionContainer, @@ -1038,7 +1038,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, globalEndpointManager, httpClient, apiType, - customHeaders); + additionalHeaders); } ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, @@ -1956,9 +1956,9 @@ public void validateAndLogNonDefaultReadConsistencyStrategy(String readConsisten private Map getRequestHeaders(RequestOptions options, ResourceType resourceType, OperationType operationType) { Map headers = new HashMap<>(); - // Apply client-level custom headers first (e.g., workload-id from CosmosClientBuilder.customHeaders()) - if (this.customHeaders != null && !this.customHeaders.isEmpty()) { - headers.putAll(this.customHeaders); + // Apply client-level additional headers first (e.g., workload-id from CosmosClientBuilder.additionalHeaders()) + if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) { + headers.putAll(this.additionalHeaders); } if (this.useMultipleWriteLocations) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index 7c761335b782..119c7309256d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -124,7 +124,7 @@ public GatewayAddressCache( ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, GatewayServerErrorInjector gatewayServerErrorInjector, - Map customHeaders) { + Map additionalHeaders) { this.clientContext = clientContext; try { @@ -166,10 +166,10 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); - // Apply client-level custom headers (e.g., workload-id) to metadata requests + // Apply client-level additional headers (e.g., workload-id) to metadata requests // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten - if (customHeaders != null && !customHeaders.isEmpty()) { - for (Map.Entry entry : customHeaders.entrySet()) { + if (additionalHeaders != null && !additionalHeaders.isEmpty()) { + for (Map.Entry entry : additionalHeaders.entrySet()) { this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); } } @@ -198,7 +198,7 @@ public GatewayAddressCache( ConnectionPolicy connectionPolicy, ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor, GatewayServerErrorInjector gatewayServerErrorInjector, - Map customHeaders) { + Map additionalHeaders) { this(clientContext, serviceEndpoint, protocol, @@ -211,7 +211,7 @@ public GatewayAddressCache( connectionPolicy, proactiveOpenConnectionsProcessor, gatewayServerErrorInjector, - customHeaders); + additionalHeaders); } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java index 2fd5287da028..6dded78b8a07 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GlobalAddressResolver.java @@ -62,7 +62,7 @@ public class GlobalAddressResolver implements IAddressResolver { private ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor; private ConnectionPolicy connectionPolicy; private GatewayServerErrorInjector gatewayServerErrorInjector; - private final Map customHeaders; + private final Map additionalHeaders; public GlobalAddressResolver( DiagnosticsClientContext diagnosticsClientContext, @@ -76,7 +76,7 @@ public GlobalAddressResolver( GatewayServiceConfigurationReader serviceConfigReader, ConnectionPolicy connectionPolicy, ApiType apiType, - Map customHeaders) { + Map additionalHeaders) { this.diagnosticsClientContext = diagnosticsClientContext; this.httpClient = httpClient; this.endpointManager = endpointManager; @@ -88,7 +88,7 @@ public GlobalAddressResolver( this.serviceConfigReader = serviceConfigReader; this.tcpConnectionEndpointRediscoveryEnabled = connectionPolicy.isTcpConnectionEndpointRediscoveryEnabled(); this.connectionPolicy = connectionPolicy; - this.customHeaders = customHeaders; + this.additionalHeaders = additionalHeaders; int maxBackupReadEndpoints = (connectionPolicy.isReadRequestsFallbackEnabled()) ? GlobalAddressResolver.MaxBackupReadRegions : 0; this.maxEndpoints = maxBackupReadEndpoints + 2; // for write and alternate write getEndpoint (during failover) @@ -294,7 +294,7 @@ private EndpointCache getOrAddEndpoint(URI endpoint) { this.connectionPolicy, this.proactiveOpenConnectionsProcessor, this.gatewayServerErrorInjector, - this.customHeaders); + this.additionalHeaders); AddressResolver addressResolver = new AddressResolver(); addressResolver.initializeCaches(this.collectionCache, this.routingMapProvider, gatewayAddressCache); EndpointCache cache = new EndpointCache(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index 3183fe59bdea..ff3b335c8824 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosItemSerializer; @@ -154,17 +155,38 @@ RequestOptions toRequestOptions() { } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosBatchRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosBatchRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosBatchRequestOptions. */ - public CosmosBatchRequestOptions setHeader(String name, String value) { + CosmosBatchRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index cd688f8a0da6..588f74cea99e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.implementation.CosmosBulkExecutionOptionsImpl; import com.azure.cosmos.implementation.ImplementationBridgeHelpers; @@ -257,17 +258,38 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosBulkExecutionOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosBulkExecutionOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosBulkExecutionOptions. */ - public CosmosBulkExecutionOptions setHeader(String name, String value) { + CosmosBulkExecutionOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index a1b675f2ffd8..b78499cd185c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.ReadConsistencyStrategy; @@ -564,17 +565,38 @@ public List getExcludedRegions() { } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosChangeFeedRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosChangeFeedRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosChangeFeedRequestOptions. */ - public CosmosChangeFeedRequestOptions setHeader(String name, String value) { + CosmosChangeFeedRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index fbc540e5baeb..d5447034274c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -3,6 +3,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosClientBuilder; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; @@ -566,17 +567,38 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosItemRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosItemRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosItemRequestOptions. */ - public CosmosItemRequestOptions setHeader(String name, String value) { + CosmosItemRequestOptions setHeader(String name, String value) { if (this.customOptions == null) { this.customOptions = new HashMap<>(); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index f0de81bbf823..ae998a39c360 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnostics; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; @@ -261,17 +262,38 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosQueryRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosQueryRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosQueryRequestOptions. */ - public CosmosQueryRequestOptions setHeader(String name, String value) { + CosmosQueryRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index de2d769f789b..98505acdab98 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -4,6 +4,7 @@ package com.azure.cosmos.models; import com.azure.cosmos.ConsistencyLevel; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.CosmosDiagnosticsThresholds; import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosItemSerializer; @@ -15,6 +16,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -367,17 +369,38 @@ public Set getKeywordIdentifiers() { } /** - * Sets a custom header to be included with this specific request. + * Sets additional headers to be included with this specific request. *

+ * The {@link CosmosHeaderName} enum defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via - * {@link com.azure.cosmos.CosmosClientBuilder#customHeaders(java.util.Map)}. + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosReadManyRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosReadManyRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.actualRequestOptions.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + /** + * Sets a header to be included with this specific request. * - * @param name the header name (e.g., "x-ms-cosmos-workload-id") - * @param value the header value (e.g., "20") + * @param name the header name + * @param value the header value * @return the CosmosReadManyRequestOptions. */ - public CosmosReadManyRequestOptions setHeader(String name, String value) { + CosmosReadManyRequestOptions setHeader(String name, String value) { this.actualRequestOptions.setHeader(name, value); return this; } From b13d0e92656f6103cb427e623c57f2f0db304131 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:32:05 -0500 Subject: [PATCH 07/13] fix: refactoring --- .../spark/SparkE2EWorkloadIdITest.scala | 52 +-- .../azure/cosmos/AdditionalHeadersTests.java | 312 ++++++++++++++++++ .../com/azure/cosmos/CustomHeadersTests.java | 256 -------------- .../RxDocumentClientUnderTest.java | 4 +- .../RxGatewayStoreModelTest.java | 26 +- .../SpyClientUnderTestFactory.java | 4 +- .../GatewayAddressCacheTest.java | 32 +- .../rx/WorkloadIdDirectInterceptorTests.java | 223 +++++++++++++ .../azure/cosmos/rx/WorkloadIdE2ETests.java | 85 +++-- .../com/azure/cosmos/CosmosHeaderName.java | 99 ++++++ .../implementation/RxGatewayStoreModel.java | 14 +- 11 files changed, 755 insertions(+), 352 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java delete mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java create mode 100644 sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala index d9706d0709e5..0f0a1f200d3e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2EWorkloadIdITest.scala @@ -10,14 +10,18 @@ import com.fasterxml.jackson.databind.node.ObjectNode import java.util.UUID /** - * End-to-end integration tests for the custom headers (workload-id) feature in the Spark connector. + * Integration tests (smoke tests) for the additional headers (workload-id) feature in the Spark connector. + * These are smoke tests — they verify that Spark DataFrame read and write operations succeed + * (no errors, correct data) when the `spark.cosmos.additionalHeaders` configuration is set. + * They do NOT assert that the workload-id header is actually present on the wire request. + * Wire-level header propagation is verified by: + * - Java SDK unit tests: RxGatewayStoreModelTest, GatewayAddressCacheTest + * - Java SDK integration tests: WorkloadIdE2ETests (interceptor-based wire assertions) * - * These tests verify that the `spark.cosmos.customHeaders` configuration option correctly flows - * through the Spark connector pipeline into CosmosClientBuilder.customHeaders(), ensuring that - * custom HTTP headers (such as x-ms-cosmos-workload-id) are applied to all Cosmos DB operations - * initiated via Spark DataFrames (reads and writes). - * - * Requires the Cosmos DB Emulator running + * Test cases: + * 1. Read with workload-id header — Spark read succeeds, correct item returned + * 2. Write with workload-id header — Spark write succeeds, item verified via SDK read-back + * 3. No additionalHeaders (regression) — operations succeed without the config set */ class SparkE2EWorkloadIdITest extends IntegrationSpec @@ -32,10 +36,12 @@ class SparkE2EWorkloadIdITest //scalastyle:off magic.number //scalastyle:off null - // Verifies that a Spark DataFrame read operation succeeds when spark.cosmos.customHeaders - // is configured with a workload-id header. The header should be passed through to the - // CosmosAsyncClient via CosmosClientBuilder.customHeaders() without affecting read behavior. - "spark query with customHeaders" can "read items with workload-id header" in { + // Integration smoke test #1: Spark read with workload-id header. + // Creates an item via SDK, then reads it back via Spark DataFrame with + // spark.cosmos.additionalHeaders set to {"x-ms-cosmos-workload-id": "15"}. + // Verifies the read succeeds and returns the correct item. + // This proves the header flows through the Spark config pipeline without causing errors. + "spark query with additionalHeaders" can "read items with workload-id header" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY @@ -58,7 +64,7 @@ class SparkE2EWorkloadIdITest "spark.cosmos.accountKey" -> cosmosMasterKey, "spark.cosmos.database" -> cosmosDatabase, "spark.cosmos.container" -> cosmosContainer, - "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "15"}""", "spark.cosmos.read.partitioning.strategy" -> "Restrictive" ) @@ -70,10 +76,12 @@ class SparkE2EWorkloadIdITest item.getAs[String]("id") shouldEqual id } - // Verifies that a Spark DataFrame write operation succeeds when spark.cosmos.customHeaders - // is configured with a workload-id header. The item is written via Spark and then verified - // via a direct SDK read to confirm the write was persisted correctly. - "spark write with customHeaders" can "write items with workload-id header" in { + // Integration smoke test #2: Spark write with workload-id header. + // Writes an item via Spark DataFrame with spark.cosmos.additionalHeaders set to + // {"x-ms-cosmos-workload-id": "20"}, then reads it back via SDK to confirm + // write was persisted correctly. + // This proves the header flows through the Spark write pipeline without causing errors. + "spark write with additionalHeaders" can "write items with workload-id header" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY @@ -91,7 +99,7 @@ class SparkE2EWorkloadIdITest "spark.cosmos.accountKey" -> cosmosMasterKey, "spark.cosmos.database" -> cosmosDatabase, "spark.cosmos.container" -> cosmosContainer, - "spark.cosmos.customHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", + "spark.cosmos.additionalHeaders" -> """{"x-ms-cosmos-workload-id": "20"}""", "spark.cosmos.write.strategy" -> "ItemOverwrite", "spark.cosmos.write.bulk.enabled" -> "false", "spark.cosmos.serialization.inclusionMode" -> "NonDefault" @@ -110,10 +118,12 @@ class SparkE2EWorkloadIdITest readItem.getItem.get("name").textValue() shouldEqual "testWriteItem" } - // Regression test: verifies that Spark read operations continue to work correctly when - // spark.cosmos.customHeaders is NOT specified. Ensures that the feature addition does not - // break existing behavior for clients that do not use custom headers. - "spark operations without customHeaders" can "still succeed" in { + // Integration smoke test #3: Regression test — no additionalHeaders configured. + // Verifies that Spark read operations continue to work correctly when + // spark.cosmos.additionalHeaders is NOT specified in the config map. + // This ensures the feature addition does not break existing Spark jobs + // that don't use additional headers. + "spark operations without additionalHeaders" can "still succeed" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java new file mode 100644 index 000000000000..9596988db621 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for the additional headers (workload-id) feature on CosmosClientBuilder and request options classes. + *

+ * These tests verify the public API surface: builder fluent methods, getter behavior, + * null/empty handling, CosmosHeaderName enum, and that setAdditionalHeaders() is publicly accessible + * on all request options classes. + */ +public class AdditionalHeadersTests { + + /** + * Verifies that additional headers (e.g., workload-id) set via CosmosClientBuilder.additionalHeaders() + * are stored correctly and retrievable via getAdditionalHeaders(). + */ + @Test(groups = { "unit" }) + public void additionalHeadersSetOnBuilder() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry("x-ms-cosmos-workload-id", "25"); + } + + /** + * Verifies that passing null to additionalHeaders() does not throw and that + * getAdditionalHeaders() returns null, ensuring graceful null handling. + */ + @Test(groups = { "unit" }) + public void additionalHeadersNullHandledGracefully() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(null); + + assertThat(builder.getAdditionalHeaders()).isNull(); + } + + /** + * Verifies that passing an empty map to additionalHeaders() is accepted and + * getAdditionalHeaders() returns an empty (not null) map. + */ + @Test(groups = { "unit" }) + public void additionalHeadersEmptyMapHandled() { + Map emptyHeaders = new HashMap<>(); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(emptyHeaders); + + assertThat(builder.getAdditionalHeaders()).isEmpty(); + } + + /** + * Verifies that CosmosHeaderName.WORKLOAD_ID maps to the correct header string. + */ + @Test(groups = { "unit" }) + public void cosmosHeaderNameWorkloadIdValue() { + assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName()) + .isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that CosmosHeaderName.fromString() resolves known header strings to the + * correct enum value. This is used by the Spark connector to convert config strings + * to enum keys. + */ + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringResolvesKnownHeader() { + CosmosHeaderName name = CosmosHeaderName.fromString("x-ms-cosmos-workload-id"); + assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + /** + * Verifies that CosmosHeaderName.fromString() is case-insensitive. + */ + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringIsCaseInsensitive() { + CosmosHeaderName name = CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID"); + assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + /** + * Verifies that CosmosHeaderName.fromString() throws IllegalArgumentException + * for unknown header strings. This is the runtime equivalent of the compile-time + * safety provided by the enum — used when converting from Spark JSON config. + */ + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringRejectsUnknownHeader() { + assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header") + .hasMessageContaining("Unknown header"); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosItemRequestOptions + * and supports fluent chaining for per-request header overrides on CRUD operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnItemRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBatchRequestOptions + * and supports fluent chaining for per-request header overrides on batch operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnBatchRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "20"); + + CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosChangeFeedRequestOptions + * and supports fluent chaining for per-request header overrides on change feed operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnChangeFeedRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBulkExecutionOptions + * and supports fluent chaining for per-request header overrides on bulk ingestion operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnBulkExecutionOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "30"); + + CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosQueryRequestOptions + * and supports fluent chaining for per-request header overrides on query operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnQueryRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "35"); + + CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that setAdditionalHeaders() is publicly accessible on CosmosReadManyRequestOptions + * and supports fluent chaining for per-request header overrides on read-many operations. + */ + @Test(groups = { "unit" }) + public void setAdditionalHeadersOnReadManyRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "40"); + + CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() + .setAdditionalHeaders(headers); + + assertThat(options).isNotNull(); + } + + /** + * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined + * with the correct canonical header name "x-ms-cosmos-workload-id" as expected + * by the Cosmos DB service. + */ + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + /** + * Verifies that a non-numeric workload-id value is rejected at builder level with + * IllegalArgumentException. This covers both Gateway and Direct modes consistently + * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. + * Range validation [1, 50] is the backend's responsibility — the SDK only validates + * that the value is a valid integer. This avoids hardcoding a range the backend team + * might change in the future. + */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); + } + + /** + * Verifies that a non-numeric workload-id value is rejected at request-options level + * (CosmosItemRequestOptions) with IllegalArgumentException, ensuring validation is + * symmetric with the builder level. + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosItemRequestOptions() + .setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that a non-numeric workload-id value is rejected at request-options level + * (CosmosBulkExecutionOptions) with IllegalArgumentException, ensuring validation is + * symmetric with the builder level. + */ + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); + + assertThatThrownBy(() -> new CosmosBulkExecutionOptions() + .setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not-a-number") + .hasMessageContaining("valid integer"); + } + + /** + * Verifies that a valid workload-id value passes builder validation. + */ + @Test(groups = { "unit" }) + public void validWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + } +} + diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java deleted file mode 100644 index 19eb03744d1a..000000000000 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CustomHeadersTests.java +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.cosmos; - -import com.azure.cosmos.implementation.HttpConstants; -import com.azure.cosmos.models.CosmosBatchRequestOptions; -import com.azure.cosmos.models.CosmosBulkExecutionOptions; -import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; -import com.azure.cosmos.models.CosmosItemRequestOptions; -import com.azure.cosmos.models.CosmosQueryRequestOptions; -import com.azure.cosmos.models.CosmosReadManyRequestOptions; -import com.azure.cosmos.models.FeedRange; -import org.testng.annotations.Test; - -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for the custom headers (workload-id) feature on CosmosClientBuilder and request options classes. - *

- * These tests verify the public API surface: builder fluent methods, getter behavior, - * null/empty handling, and that setHeader() is publicly accessible on all request options classes. - */ -public class CustomHeadersTests { - - /** - * Verifies that custom headers (e.g., workload-id) set via CosmosClientBuilder.customHeaders() - * are stored correctly and retrievable via getCustomHeaders(). - */ - @Test(groups = { "unit" }) - public void customHeadersSetOnBuilder() { - Map headers = new HashMap<>(); - headers.put("x-ms-cosmos-workload-id", "25"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers); - - assertThat(builder.getCustomHeaders()).containsEntry("x-ms-cosmos-workload-id", "25"); - } - - /** - * Verifies that passing null to customHeaders() does not throw and that - * getCustomHeaders() returns null, ensuring graceful null handling. - */ - @Test(groups = { "unit" }) - public void customHeadersNullHandledGracefully() { - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(null); - - assertThat(builder.getCustomHeaders()).isNull(); - } - - /** - * Verifies that passing an empty map to customHeaders() is accepted and - * getCustomHeaders() returns an empty (not null) map. - */ - @Test(groups = { "unit" }) - public void customHeadersEmptyMapHandled() { - Map emptyHeaders = new HashMap<>(); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(emptyHeaders); - - assertThat(builder.getCustomHeaders()).isEmpty(); - } - - /** - * Verifies that headers not in the allowlist are rejected with IllegalArgumentException. - * This ensures consistent behavior across Gateway and Direct modes — only headers with - * RNTBD encoding support are allowed. - */ - @Test(groups = { "unit" }) - public void unknownHeaderRejectedByAllowlist() { - Map headers = new HashMap<>(); - headers.put("x-ms-custom-header", "value"); - - assertThatThrownBy(() -> new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("x-ms-custom-header") - .hasMessageContaining("not allowed"); - } - - /** - * Verifies that a map containing both an allowed header and a disallowed header - * is rejected — the entire map must pass the allowlist check. - */ - @Test(groups = { "unit" }) - public void mixedAllowedAndDisallowedHeadersRejected() { - Map headers = new HashMap<>(); - headers.put("x-ms-cosmos-workload-id", "15"); - headers.put("x-ms-custom-header", "value"); - - assertThatThrownBy(() -> new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("x-ms-custom-header"); - } - - /** - * Verifies that setHeader() is publicly accessible on CosmosItemRequestOptions - * (previously package-private) and supports fluent chaining for per-request - * header overrides on CRUD operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnItemRequestOptionsIsPublic() { - CosmosItemRequestOptions options = new CosmosItemRequestOptions() - .setHeader("x-ms-cosmos-workload-id", "15"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setHeader() is publicly accessible on CosmosBatchRequestOptions - * (previously package-private) and supports fluent chaining for per-request - * header overrides on batch operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnBatchRequestOptionsIsPublic() { - CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() - .setHeader("x-ms-cosmos-workload-id", "20"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setHeader() is publicly accessible on CosmosChangeFeedRequestOptions - * (previously package-private) and supports fluent chaining for per-request - * header overrides on change feed operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnChangeFeedRequestOptionsIsPublic() { - CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions - .createForProcessingFromBeginning(FeedRange.forFullRange()) - .setHeader("x-ms-cosmos-workload-id", "25"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setHeader() is publicly accessible on CosmosBulkExecutionOptions - * (previously package-private) and supports fluent chaining for per-request - * header overrides on bulk ingestion operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnBulkExecutionOptionsIsPublic() { - CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() - .setHeader("x-ms-cosmos-workload-id", "30"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that the new delegating setHeader() method on CosmosQueryRequestOptions - * is publicly accessible and supports fluent chaining for per-request header - * overrides on query operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnQueryRequestOptionsIsPublic() { - CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() - .setHeader("x-ms-cosmos-workload-id", "35"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that the new delegating setHeader() method on CosmosReadManyRequestOptions - * is publicly accessible and supports fluent chaining for per-request header - * overrides on read-many operations. - */ - @Test(groups = { "unit" }) - public void setHeaderOnReadManyRequestOptionsIsPublic() { - CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() - .setHeader("x-ms-cosmos-workload-id", "40"); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined - * with the correct canonical header name "x-ms-cosmos-workload-id" as expected - * by the Cosmos DB service. - */ - @Test(groups = { "unit" }) - public void workloadIdHttpHeaderConstant() { - assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); - } - - /** - * Verifies that a non-numeric workload-id value is rejected at builder level with - * IllegalArgumentException. This covers both Gateway and Direct modes consistently - * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). - */ - @Test(groups = { "unit" }) - public void nonNumericWorkloadIdRejectedAtBuilderLevel() { - Map headers = new HashMap<>(); - headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "abc"); - - assertThatThrownBy(() -> new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("abc") - .hasMessageContaining("valid integer"); - } - - /** - * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. - * Range validation [1, 50] is the backend's responsibility — the SDK only validates - * that the value is a valid integer. This avoids hardcoding a range the backend team - * might change in the future. - */ - @Test(groups = { "unit" }) - public void outOfRangeWorkloadIdAcceptedByBuilder() { - Map headers = new HashMap<>(); - headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers); - - assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); - } - - /** - * Verifies that a valid workload-id value passes builder validation. - */ - @Test(groups = { "unit" }) - public void validWorkloadIdAcceptedByBuilder() { - Map headers = new HashMap<>(); - headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .customHeaders(headers); - - assertThat(builder.getCustomHeaders()).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); - } -} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java index d5f8b92ac7a6..b4cb04e0ae3e 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxDocumentClientUnderTest.java @@ -77,7 +77,7 @@ RxGatewayStoreModel createRxGatewayProxy( GlobalPartitionEndpointManagerForPerPartitionCircuitBreaker globalPartitionEndpointManagerForPerPartitionCircuitBreaker, HttpClient rxOrigClient, ApiType apiType, - Map customHeaders) { + Map additionalHeaders) { origHttpClient = rxOrigClient; spyHttpClient = Mockito.spy(rxOrigClient); @@ -96,6 +96,6 @@ RxGatewayStoreModel createRxGatewayProxy( globalEndpointManager, spyHttpClient, apiType, - customHeaders); + additionalHeaders); } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java index 587844f4043a..edf16806489c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/RxGatewayStoreModelTest.java @@ -436,12 +436,12 @@ private boolean runCancelAfterRetainIteration() throws Exception { } /** - * Verifies that client-level customHeaders (e.g., workload-id) are injected into + * Verifies that client-level additionalHeaders (e.g., workload-id) are injected into * outgoing HTTP requests by performRequest(). This covers metadata requests * (collection cache, partition key range) that don't go through getRequestHeaders(). */ @Test(groups = "unit") - public void customHeadersInjectedInPerformRequest() throws Exception { + public void additionalHeadersInjectedInPerformRequest() throws Exception { DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); @@ -453,8 +453,8 @@ public void customHeadersInjectedInPerformRequest() throws Exception { ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); - Map customHeaders = new HashMap<>(); - customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); RxGatewayStoreModel storeModel = new RxGatewayStoreModel( clientContext, @@ -465,9 +465,9 @@ public void customHeadersInjectedInPerformRequest() throws Exception { globalEndpointManager, httpClient, null, - customHeaders); + additionalHeaders); - // Simulate a metadata request (e.g., collection cache lookup) — no customHeaders on the request itself + // Simulate a metadata request (e.g., collection cache lookup) — no additionalHeaders on the request itself RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, OperationType.Read, @@ -490,12 +490,12 @@ public void customHeadersInjectedInPerformRequest() throws Exception { } /** - * Verifies that request-level headers take precedence over client-level customHeaders. + * Verifies that request-level headers take precedence over client-level additionalHeaders. * If a request already has workload-id set (e.g., via getRequestHeaders()), performRequest() * should NOT overwrite it. */ @Test(groups = "unit") - public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exception { + public void requestLevelHeadersTakePrecedenceOverAdditionalHeaders() throws Exception { DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); @@ -507,8 +507,8 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio ArgumentCaptor httpClientRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); Mockito.when(httpClient.send(any(), any())).thenReturn(Mono.error(new ConnectTimeoutException())); - Map customHeaders = new HashMap<>(); - customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "10"); RxGatewayStoreModel storeModel = new RxGatewayStoreModel( clientContext, @@ -519,7 +519,7 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio globalEndpointManager, httpClient, null, - customHeaders); + additionalHeaders); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -547,11 +547,11 @@ public void requestLevelHeadersTakePrecedenceOverCustomHeaders() throws Exceptio } /** - * Verifies that when customHeaders is null, performRequest() still works normally + * Verifies that when additionalHeaders is null, performRequest() still works normally * without injecting any extra headers. */ @Test(groups = "unit") - public void nullCustomHeadersDoesNotAffectPerformRequest() throws Exception { + public void nullAdditionalHeadersDoesNotAffectPerformRequest() throws Exception { DiagnosticsClientContext clientContext = mockDiagnosticsClientContext(); ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java index 775b74785630..aa9029576b24 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/SpyClientUnderTestFactory.java @@ -128,7 +128,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, GlobalEndpointManager globalEndpointManager, HttpClient rxClient, ApiType apiType, - Map customHeaders) { + Map additionalHeaders) { this.origRxGatewayStoreModel = super.createRxGatewayProxy( sessionContainer, consistencyLevel, @@ -137,7 +137,7 @@ RxGatewayStoreModel createRxGatewayProxy(ISessionContainer sessionContainer, globalEndpointManager, rxClient, apiType, - customHeaders); + additionalHeaders); this.requests = Collections.synchronizedList(new ArrayList<>()); this.spyRxGatewayStoreModel = Mockito.spy(this.origRxGatewayStoreModel); this.initRequestCapture(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java index 71d78aefdb2d..2d81311d8417 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCacheTest.java @@ -1647,16 +1647,16 @@ private HttpClientUnderTestWrapper getHttpClientUnderTestWrapper(Configs configs } /** - * Verifies that client-level customHeaders (e.g., workload-id) are included in + * Verifies that client-level additionalHeaders (e.g., workload-id) are included in * GatewayAddressCache's defaultRequestHeaders, which are sent on every address * resolution request. */ @Test(groups = { "unit" }) - public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { + public void additionalHeadersIncludedInDefaultRequestHeaders() throws Exception { URI serviceEndpoint = new URI("https://localhost"); - Map customHeaders = new HashMap<>(); - customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); GatewayAddressCache cache = new GatewayAddressCache( mockDiagnosticsClientContext(), @@ -1670,7 +1670,7 @@ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { null, null, null, - customHeaders); + additionalHeaders); Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); defaultRequestHeadersField.setAccessible(true); @@ -1681,18 +1681,18 @@ public void customHeadersIncludedInDefaultRequestHeaders() throws Exception { } /** - * Verifies that customHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) + * Verifies that additionalHeaders do NOT overwrite SDK system headers (USER_AGENT, VERSION, etc.) * in GatewayAddressCache's defaultRequestHeaders. putIfAbsent is used so SDK headers - * set before customHeaders are preserved. + * set before additionalHeaders are preserved. */ @Test(groups = { "unit" }) - public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { + public void additionalHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { URI serviceEndpoint = new URI("https://localhost"); - Map customHeaders = new HashMap<>(); - customHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); - customHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); - customHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.USER_AGENT, "malicious-agent"); + additionalHeaders.put(HttpConstants.HttpHeaders.VERSION, "bad-version"); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); GatewayAddressCache cache = new GatewayAddressCache( mockDiagnosticsClientContext(), @@ -1706,7 +1706,7 @@ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { null, null, null, - customHeaders); + additionalHeaders); Field defaultRequestHeadersField = GatewayAddressCache.class.getDeclaredField("defaultRequestHeaders"); defaultRequestHeadersField.setAccessible(true); @@ -1716,16 +1716,16 @@ public void customHeadersDoNotOverwriteSdkSystemHeaders() throws Exception { // SDK headers should NOT be overwritten assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.USER_AGENT)).isNotEqualTo("malicious-agent"); assertThat(defaultRequestHeaders.get(HttpConstants.HttpHeaders.VERSION)).isEqualTo(HttpConstants.Versions.CURRENT_VERSION); - // Custom header should still be added + // Additional header should still be added assertThat(defaultRequestHeaders).containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "25"); } /** - * Verifies that when customHeaders is null, GatewayAddressCache's defaultRequestHeaders + * Verifies that when additionalHeaders is null, GatewayAddressCache's defaultRequestHeaders * contains only SDK system headers and no extra entries. */ @Test(groups = { "unit" }) - public void nullCustomHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { + public void nullAdditionalHeadersDoesNotAffectDefaultRequestHeaders() throws Exception { URI serviceEndpoint = new URI("https://localhost"); GatewayAddressCache cache = new GatewayAddressCache( diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java new file mode 100644 index 000000000000..2fe594e60831 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.rx; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncContainer; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosHeaderName; +import com.azure.cosmos.TestObject; +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.ResourceType; +import com.azure.cosmos.implementation.RxDocumentServiceRequest; +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.test.implementation.interceptor.CosmosInterceptorHelper; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Interceptor-based integration tests for the workload-id / additional headers feature + * running in Direct mode (RNTBD). + *

+ * These tests use {@link CosmosInterceptorHelper#registerTransportClientInterceptor} to + * capture {@link RxDocumentServiceRequest} objects at the TransportClient level and assert + * that the {@code x-ms-cosmos-workload-id} header is present with the correct value on + * wire requests. The interceptor only fires in Direct mode because Gateway mode data-plane + * requests go through {@code RxGatewayStoreModel → HttpClient}, bypassing the + * TransportClient entirely. + *

+ * Uses {@code @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp")} to ensure + * all tests run exclusively in Direct/TCP mode where the transport interceptor is active. + *

+ * Gateway-mode header injection is separately verified by: + *

    + *
  • Unit tests in {@code RxGatewayStoreModelTest} (data-plane)
  • + *
  • Unit tests in {@code GatewayAddressCacheTest} (metadata requests)
  • + *
  • Smoke tests in {@link WorkloadIdE2ETests} (Gateway mode end-to-end)
  • + *
+ */ +public class WorkloadIdDirectInterceptorTests extends TestSuiteBase { + + private static final String DATABASE_ID = "workloadIdDirectTestDb-" + UUID.randomUUID(); + private static final String CONTAINER_ID = "workloadIdDirectTestContainer-" + UUID.randomUUID(); + + private CosmosAsyncClient clientWithWorkloadId; + private CosmosAsyncDatabase database; + private CosmosAsyncContainer container; + + @Factory(dataProvider = "simpleClientBuildersWithJustDirectTcp") + public WorkloadIdDirectInterceptorTests(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) + public void beforeClass() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + + clientWithWorkloadId = getClientBuilder() + .additionalHeaders(headers) + .buildAsyncClient(); + + database = createDatabase(clientWithWorkloadId, DATABASE_ID); + + PartitionKeyDefinition partitionKeyDef = new PartitionKeyDefinition(); + ArrayList paths = new ArrayList<>(); + paths.add("/mypk"); + partitionKeyDef.setPaths(paths); + CosmosContainerProperties containerProperties = new CosmosContainerProperties(CONTAINER_ID, partitionKeyDef); + database.createContainer(containerProperties).block(); + container = database.getContainer(CONTAINER_ID); + } + + /** + * Verifies that the client-level workload-id header (value "15") is present on the + * {@link RxDocumentServiceRequest} for create (data-plane) operation in Direct mode. + *

+ * Registers a transport client interceptor that captures every request before it hits + * the RNTBD wire. After performing a createItem, asserts that at least one captured + * request with {@code ResourceType.Document} carries the {@code x-ms-cosmos-workload-id} + * header with value {@code "15"}. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyWorkloadIdHeaderPresentOnDataPlaneRequest() { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithWorkloadId, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + // In Direct mode, the interceptor MUST capture requests — fail if it didn't + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that at least one Document-type request carries the workload-id header + boolean foundWorkloadIdOnDocument = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> "15".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID))); + + assertThat(foundWorkloadIdOnDocument) + .as("Expected workload-id header '15' on at least one Document request") + .isTrue(); + } + + /** + * Verifies that a per-request workload-id override (value "30") is present on the wire + * request instead of the client-level default (value "15"). + *

+ * This confirms that the request-level header set via + * {@link CosmosItemRequestOptions#setAdditionalHeaders(Map)} takes precedence over + * the client-level header set via {@link CosmosClientBuilder#additionalHeaders(Map)} + * in Direct mode (RNTBD). + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyRequestLevelOverrideOnWire() { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithWorkloadId, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30"); + + CosmosItemRequestOptions options = new CosmosItemRequestOptions() + .setAdditionalHeaders(requestHeaders); + + TestObject doc = TestObject.create(); + container.createItem(doc, new PartitionKey(doc.getMypk()), options).block(); + + // In Direct mode, the interceptor MUST capture requests — fail if it didn't + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that the Document request carries the overridden value "30", not the client default "15" + boolean foundOverriddenWorkloadId = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> "30".equals(r.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID))); + + assertThat(foundOverriddenWorkloadId) + .as("Expected workload-id header '30' (request-level override) on Document request") + .isTrue(); + } + + /** + * Negative test: verifies that a client created WITHOUT additional headers does NOT + * have the workload-id header on wire requests in Direct mode. Ensures the header is + * only present when explicitly configured. + */ + @Test(groups = { "emulator" }, timeOut = TIMEOUT) + public void verifyNoWorkloadIdHeaderWhenNotConfigured() { + CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) + .buildAsyncClient(); + + try { + ConcurrentLinkedQueue capturedRequests = new ConcurrentLinkedQueue<>(); + + CosmosInterceptorHelper.registerTransportClientInterceptor( + clientWithoutHeaders, + (request, storeResponse) -> { + capturedRequests.add(request); + return storeResponse; + } + ); + + CosmosAsyncContainer c = clientWithoutHeaders + .getDatabase(DATABASE_ID) + .getContainer(CONTAINER_ID); + + TestObject doc = TestObject.create(); + c.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + + // In Direct mode, the interceptor MUST capture requests + assertThat(capturedRequests) + .as("Transport interceptor should capture requests in Direct mode") + .isNotEmpty(); + + // Assert that NO Document-type request carries the workload-id header + boolean anyDocRequestHasWorkloadId = capturedRequests.stream() + .filter(r -> r.getResourceType() == ResourceType.Document) + .anyMatch(r -> r.getHeaders().containsKey(HttpConstants.HttpHeaders.WORKLOAD_ID)); + + assertThat(anyDocRequestHasWorkloadId) + .as("Expected NO workload-id header on Document requests when not configured") + .isFalse(); + } finally { + safeClose(clientWithoutHeaders); + } + } + + @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteDatabase(database); + safeClose(clientWithWorkloadId); + } +} + diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index 3bf2fdafce7c..c41d8293f88f 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -6,8 +6,8 @@ import com.azure.cosmos.CosmosAsyncContainer; import com.azure.cosmos.CosmosAsyncDatabase; import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.TestObject; -import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.models.CosmosContainerProperties; import com.azure.cosmos.models.CosmosItemRequestOptions; import com.azure.cosmos.models.CosmosItemResponse; @@ -27,13 +27,21 @@ import static org.assertj.core.api.Assertions.assertThat; /** - * End-to-end integration tests for the custom headers / workload-id feature. + * End-to-end smoke tests for the additional headers / workload-id feature in Gateway mode. *

- * Test type: EMULATOR INTEGRATION TEST — requires the Cosmos DB Emulator to be running locally. + * Test type: EMULATOR INTEGRATION TEST — requires a Cosmos DB account or emulator. *

* Uses {@code @Factory(dataProvider = "simpleClientBuilderGatewaySession")} to run all tests - * against both Gateway mode (HTTP headers) and Direct mode (RNTBD binary token 0x00DC), - * ensuring the workload-id header is correctly encoded and sent in both transport paths. + * in Gateway mode with Session consistency. + *

+ * These are smoke tests — they verify CRUD/query operations succeed (status code + * 200/201/204) when the workload-id header is set. They prove the header doesn't break + * anything but do NOT assert the header is actually present on the wire request. + *

+ * For wire-level assertion tests that verify the header is actually present on + * {@link com.azure.cosmos.implementation.RxDocumentServiceRequest}, see + * {@link WorkloadIdDirectInterceptorTests} which runs in Direct mode (RNTBD) using + * the transport client interceptor. */ public class WorkloadIdE2ETests extends TestSuiteBase { @@ -51,11 +59,11 @@ public WorkloadIdE2ETests(CosmosClientBuilder clientBuilder) { @BeforeClass(groups = { "emulator" }, timeOut = SETUP_TIMEOUT) public void beforeClass() { - Map headers = new HashMap<>(); - headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); clientWithWorkloadId = getClientBuilder() - .customHeaders(headers) + .additionalHeaders(headers) .buildAsyncClient(); database = createDatabase(clientWithWorkloadId, DATABASE_ID); @@ -71,7 +79,7 @@ public void beforeClass() { /** * verifies that a create (POST) operation succeeds when the client - * has a workload-id custom header set at the builder level. Confirms the header + * has a workload-id additional header set at the builder level. Confirms the header * flows through the request pipeline without causing errors. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) @@ -145,7 +153,7 @@ public void deleteItemWithClientLevelWorkloadId() { /** * Verifies that a per-request workload-id header override via - * {@code CosmosItemRequestOptions.setHeader()} works. The request-level header + * {@code CosmosItemRequestOptions.setAdditionalHeaders()} works. The request-level header * (value "30") should take precedence over the client-level default (value "15"). */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) @@ -153,8 +161,11 @@ public void createItemWithRequestLevelWorkloadIdOverride() { // Verify per-request header override works — request-level should take precedence TestObject doc = TestObject.create(); + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "30"); + CosmosItemRequestOptions options = new CosmosItemRequestOptions() - .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "30"); + .setAdditionalHeaders(requestHeaders); CosmosItemResponse response = container .createItem(doc, new PartitionKey(doc.getMypk()), options) @@ -166,7 +177,7 @@ public void createItemWithRequestLevelWorkloadIdOverride() { /** * Verifies that a cross-partition query operation succeeds when the client has a - * workload-id custom header. Confirms the header flows correctly through the + * workload-id additional header. Confirms the header flows correctly through the * query pipeline and does not affect result correctness. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) @@ -187,7 +198,7 @@ public void queryItemsWithClientLevelWorkloadId() { /** * Verifies that a per-request workload-id header override on - * {@code CosmosQueryRequestOptions.setHeader()} works for query operations. + * {@code CosmosQueryRequestOptions.setAdditionalHeaders()} works for query operations. * The request-level header (value "42") should take precedence over the * client-level default. */ @@ -197,8 +208,11 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() { TestObject doc = TestObject.create(); container.createItem(doc, new PartitionKey(doc.getMypk()), new CosmosItemRequestOptions()).block(); + Map requestHeaders = new HashMap<>(); + requestHeaders.put(CosmosHeaderName.WORKLOAD_ID, "42"); + CosmosQueryRequestOptions queryOptions = new CosmosQueryRequestOptions() - .setHeader(HttpConstants.HttpHeaders.WORKLOAD_ID, "42"); + .setAdditionalHeaders(requestHeaders); long count = container .queryItems("SELECT * FROM c WHERE c.id = '" + doc.getId() + "'", queryOptions, TestObject.class) @@ -210,13 +224,13 @@ public void queryItemsWithRequestLevelWorkloadIdOverride() { } /** - * Regression test: verifies that a client created without any custom headers - * continues to work normally. Ensures the custom headers feature does not + * Regression test: verifies that a client created without any additional headers + * continues to work normally. Ensures the additional headers feature does not * introduce regressions for clients that do not use it. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) - public void clientWithNoCustomHeadersStillWorks() { - // Verify that a client without custom headers works normally (no regression) + public void clientWithNoAdditionalHeadersStillWorks() { + // Verify that a client without additional headers works normally (no regression) CosmosAsyncClient clientWithoutHeaders = copyCosmosClientBuilder(getClientBuilder()) .buildAsyncClient(); @@ -238,15 +252,15 @@ public void clientWithNoCustomHeadersStillWorks() { } /** - * Verifies that a client created with an empty custom headers map works normally. - * An empty map should behave identically to no custom headers — no errors, + * Verifies that a client created with an empty additional headers map works normally. + * An empty map should behave identically to no additional headers — no errors, * no unexpected behavior. */ @Test(groups = { "emulator" }, timeOut = TIMEOUT) - public void clientWithEmptyCustomHeaders() { - // Verify that a client with empty custom headers map works normally + public void clientWithEmptyAdditionalHeaders() { + // Verify that a client with empty additional headers map works normally CosmosAsyncClient clientWithEmptyHeaders = copyCosmosClientBuilder(getClientBuilder()) - .customHeaders(new HashMap<>()) + .additionalHeaders(new HashMap<>()) .buildAsyncClient(); try { @@ -266,20 +280,21 @@ public void clientWithEmptyCustomHeaders() { } } + /** - * Verifies that unknown headers in customHeaders are rejected by the allowlist. - * In Direct mode (RNTBD), unknown headers are silently dropped, so the allowlist - * ensures consistent behavior across Gateway and Direct modes. + * Verifies that the {@link CosmosHeaderName} enum-based allowlist rejects unknown + * header names at client build time. Attempting to set an unrecognized header via + * {@code additionalHeaders()} should throw {@link IllegalArgumentException} from + * {@link CosmosHeaderName#fromString(String)}, preventing arbitrary headers from + * being sent. */ - @Test(groups = { "emulator" }, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) - public void unknownCustomHeadersRejectedByAllowlist() { - Map headers = new HashMap<>(); - headers.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "20"); - headers.put("x-ms-custom-test-header", "test-value"); - - // Should throw IllegalArgumentException due to unknown header - copyCosmosClientBuilder(getClientBuilder()) - .customHeaders(headers); + @Test(groups = { "emulator" }, timeOut = TIMEOUT, + expectedExceptions = IllegalArgumentException.class) + public void unknownAdditionalHeadersRejectedByAllowlist() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); + // Use fromString with an unknown header — should throw IllegalArgumentException + CosmosHeaderName.fromString("x-unknown-header"); } @AfterClass(groups = { "emulator" }, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java new file mode 100644 index 000000000000..7be67765d706 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; + +import java.util.Arrays; +import java.util.Map; + +/** + * Defines the set of additional headers that can be set on a {@link CosmosClientBuilder} + * via {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * Only headers with RNTBD encoding support are included in this enum, ensuring consistent + * behavior across both Gateway mode (HTTP) and Direct mode (RNTBD binary protocol). + */ +public enum CosmosHeaderName { + + /** + * The workload ID header ({@code x-ms-cosmos-workload-id}). + *

+ * Valid values: a string representation of an integer (e.g., {@code "15"}). + * The service accepts values in the range 1–50 for Azure Monitor metrics attribution. + * The SDK validates that the value is a valid integer but does not enforce range limits — + * range validation is the backend's responsibility. + */ + WORKLOAD_ID(HttpConstants.HttpHeaders.WORKLOAD_ID); + + private final String headerName; + + CosmosHeaderName(String headerName) { + this.headerName = headerName; + } + + /** + * Gets the canonical HTTP header name string (e.g., {@code "x-ms-cosmos-workload-id"}). + * + * @return the header name string + */ + public String getHeaderName() { + return this.headerName; + } + + /** + * Converts a header name string to the corresponding {@link CosmosHeaderName} enum value. + *

+ * This is primarily used by the Spark connector, which parses header names from JSON + * configuration strings and needs to convert them to enum values before calling + * {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. + * + * @param headerName the header name string (e.g., {@code "x-ms-cosmos-workload-id"}) + * @return the matching {@link CosmosHeaderName} + * @throws IllegalArgumentException if the header name does not match any known enum value + */ + public static CosmosHeaderName fromString(String headerName) { + for (CosmosHeaderName name : values()) { + if (name.headerName.equalsIgnoreCase(headerName)) { + return name; + } + } + throw new IllegalArgumentException( + "Unknown header: '" + headerName + "'. Allowed headers: " + Arrays.toString(values())); + } + + /** + * Validates all entries in an additional-headers map. + *

+ * Each {@link CosmosHeaderName} enum value carries its own validation rules. Currently: + *

    + *
  • {@link #WORKLOAD_ID}: value must be a valid integer string
  • + *
+ *

+ * This method is called by {@link CosmosClientBuilder#additionalHeaders(Map)} and + * by every request-options class's {@code setAdditionalHeaders} method, so the + * validation logic lives in one place. + * + * @param additionalHeaders the map to validate (may be null — no-op in that case) + * @throws IllegalArgumentException if any header value fails validation + */ + public static void validateAdditionalHeaders(Map additionalHeaders) { + if (additionalHeaders == null) { + return; + } + for (Map.Entry entry : additionalHeaders.entrySet()) { + CosmosHeaderName key = entry.getKey(); + String value = entry.getValue(); + + if (WORKLOAD_ID == key && value != null) { + try { + Integer.parseInt(value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid value '" + value + "' for header '" + key.getHeaderName() + + "'. The value must be a valid integer.", e); + } + } + } + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java index 35f6c64c0079..1a0311a90905 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxGatewayStoreModel.java @@ -91,7 +91,7 @@ public class RxGatewayStoreModel implements RxStoreModel, HttpTransportSerialize private GatewayServiceConfigurationReader gatewayServiceConfigurationReader; private RxClientCollectionCache collectionCache; private GatewayServerErrorInjector gatewayServerErrorInjector; - private final Map customHeaders; + private final Map additionalHeaders; public RxGatewayStoreModel( DiagnosticsClientContext clientContext, @@ -102,7 +102,7 @@ public RxGatewayStoreModel( GlobalEndpointManager globalEndpointManager, HttpClient httpClient, ApiType apiType, - Map customHeaders) { + Map additionalHeaders) { this.clientContext = clientContext; @@ -118,7 +118,7 @@ public RxGatewayStoreModel( this.httpClient = httpClient; this.sessionContainer = sessionContainer; - this.customHeaders = customHeaders; + this.additionalHeaders = additionalHeaders; } public RxGatewayStoreModel(RxGatewayStoreModel inner) { @@ -130,7 +130,7 @@ public RxGatewayStoreModel(RxGatewayStoreModel inner) { this.httpClient = inner.httpClient; this.sessionContainer = inner.sessionContainer; - this.customHeaders = inner.customHeaders; + this.additionalHeaders = inner.additionalHeaders; } protected Map getDefaultHeaders( @@ -283,10 +283,10 @@ public Mono performRequest(RxDocumentServiceRequest r request.requestContext.cosmosDiagnostics = clientContext.createDiagnostics(); } - // Apply client-level custom headers (e.g., workload-id) to all requests + // Apply client-level additional headers (e.g., workload-id) to all requests // including metadata requests (collection cache, partition key range, etc.) - if (this.customHeaders != null && !this.customHeaders.isEmpty()) { - for (Map.Entry entry : this.customHeaders.entrySet()) { + if (this.additionalHeaders != null && !this.additionalHeaders.isEmpty()) { + for (Map.Entry entry : this.additionalHeaders.entrySet()) { // Only set if not already present — request-level headers take precedence if (!request.getHeaders().containsKey(entry.getKey())) { request.getHeaders().put(entry.getKey(), entry.getValue()); From a79808d6ca2be4a189fcf743780034c6b94fe02c Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:47:52 -0500 Subject: [PATCH 08/13] fix: updated chnage log text --- sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 2 +- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index 6d23a6f92cf2..31851ffaca31 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 49da84cfc8d4..070dbc48e1d9 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 1290801e56f2..122a54604f38 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index 8660d2ba7b59..94ac0cfa7b0e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index 6149df9f05e4..6a42a70828b4 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.45.0-beta.1 (Unreleased) #### Features Added -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 424aef4e3a30..c300f7185825 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -4,7 +4,7 @@ #### Features Added * Added support for N-Region synchronous commit feature - See [PR 47757](https://github.com/Azure/azure-sdk-for-java/pull/47757) -* Added `additionalHeaders` support to allow setting additional HTTP headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) +* Added `additionalHeaders` support to allow setting additional headers (e.g., `x-ms-cosmos-workload-id`) that are sent with every request. - See [PR 48128](https://github.com/Azure/azure-sdk-for-java/pull/48128) #### Breaking Changes From 9ce5d98a4971524969315d137e6583c3e3ecb798 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:54:54 -0500 Subject: [PATCH 09/13] fix: addressing comments --- .../src/main/java/com/azure/cosmos/CosmosClientBuilder.java | 4 +++- .../directconnectivity/GatewayAddressCache.java | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 129b4b578c58..5b5e80639fed 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -758,7 +758,9 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { */ public CosmosClientBuilder additionalHeaders(Map additionalHeaders) { CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); - this.additionalHeaders = additionalHeaders; + this.additionalHeaders = additionalHeaders != null + ? new HashMap<>(additionalHeaders) + : null; return this; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java index 119c7309256d..1c09f37e6fa8 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/GatewayAddressCache.java @@ -166,8 +166,10 @@ public GatewayAddressCache( HttpConstants.HttpHeaders.SDK_SUPPORTED_CAPABILITIES, HttpConstants.SDKSupportedCapabilities.SUPPORTED_CAPABILITIES); - // Apply client-level additional headers (e.g., workload-id) to metadata requests - // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten + // Apply client-level additional headers (e.g., workload-id) to address resolution requests. + // Address resolution is the one metadata path that bypasses RxGatewayStoreModel (which + // handles all other metadata requests) — GatewayAddressCache calls httpClient.send() directly. + // Use putIfAbsent to ensure SDK system headers (USER_AGENT, VERSION, etc.) are not overwritten. if (additionalHeaders != null && !additionalHeaders.isEmpty()) { for (Map.Entry entry : additionalHeaders.entrySet()) { this.defaultRequestHeaders.putIfAbsent(entry.getKey(), entry.getValue()); From 289d321751ed2ebc68bfe1312f6b341b0395dc03 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:28:53 -0500 Subject: [PATCH 10/13] fix: addressing comments --- .../cosmos/spark/CosmosClientCache.scala | 6 +- .../azure/cosmos/AdditionalHeadersTests.java | 312 --------------- .../azure/cosmos/WorkloadIdHeaderTests.java | 367 ++++++++++++++++++ .../ThinClientStoreModelTest.java | 129 +++++- .../azure/cosmos/rx/WorkloadIdE2ETests.java | 2 +- .../com/azure/cosmos/CosmosClientBuilder.java | 6 +- .../com/azure/cosmos/CosmosHeaderName.java | 103 ++++- .../implementation/RxDocumentClientImpl.java | 9 +- .../implementation/ThinClientStoreModel.java | 5 +- .../models/CosmosBatchRequestOptions.java | 2 +- .../models/CosmosBulkExecutionOptions.java | 2 +- .../CosmosChangeFeedRequestOptions.java | 2 +- .../models/CosmosContainerRequestOptions.java | 43 ++ .../models/CosmosDatabaseRequestOptions.java | 43 ++ .../models/CosmosItemRequestOptions.java | 2 +- .../models/CosmosQueryRequestOptions.java | 2 +- .../models/CosmosReadManyRequestOptions.java | 2 +- .../CosmosStoredProcedureRequestOptions.java | 43 ++ 18 files changed, 728 insertions(+), 352 deletions(-) delete mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 3a0174a9a9a8..49b051a97276 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -716,11 +716,11 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { // These headers are attached to every Cosmos DB request made by this client instance. // Converts Map[String, String] from Spark config to Map[CosmosHeaderName, String] for the builder. if (cosmosClientConfiguration.additionalHeaders.isDefined) { - val enumHeaders = new java.util.HashMap[CosmosHeaderName, String]() + val headerMap = new java.util.HashMap[CosmosHeaderName, String]() for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) { - enumHeaders.put(CosmosHeaderName.fromString(key), value) + headerMap.put(CosmosHeaderName.fromString(key), value) } - builder.additionalHeaders(enumHeaders) + builder.additionalHeaders(headerMap) } var client = builder.buildAsyncClient() diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java deleted file mode 100644 index 9596988db621..000000000000 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/AdditionalHeadersTests.java +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.cosmos; - -import com.azure.cosmos.implementation.HttpConstants; -import com.azure.cosmos.models.CosmosBatchRequestOptions; -import com.azure.cosmos.models.CosmosBulkExecutionOptions; -import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; -import com.azure.cosmos.models.CosmosItemRequestOptions; -import com.azure.cosmos.models.CosmosQueryRequestOptions; -import com.azure.cosmos.models.CosmosReadManyRequestOptions; -import com.azure.cosmos.models.FeedRange; -import org.testng.annotations.Test; - -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for the additional headers (workload-id) feature on CosmosClientBuilder and request options classes. - *

- * These tests verify the public API surface: builder fluent methods, getter behavior, - * null/empty handling, CosmosHeaderName enum, and that setAdditionalHeaders() is publicly accessible - * on all request options classes. - */ -public class AdditionalHeadersTests { - - /** - * Verifies that additional headers (e.g., workload-id) set via CosmosClientBuilder.additionalHeaders() - * are stored correctly and retrievable via getAdditionalHeaders(). - */ - @Test(groups = { "unit" }) - public void additionalHeadersSetOnBuilder() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(headers); - - assertThat(builder.getAdditionalHeaders()) - .containsEntry("x-ms-cosmos-workload-id", "25"); - } - - /** - * Verifies that passing null to additionalHeaders() does not throw and that - * getAdditionalHeaders() returns null, ensuring graceful null handling. - */ - @Test(groups = { "unit" }) - public void additionalHeadersNullHandledGracefully() { - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(null); - - assertThat(builder.getAdditionalHeaders()).isNull(); - } - - /** - * Verifies that passing an empty map to additionalHeaders() is accepted and - * getAdditionalHeaders() returns an empty (not null) map. - */ - @Test(groups = { "unit" }) - public void additionalHeadersEmptyMapHandled() { - Map emptyHeaders = new HashMap<>(); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(emptyHeaders); - - assertThat(builder.getAdditionalHeaders()).isEmpty(); - } - - /** - * Verifies that CosmosHeaderName.WORKLOAD_ID maps to the correct header string. - */ - @Test(groups = { "unit" }) - public void cosmosHeaderNameWorkloadIdValue() { - assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName()) - .isEqualTo("x-ms-cosmos-workload-id"); - } - - /** - * Verifies that CosmosHeaderName.fromString() resolves known header strings to the - * correct enum value. This is used by the Spark connector to convert config strings - * to enum keys. - */ - @Test(groups = { "unit" }) - public void cosmosHeaderNameFromStringResolvesKnownHeader() { - CosmosHeaderName name = CosmosHeaderName.fromString("x-ms-cosmos-workload-id"); - assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID); - } - - /** - * Verifies that CosmosHeaderName.fromString() is case-insensitive. - */ - @Test(groups = { "unit" }) - public void cosmosHeaderNameFromStringIsCaseInsensitive() { - CosmosHeaderName name = CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID"); - assertThat(name).isEqualTo(CosmosHeaderName.WORKLOAD_ID); - } - - /** - * Verifies that CosmosHeaderName.fromString() throws IllegalArgumentException - * for unknown header strings. This is the runtime equivalent of the compile-time - * safety provided by the enum — used when converting from Spark JSON config. - */ - @Test(groups = { "unit" }) - public void cosmosHeaderNameFromStringRejectsUnknownHeader() { - assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("x-ms-custom-header") - .hasMessageContaining("Unknown header"); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosItemRequestOptions - * and supports fluent chaining for per-request header overrides on CRUD operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnItemRequestOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); - - CosmosItemRequestOptions options = new CosmosItemRequestOptions() - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBatchRequestOptions - * and supports fluent chaining for per-request header overrides on batch operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnBatchRequestOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "20"); - - CosmosBatchRequestOptions options = new CosmosBatchRequestOptions() - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosChangeFeedRequestOptions - * and supports fluent chaining for per-request header overrides on change feed operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnChangeFeedRequestOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); - - CosmosChangeFeedRequestOptions options = CosmosChangeFeedRequestOptions - .createForProcessingFromBeginning(FeedRange.forFullRange()) - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosBulkExecutionOptions - * and supports fluent chaining for per-request header overrides on bulk ingestion operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnBulkExecutionOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "30"); - - CosmosBulkExecutionOptions options = new CosmosBulkExecutionOptions() - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosQueryRequestOptions - * and supports fluent chaining for per-request header overrides on query operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnQueryRequestOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "35"); - - CosmosQueryRequestOptions options = new CosmosQueryRequestOptions() - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that setAdditionalHeaders() is publicly accessible on CosmosReadManyRequestOptions - * and supports fluent chaining for per-request header overrides on read-many operations. - */ - @Test(groups = { "unit" }) - public void setAdditionalHeadersOnReadManyRequestOptions() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "40"); - - CosmosReadManyRequestOptions options = new CosmosReadManyRequestOptions() - .setAdditionalHeaders(headers); - - assertThat(options).isNotNull(); - } - - /** - * Verifies that the WORKLOAD_ID constant in HttpConstants.HttpHeaders is defined - * with the correct canonical header name "x-ms-cosmos-workload-id" as expected - * by the Cosmos DB service. - */ - @Test(groups = { "unit" }) - public void workloadIdHttpHeaderConstant() { - assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); - } - - /** - * Verifies that a non-numeric workload-id value is rejected at builder level with - * IllegalArgumentException. This covers both Gateway and Direct modes consistently - * (unlike RntbdRequestHeaders.addWorkloadId() which only covers Direct mode). - */ - @Test(groups = { "unit" }) - public void nonNumericWorkloadIdRejectedAtBuilderLevel() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); - - assertThatThrownBy(() -> new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("abc") - .hasMessageContaining("valid integer"); - } - - /** - * Verifies that out-of-range workload-id values (e.g., 51) are accepted by the SDK. - * Range validation [1, 50] is the backend's responsibility — the SDK only validates - * that the value is a valid integer. This avoids hardcoding a range the backend team - * might change in the future. - */ - @Test(groups = { "unit" }) - public void outOfRangeWorkloadIdAcceptedByBuilder() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "51"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(headers); - - assertThat(builder.getAdditionalHeaders()) - .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "51"); - } - - /** - * Verifies that a non-numeric workload-id value is rejected at request-options level - * (CosmosItemRequestOptions) with IllegalArgumentException, ensuring validation is - * symmetric with the builder level. - */ - @Test(groups = { "unit" }) - public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); - - assertThatThrownBy(() -> new CosmosItemRequestOptions() - .setAdditionalHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("abc") - .hasMessageContaining("valid integer"); - } - - /** - * Verifies that a non-numeric workload-id value is rejected at request-options level - * (CosmosBulkExecutionOptions) with IllegalArgumentException, ensuring validation is - * symmetric with the builder level. - */ - @Test(groups = { "unit" }) - public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); - - assertThatThrownBy(() -> new CosmosBulkExecutionOptions() - .setAdditionalHeaders(headers)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("not-a-number") - .hasMessageContaining("valid integer"); - } - - /** - * Verifies that a valid workload-id value passes builder validation. - */ - @Test(groups = { "unit" }) - public void validWorkloadIdAcceptedByBuilder() { - Map headers = new HashMap<>(); - headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); - - CosmosClientBuilder builder = new CosmosClientBuilder() - .endpoint("https://test.documents.azure.com:443/") - .key("dGVzdEtleQ==") - .additionalHeaders(headers); - - assertThat(builder.getAdditionalHeaders()) - .containsEntry(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); - } -} - diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java new file mode 100644 index 000000000000..2a9c6f8e2a02 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos; + +import com.azure.cosmos.implementation.HttpConstants; +import com.azure.cosmos.implementation.ImplementationBridgeHelpers; +import com.azure.cosmos.implementation.RequestOptions; +import com.azure.cosmos.models.CosmosBatchRequestOptions; +import com.azure.cosmos.models.CosmosBulkExecutionOptions; +import com.azure.cosmos.models.CosmosChangeFeedRequestOptions; +import com.azure.cosmos.models.CosmosContainerRequestOptions; +import com.azure.cosmos.models.CosmosDatabaseRequestOptions; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.CosmosQueryRequestOptions; +import com.azure.cosmos.models.CosmosReadManyRequestOptions; +import com.azure.cosmos.models.CosmosStoredProcedureRequestOptions; +import com.azure.cosmos.models.FeedRange; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Comprehensive unit tests for the workload-id / additional headers feature. + *

+ * Covers three layers: + *

    + *
  1. Public API surface — {@link CosmosHeaderName} constants, {@code CosmosClientBuilder.additionalHeaders()}, + * and that {@code setAdditionalHeaders()} is callable on every request options class.
  2. + *
  3. Validation — non-numeric workload-id rejected at builder and request-options levels; + * out-of-range values accepted (range enforcement is the backend's responsibility).
  4. + *
  5. Internal wiring — headers set via {@code setAdditionalHeaders()} actually reach + * {@code RequestOptions.getHeaders()}, which is what {@code RxDocumentClientImpl.getRequestHeaders()} + * reads to populate outbound request headers. Covers both data-plane and control-plane classes.
  6. + *
+ * E2E tests (SDK → Gateway/Backend → Jarvis) are in {@code WorkloadIdJarvisValidationTests}. + */ +public class WorkloadIdHeaderTests { + + private static final String WORKLOAD_ID_HEADER = HttpConstants.HttpHeaders.WORKLOAD_ID; + private static final String TEST_WORKLOAD_ID = "42"; + + // ============================================================================================== + // 1. CosmosHeaderName constants + // ============================================================================================== + + @Test(groups = { "unit" }) + public void workloadIdHttpHeaderConstant() { + assertThat(HttpConstants.HttpHeaders.WORKLOAD_ID).isEqualTo("x-ms-cosmos-workload-id"); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameWorkloadIdValue() { + assertThat(CosmosHeaderName.WORKLOAD_ID.getHeaderName()).isEqualTo("x-ms-cosmos-workload-id"); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringResolvesKnownHeader() { + assertThat(CosmosHeaderName.fromString("x-ms-cosmos-workload-id")).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringIsCaseInsensitive() { + assertThat(CosmosHeaderName.fromString("X-MS-COSMOS-WORKLOAD-ID")).isEqualTo(CosmosHeaderName.WORKLOAD_ID); + } + + @Test(groups = { "unit" }) + public void cosmosHeaderNameFromStringRejectsUnknownHeader() { + assertThatThrownBy(() -> CosmosHeaderName.fromString("x-ms-custom-header")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("x-ms-custom-header"); + } + + // ============================================================================================== + // 2. CosmosClientBuilder — additionalHeaders() + // ============================================================================================== + + @Test(groups = { "unit" }) + public void builderStoresAdditionalHeaders() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(WORKLOAD_ID_HEADER, "25"); + } + + @Test(groups = { "unit" }) + public void builderHandlesNullAdditionalHeaders() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(null); + + assertThat(builder.getAdditionalHeaders()).isNull(); + } + + @Test(groups = { "unit" }) + public void builderHandlesEmptyAdditionalHeaders() { + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(new HashMap<>()); + + assertThat(builder.getAdditionalHeaders()).isEmpty(); + } + + // ============================================================================================== + // 3. Validation — non-numeric rejected, out-of-range accepted + // ============================================================================================== + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBuilderLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc") + .hasMessageContaining("valid integer"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtItemRequestOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "abc"); + + assertThatThrownBy(() -> new CosmosItemRequestOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("abc"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtBulkExecutionOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); + + assertThatThrownBy(() -> new CosmosBulkExecutionOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not-a-number"); + } + + @Test(groups = { "unit" }) + public void nonNumericWorkloadIdRejectedAtDatabaseRequestOptionsLevel() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "not-a-number"); + + assertThatThrownBy(() -> new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not-a-number"); + } + + /** Range validation [1, 50] is the backend's responsibility — SDK only validates integer format. */ + @Test(groups = { "unit" }) + public void outOfRangeWorkloadIdAcceptedByBuilder() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "51"); + + CosmosClientBuilder builder = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + assertThat(builder.getAdditionalHeaders()) + .containsEntry(WORKLOAD_ID_HEADER, "51"); + } + + // ============================================================================================== + // 4. Internal wiring — workload-id reaches RequestOptions.getHeaders() + // This is the acceptance proof that the header will be on the wire. + // ============================================================================================== + + /** + * Data provider covering request options classes that have a no-arg {@code toRequestOptions()}. + * Includes both data-plane and control-plane classes. + */ + @DataProvider(name = "requestOptionsWithToRequestOptions") + public Object[][] requestOptionsWithToRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + return new Object[][] { + // Data-plane + { "CosmosItemRequestOptions", + reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosBatchRequestOptions", + reflectToRequestOptions(new CosmosBatchRequestOptions().setAdditionalHeaders(headers)) }, + // Control-plane + { "CosmosDatabaseRequestOptions", + reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosContainerRequestOptions", + reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(headers)) }, + { "CosmosStoredProcedureRequestOptions", + reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(headers)) }, + }; + } + + @Test(groups = { "unit" }, dataProvider = "requestOptionsWithToRequestOptions") + public void workloadIdReachesOutboundHeaders(String optionsClassName, RequestOptions requestOptions) { + assertThat(requestOptions.getHeaders()) + .as("workload-id should reach RequestOptions.getHeaders() for " + optionsClassName) + .isNotNull() + .containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosQueryRequestOptions uses a bridge accessor pattern (no simple toRequestOptions()). */ + @Test(groups = { "unit" }) + public void workloadIdReachesQueryRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + RequestOptions opts = ImplementationBridgeHelpers + .CosmosQueryRequestOptionsHelper + .getCosmosQueryRequestOptionsAccessor() + .toRequestOptions(new CosmosQueryRequestOptions().setAdditionalHeaders(headers)); + + assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosChangeFeedRequestOptions uses a bridge accessor that exposes getHeaders(). */ + @Test(groups = { "unit" }) + public void workloadIdReachesChangeFeedRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + Map extracted = ImplementationBridgeHelpers + .CosmosChangeFeedRequestOptionsHelper + .getCosmosChangeFeedRequestOptionsAccessor() + .getHeaders( + CosmosChangeFeedRequestOptions + .createForProcessingFromBeginning(FeedRange.forFullRange()) + .setAdditionalHeaders(headers)); + + assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosReadManyRequestOptions uses getImpl().applyToRequestOptions(). */ + @Test(groups = { "unit" }) + public void workloadIdReachesReadManyRequestOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + RequestOptions opts = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor() + .getImpl(new CosmosReadManyRequestOptions().setAdditionalHeaders(headers)) + .applyToRequestOptions(new RequestOptions()); + + assertThat(opts.getHeaders()).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + /** CosmosBulkExecutionOptions delegates to an internal CosmosBatchRequestOptions. */ + @Test(groups = { "unit" }) + public void workloadIdReachesBulkExecutionOptions() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, TEST_WORKLOAD_ID); + + Map extracted = reflectGetHeaders( + new CosmosBulkExecutionOptions().setAdditionalHeaders(headers)); + + assertThat(extracted).containsEntry(WORKLOAD_ID_HEADER, TEST_WORKLOAD_ID); + } + + // ============================================================================================== + // 5. Null/empty additionalHeaders does not inject workload-id into outbound headers + // ============================================================================================== + + @Test(groups = { "unit" }) + public void nullAdditionalHeadersDoesNotInjectWorkloadId() { + assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(null))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(null))); + } + + @Test(groups = { "unit" }) + public void emptyAdditionalHeadersDoesNotInjectWorkloadId() { + Map empty = new HashMap<>(); + assertNoWorkloadId(reflectToRequestOptions(new CosmosItemRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosDatabaseRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosContainerRequestOptions().setAdditionalHeaders(empty))); + assertNoWorkloadId(reflectToRequestOptions(new CosmosStoredProcedureRequestOptions().setAdditionalHeaders(empty))); + } + + // ============================================================================================== + // 6. Coverage matrix guard — setAdditionalHeaders() exists on all expected classes + // ============================================================================================== + + /** Guard against accidental removal of setAdditionalHeaders() from any request options class. */ + @Test(groups = { "unit" }) + public void allExpectedRequestOptionsClassesSupportAdditionalHeaders() { + Class[] expectedClasses = { + // Data-plane + CosmosItemRequestOptions.class, + CosmosBatchRequestOptions.class, + CosmosBulkExecutionOptions.class, + CosmosQueryRequestOptions.class, + CosmosReadManyRequestOptions.class, + CosmosChangeFeedRequestOptions.class, + // Control-plane + CosmosDatabaseRequestOptions.class, + CosmosContainerRequestOptions.class, + CosmosStoredProcedureRequestOptions.class, + }; + + for (Class clazz : expectedClasses) { + assertThat(hasSetAdditionalHeaders(clazz)) + .as(clazz.getSimpleName() + " should have setAdditionalHeaders()") + .isTrue(); + } + } + + // ============================================================================================== + // Helpers + // ============================================================================================== + + private static RequestOptions reflectToRequestOptions(Object cosmosRequestOptions) { + try { + java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("toRequestOptions"); + m.setAccessible(true); + return (RequestOptions) m.invoke(cosmosRequestOptions); + } catch (Exception e) { + throw new RuntimeException( + "Failed to call toRequestOptions() on " + cosmosRequestOptions.getClass().getSimpleName(), e); + } + } + + @SuppressWarnings("unchecked") + private static Map reflectGetHeaders(Object cosmosRequestOptions) { + try { + java.lang.reflect.Method m = cosmosRequestOptions.getClass().getDeclaredMethod("getHeaders"); + m.setAccessible(true); + return (Map) m.invoke(cosmosRequestOptions); + } catch (Exception e) { + throw new RuntimeException( + "Failed to call getHeaders() on " + cosmosRequestOptions.getClass().getSimpleName(), e); + } + } + + private static void assertNoWorkloadId(RequestOptions options) { + Map headers = options.getHeaders(); + if (headers != null) { + assertThat(headers).doesNotContainKey(WORKLOAD_ID_HEADER); + } + } + + private static boolean hasSetAdditionalHeaders(Class clazz) { + try { + clazz.getMethod("setAdditionalHeaders", Map.class); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } +} + diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java index 64cd7fe37115..1b5e3670ace0 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ThinClientStoreModelTest.java @@ -2,14 +2,19 @@ import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.http.HttpClient; +import com.azure.cosmos.implementation.http.HttpRequest; import com.azure.cosmos.implementation.routing.RegionalRoutingContext; import io.netty.channel.ConnectTimeoutException; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.testng.annotations.Test; import reactor.core.publisher.Mono; import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; public class ThinClientStoreModelTest { @@ -44,7 +49,8 @@ public void testThinClientStoreModel() throws Exception { ConsistencyLevel.SESSION, new UserAgentContainer(), globalEndpointManager, - httpClient); + httpClient, + null); RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( clientContext, @@ -58,4 +64,125 @@ public void testThinClientStoreModel() throws Exception { //no-op } } + + /** + * Verifies that additionalHeaders (e.g., workload-id) passed to ThinClientStoreModel's + * constructor are correctly propagated to the parent RxGatewayStoreModel and injected + * into outgoing requests via performRequest(). + *

+ * This is for the workload-id feature: ThinClientStoreModel extends + * RxGatewayStoreModel, and the additionalHeaders must flow through the constructor + * chain so that performRequest() injects them into every request — including + * metadata requests (collection cache, PKRange cache, etc.). + */ + @Test(groups = "unit") + public void testAdditionalHeadersFlowThroughThinClientStoreModel() throws Exception { + DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class); + Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig(); + Mockito + .doReturn(ImplementationBridgeHelpers + .CosmosDiagnosticsHelper + .getCosmosDiagnosticsAccessor() + .create(clientContext, 1d)) + .when(clientContext).createDiagnostics(); + + String sdkGlobalSessionToken = "1#100#1=20#2=5#3=30"; + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + Mockito.doReturn(sdkGlobalSessionToken).when(sessionContainer).resolveGlobalSessionToken(any()); + + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + // Capture the HttpRequest sent by performRequest() to verify headers + HttpClient httpClient = Mockito.mock(HttpClient.class); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + Mockito.when(httpClient.send(requestCaptor.capture(), any())) + .thenReturn(Mono.error(new ConnectTimeoutException())); + + // Set up additionalHeaders with workload-id + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put(HttpConstants.HttpHeaders.WORKLOAD_ID, "15"); + + ThinClientStoreModel storeModel = new ThinClientStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + additionalHeaders); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/fakeResourceFullName", + ResourceType.Document); + + try { + storeModel.performRequest(dsr).block(); + } catch (Exception e) { + // Expected — mock HTTP client throws ConnectTimeoutException + } + + // Verify that the workload-id header was injected into the request + assertThat(dsr.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)) + .as("workload-id header should be injected into request by performRequest()") + .isEqualTo("15"); + } + + /** + * Verifies that ThinClientStoreModel works correctly when additionalHeaders is null + * (the default case when no workload-id is configured). This ensures backward + * compatibility — the null case should not throw or inject unexpected headers. + */ + @Test(groups = "unit") + public void testNullAdditionalHeadersThinClientStoreModel() throws Exception { + DiagnosticsClientContext clientContext = Mockito.mock(DiagnosticsClientContext.class); + Mockito.doReturn(new DiagnosticsClientContext.DiagnosticsClientConfig()).when(clientContext).getConfig(); + Mockito + .doReturn(ImplementationBridgeHelpers + .CosmosDiagnosticsHelper + .getCosmosDiagnosticsAccessor() + .create(clientContext, 1d)) + .when(clientContext).createDiagnostics(); + + ISessionContainer sessionContainer = Mockito.mock(ISessionContainer.class); + Mockito.doReturn("1#100#1=20#2=5#3=30").when(sessionContainer).resolveGlobalSessionToken(any()); + + GlobalEndpointManager globalEndpointManager = Mockito.mock(GlobalEndpointManager.class); + Mockito.doReturn(new RegionalRoutingContext(new URI("https://localhost:8080"))) + .when(globalEndpointManager).resolveServiceEndpoint(any()); + + HttpClient httpClient = Mockito.mock(HttpClient.class); + Mockito.when(httpClient.send(any(), any())) + .thenReturn(Mono.error(new ConnectTimeoutException())); + + // Pass null for additionalHeaders — this is the default case + ThinClientStoreModel storeModel = new ThinClientStoreModel( + clientContext, + sessionContainer, + ConsistencyLevel.SESSION, + new UserAgentContainer(), + globalEndpointManager, + httpClient, + null); + + RxDocumentServiceRequest dsr = RxDocumentServiceRequest.createFromName( + clientContext, + OperationType.Read, + "/fakeResourceFullName", + ResourceType.Document); + + try { + storeModel.performRequest(dsr).block(); + } catch (Exception e) { + // Expected — mock HTTP client throws ConnectTimeoutException + } + + // Verify that no workload-id header was injected + assertThat(dsr.getHeaders().get(HttpConstants.HttpHeaders.WORKLOAD_ID)) + .as("workload-id header should NOT be present when additionalHeaders is null") + .isNull(); + } } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index c41d8293f88f..c9b06bddee58 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -282,7 +282,7 @@ public void clientWithEmptyAdditionalHeaders() { /** - * Verifies that the {@link CosmosHeaderName} enum-based allowlist rejects unknown + * Verifies that the {@link CosmosHeaderName} allowlist rejects unknown * header names at client build time. Attempting to set an unrecognized header via * {@code additionalHeaders()} should throw {@link IllegalArgumentException} from * {@link CosmosHeaderName#fromString(String)}, preventing arbitrary headers from diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java index 5b5e80639fed..9f999d899cfa 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosClientBuilder.java @@ -740,13 +740,13 @@ public CosmosClientBuilder userAgentSuffix(String userAgentSuffix) { /** * Sets additional HTTP headers that will be included with every request from this client. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. Currently + * {@link CosmosHeaderName} defines exactly which headers are supported. Currently * the only supported header is {@link CosmosHeaderName#WORKLOAD_ID} * ({@code x-ms-cosmos-workload-id}). *

* This restriction exists because in Direct mode (RNTBD), only headers with explicit - * encoding support are sent on the wire. The enum ensures consistent behavior across - * both Gateway and Direct modes. + * encoding support are sent on the wire. Using {@link CosmosHeaderName} ensures consistent + * behavior across both Gateway and Direct modes. *

* If the same header is also set on request options (e.g., * {@code CosmosItemRequestOptions.setAdditionalHeaders(Map)}), diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java index 7be67765d706..f7ed5d6f6f24 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosHeaderName.java @@ -4,17 +4,34 @@ import com.azure.cosmos.implementation.HttpConstants; -import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.StringJoiner; + +import static com.azure.cosmos.implementation.guava25.base.Preconditions.checkNotNull; /** * Defines the set of additional headers that can be set on a {@link CosmosClientBuilder} * via {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. *

- * Only headers with RNTBD encoding support are included in this enum, ensuring consistent + * Only headers with RNTBD encoding support are included, ensuring consistent * behavior across both Gateway mode (HTTP) and Direct mode (RNTBD binary protocol). + *

+ * This class uses the non-exhaustive final class pattern (rather than Java enum) for + * binary compatibility when new header names are added in future releases. See + * Azure SDK Java Guidelines — Enumerations. */ -public enum CosmosHeaderName { +public final class CosmosHeaderName { + + private final String headerName; + + private CosmosHeaderName(String headerName) { + checkNotNull(headerName, "Argument 'headerName' must not be null."); + this.headerName = headerName; + } /** * The workload ID header ({@code x-ms-cosmos-workload-id}). @@ -24,13 +41,13 @@ public enum CosmosHeaderName { * The SDK validates that the value is a valid integer but does not enforce range limits — * range validation is the backend's responsibility. */ - WORKLOAD_ID(HttpConstants.HttpHeaders.WORKLOAD_ID); - - private final String headerName; + public static final CosmosHeaderName WORKLOAD_ID = new CosmosHeaderName( + HttpConstants.HttpHeaders.WORKLOAD_ID); - CosmosHeaderName(String headerName) { - this.headerName = headerName; - } + // IMPORTANT: ADDITIONAL_HEADERS must be declared AFTER all public static final fields above, + // because Java initializes static fields in declaration order. If this map were declared + // before WORKLOAD_ID, the map would contain null values. + private static final Map ADDITIONAL_HEADERS = createAdditionalHeadersMap(); /** * Gets the canonical HTTP header name string (e.g., {@code "x-ms-cosmos-workload-id"}). @@ -42,30 +59,34 @@ public String getHeaderName() { } /** - * Converts a header name string to the corresponding {@link CosmosHeaderName} enum value. + * Converts a header name string to the corresponding {@link CosmosHeaderName} instance. *

* This is primarily used by the Spark connector, which parses header names from JSON - * configuration strings and needs to convert them to enum values before calling - * {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. + * configuration strings and needs to convert them to {@link CosmosHeaderName} instances + * before calling {@link CosmosClientBuilder#additionalHeaders(java.util.Map)}. * * @param headerName the header name string (e.g., {@code "x-ms-cosmos-workload-id"}) * @return the matching {@link CosmosHeaderName} - * @throws IllegalArgumentException if the header name does not match any known enum value + * @throws IllegalArgumentException if the header name does not match any known value */ public static CosmosHeaderName fromString(String headerName) { - for (CosmosHeaderName name : values()) { - if (name.headerName.equalsIgnoreCase(headerName)) { - return name; - } + checkNotNull(headerName, "Argument 'headerName' must not be null."); + + String normalizedName = headerName.trim().toLowerCase(Locale.ROOT); + CosmosHeaderName result = ADDITIONAL_HEADERS.getOrDefault(normalizedName, null); + + if (result == null) { + throw new IllegalArgumentException( + "Unknown header: '" + headerName + "'. Allowed headers: " + getValidValues()); } - throw new IllegalArgumentException( - "Unknown header: '" + headerName + "'. Allowed headers: " + Arrays.toString(values())); + + return result; } /** * Validates all entries in an additional-headers map. *

- * Each {@link CosmosHeaderName} enum value carries its own validation rules. Currently: + * Each {@link CosmosHeaderName} instance carries its own validation rules. Currently: *

    *
  • {@link #WORKLOAD_ID}: value must be a valid integer string
  • *
@@ -85,7 +106,7 @@ public static void validateAdditionalHeaders(Map addit CosmosHeaderName key = entry.getKey(); String value = entry.getValue(); - if (WORKLOAD_ID == key && value != null) { + if (WORKLOAD_ID.equals(key) && value != null) { try { Integer.parseInt(value); } catch (NumberFormatException e) { @@ -96,4 +117,44 @@ public static void validateAdditionalHeaders(Map addit } } } + + @Override + public String toString() { + return this.headerName; + } + + @Override + public int hashCode() { + return Objects.hash(CosmosHeaderName.class, this.headerName); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null) { + return false; + } + if (!(obj instanceof CosmosHeaderName)) { + return false; + } + + CosmosHeaderName other = (CosmosHeaderName) obj; + return Objects.equals(this.headerName, other.headerName); + } + + private static Map createAdditionalHeadersMap() { + Map map = new HashMap<>(); + map.put(HttpConstants.HttpHeaders.WORKLOAD_ID.toLowerCase(Locale.ROOT), WORKLOAD_ID); + return Collections.unmodifiableMap(map); + } + + private static String getValidValues() { + StringJoiner sj = new StringJoiner(", "); + for (CosmosHeaderName header : ADDITIONAL_HEADERS.values()) { + sj.add(header.headerName); + } + return sj.toString(); + } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index b13b6c02009c..3a6501bea8c1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -871,7 +871,8 @@ public void init(CosmosClientMetadataCachesSnapshot metadataCachesSnapshot, Func this.consistencyLevel, this.userAgentContainer, this.globalEndpointManager, - this.reactorHttpClient); + this.reactorHttpClient, + this.additionalHeaders); this.perPartitionFailoverConfigModifier = (databaseAccount -> { @@ -1045,14 +1046,16 @@ ThinClientStoreModel createThinProxy(ISessionContainer sessionContainer, ConsistencyLevel consistencyLevel, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, - HttpClient httpClient) { + HttpClient httpClient, + Map additionalHeaders) { return new ThinClientStoreModel( this, sessionContainer, consistencyLevel, userAgentContainer, globalEndpointManager, - httpClient); + httpClient, + additionalHeaders); } private HttpClient httpClient() { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java index ff139e203d2e..132d367b587d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ThinClientStoreModel.java @@ -47,7 +47,8 @@ public ThinClientStoreModel( ConsistencyLevel defaultConsistencyLevel, UserAgentContainer userAgentContainer, GlobalEndpointManager globalEndpointManager, - HttpClient httpClient) { + HttpClient httpClient, + Map additionalHeaders) { super( clientContext, sessionContainer, @@ -57,7 +58,7 @@ public ThinClientStoreModel( globalEndpointManager, httpClient, ApiType.SQL, - null); + additionalHeaders); String userAgent = userAgentContainer != null ? userAgentContainer.getUserAgent() diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java index ff3b335c8824..978f0a43ec37 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBatchRequestOptions.java @@ -157,7 +157,7 @@ RequestOptions toRequestOptions() { /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index 588f74cea99e..f576d31cb86b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -260,7 +260,7 @@ void setOperationContextAndListenerTuple(OperationContextAndListenerTuple operat /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java index b78499cd185c..f08dee8b04fe 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosChangeFeedRequestOptions.java @@ -567,7 +567,7 @@ public List getExcludedRegions() { /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java index 1a2cad1b85ed..1747d4605a80 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosContainerRequestOptions.java @@ -2,9 +2,13 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to Cosmos container. */ @@ -15,6 +19,7 @@ public final class CosmosContainerRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private ThroughputProperties throughputProperties; + private Map customOptions; /** * Gets the quotaInfoEnabled setting for cosmos container read requests in the Azure Cosmos DB database service. @@ -141,6 +146,39 @@ CosmosContainerRequestOptions setThroughputProperties(ThroughputProperties throu return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosContainerRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosContainerRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosContainerRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions options = new RequestOptions(); options.setIfMatchETag(getIfMatchETag()); @@ -149,6 +187,11 @@ RequestOptions toRequestOptions() { options.setSessionToken(sessionToken); options.setConsistencyLevel(consistencyLevel); options.setThroughputProperties(this.throughputProperties); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + options.setHeader(entry.getKey(), entry.getValue()); + } + } return options; } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java index a4817e14be7e..618bf03ffa99 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosDatabaseRequestOptions.java @@ -2,8 +2,12 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to cosmos database. */ @@ -11,6 +15,7 @@ public final class CosmosDatabaseRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private ThroughputProperties throughputProperties; + private Map customOptions; /** * Gets the If-Match (ETag) associated with the request in the Azure Cosmos DB service. @@ -73,11 +78,49 @@ CosmosDatabaseRequestOptions setThroughputProperties(ThroughputProperties throug return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosDatabaseRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosDatabaseRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosDatabaseRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions options = new RequestOptions(); options.setIfMatchETag(getIfMatchETag()); options.setIfNoneMatchETag(getIfNoneMatchETag()); options.setThroughputProperties(this.throughputProperties); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + options.setHeader(entry.getKey(), entry.getValue()); + } + } return options; } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java index d5447034274c..36c7a97c318a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosItemRequestOptions.java @@ -569,7 +569,7 @@ public CosmosItemRequestOptions setThresholdForDiagnosticsOnTracer(Duration thre /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java index ae998a39c360..18f0a9fd8ad0 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosQueryRequestOptions.java @@ -264,7 +264,7 @@ public CosmosQueryRequestOptions setExcludedRegions(List excludeRegions) /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java index 98505acdab98..04ac31f30f06 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosReadManyRequestOptions.java @@ -371,7 +371,7 @@ public Set getKeywordIdentifiers() { /** * Sets additional headers to be included with this specific request. *

- * The {@link CosmosHeaderName} enum defines exactly which headers are supported. + * The {@link CosmosHeaderName} class defines exactly which headers are supported. * This allows per-request header customization, such as setting a workload ID * that overrides the client-level default set via * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java index 02f360299c42..42486a1ad6b7 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosStoredProcedureRequestOptions.java @@ -2,9 +2,13 @@ // Licensed under the MIT License. package com.azure.cosmos.models; +import com.azure.cosmos.CosmosHeaderName; import com.azure.cosmos.ConsistencyLevel; import com.azure.cosmos.implementation.RequestOptions; +import java.util.HashMap; +import java.util.Map; + /** * Encapsulates options that can be specified for a request issued to cosmos stored procedure. */ @@ -15,6 +19,7 @@ public final class CosmosStoredProcedureRequestOptions { private String ifMatchETag; private String ifNoneMatchETag; private boolean scriptLoggingEnabled; + private Map customOptions; /** * Gets the If-Match (ETag) associated with the request in the Azure Cosmos DB service. @@ -158,6 +163,39 @@ public CosmosStoredProcedureRequestOptions setScriptLoggingEnabled(boolean scrip return this; } + /** + * Sets additional headers to be included with this specific request. + *

+ * The {@link CosmosHeaderName} class defines exactly which headers are supported. + * This allows per-request header customization, such as setting a workload ID + * that overrides the client-level default set via + * {@link com.azure.cosmos.CosmosClientBuilder#additionalHeaders(java.util.Map)}. + *

+ * If the same header is also set at the client level, the request-level value + * takes precedence. + * + * @param additionalHeaders map of {@link CosmosHeaderName} to value + * @return the CosmosStoredProcedureRequestOptions. + * @throws IllegalArgumentException if the workload-id value is not a valid integer + */ + public CosmosStoredProcedureRequestOptions setAdditionalHeaders(Map additionalHeaders) { + CosmosHeaderName.validateAdditionalHeaders(additionalHeaders); + if (additionalHeaders != null) { + for (Map.Entry entry : additionalHeaders.entrySet()) { + this.setHeader(entry.getKey().getHeaderName(), entry.getValue()); + } + } + return this; + } + + CosmosStoredProcedureRequestOptions setHeader(String name, String value) { + if (this.customOptions == null) { + this.customOptions = new HashMap<>(); + } + this.customOptions.put(name, value); + return this; + } + RequestOptions toRequestOptions() { RequestOptions requestOptions = new RequestOptions(); requestOptions.setIfMatchETag(getIfMatchETag()); @@ -166,6 +204,11 @@ RequestOptions toRequestOptions() { requestOptions.setPartitionKey(partitionKey); requestOptions.setSessionToken(sessionToken); requestOptions.setScriptLoggingEnabled(scriptLoggingEnabled); + if (this.customOptions != null) { + for (Map.Entry entry : this.customOptions.entrySet()) { + requestOptions.setHeader(entry.getKey(), entry.getValue()); + } + } return requestOptions; } } From 57ae651ecdf966b20007377563632ac2ed1f44ba Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:33:19 -0500 Subject: [PATCH 11/13] fix: fixing clone method --- .../azure/cosmos/WorkloadIdHeaderTests.java | 30 +++++++++++++++++++ .../azure/cosmos/CosmosBridgeInternal.java | 4 +++ 2 files changed, 34 insertions(+) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java index 2a9c6f8e2a02..41fc62f62ca3 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/WorkloadIdHeaderTests.java @@ -113,6 +113,36 @@ public void builderHandlesEmptyAdditionalHeaders() { assertThat(builder.getAdditionalHeaders()).isEmpty(); } + @Test(groups = { "unit" }) + public void clonedBuilderPreservesAdditionalHeaders() { + Map headers = new HashMap<>(); + headers.put(CosmosHeaderName.WORKLOAD_ID, "25"); + + CosmosClientBuilder original = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ==") + .additionalHeaders(headers); + + CosmosClientBuilder cloned = CosmosBridgeInternal.cloneCosmosClientBuilder(original); + + assertThat(cloned.getAdditionalHeaders()) + .as("cloned builder should preserve additionalHeaders") + .containsEntry(WORKLOAD_ID_HEADER, "25"); + } + + @Test(groups = { "unit" }) + public void clonedBuilderHandlesNullAdditionalHeaders() { + CosmosClientBuilder original = new CosmosClientBuilder() + .endpoint("https://test.documents.azure.com:443/") + .key("dGVzdEtleQ=="); + + CosmosClientBuilder cloned = CosmosBridgeInternal.cloneCosmosClientBuilder(original); + + assertThat(cloned.getAdditionalHeaders()) + .as("cloned builder should handle null additionalHeaders") + .isNull(); + } + // ============================================================================================== // 3. Validation — non-numeric rejected, out-of-range accepted // ============================================================================================== diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java index 81beda97d5b4..72c330425054 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java @@ -115,6 +115,10 @@ public static CosmosClientBuilder cloneCosmosClientBuilder(CosmosClientBuilder b .multipleWriteRegionsEnabled(builder.isMultipleWriteRegionsEnabled()) .readRequestsFallbackEnabled(builder.isReadRequestsFallbackEnabled()); +fixig if (builder.getAdditionalHeadersRaw() != null) { + copy.additionalHeaders(builder.getAdditionalHeadersRaw()); + } + return copy; } From 7703485e34b7d84daf395cf5c3e4a98820459238 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:13:35 -0500 Subject: [PATCH 12/13] fix: fixing validation --- .../cosmos/spark/CosmosClientCache.scala | 4 ++-- .../com/azure/cosmos/spark/CosmosConfig.scala | 19 +++++++++++++-- .../spark/CosmosClientConfigurationSpec.scala | 24 ++++++++++--------- .../azure/cosmos/CosmosBridgeInternal.java | 6 ++--- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 49b051a97276..0f0318057eee 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -713,8 +713,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } // Apply additional HTTP headers (e.g., workload-id) to the builder if configured. - // These headers are attached to every Cosmos DB request made by this client instance. - // Converts Map[String, String] from Spark config to Map[CosmosHeaderName, String] for the builder. + // Header name validation already happened at config-parse time in CosmosConfig.AdditionalHeadersConfig, + // so CosmosHeaderName.fromString() here is just a type conversion for already-validated keys. if (cosmosClientConfiguration.additionalHeaders.isDefined) { val headerMap = new java.util.HashMap[CosmosHeaderName, String]() for ((key, value) <- cosmosClientConfiguration.additionalHeaders.get) { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 870ee17841c5..10b0d5af8623 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -4,7 +4,7 @@ package com.azure.cosmos.spark import com.azure.core.management.AzureEnvironment -import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, ReadConsistencyStrategy, spark} +import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, CosmosHeaderName, ReadConsistencyStrategy, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings, Utils} @@ -740,14 +740,29 @@ private object CosmosAccountConfig extends BasicLoggingTrait { // {"x-ms-cosmos-workload-id": "15"} into a Scala Map[String, String] using Jackson. // These headers are converted to Map[CosmosHeaderName, String] and passed to // CosmosClientBuilder.additionalHeaders() in CosmosClientCache. + // + // Validation: After JSON parsing, every header name is validated via CosmosHeaderName.fromString() + // to fail fast at config-parse time rather than at runtime during client creation. + // This prevents Spark jobs from starting, allocating cluster resources, and only failing + // later when CosmosClientCache tries to convert String keys to CosmosHeaderName instances. private val AdditionalHeadersConfig = CosmosConfigEntry[Map[String, String]]( key = CosmosConfigNames.AdditionalHeaders, mandatory = false, parseFromStringFunction = headersJson => { try { val typeRef = new com.fasterxml.jackson.core.`type`.TypeReference[java.util.Map[String, String]]() {} - Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + val parsed = Utils.getSimpleObjectMapperWithAllowDuplicates.readValue(headersJson, typeRef).asScala.toMap + + // Fail fast: validate every header name is a known CosmosHeaderName at parse time. + // Without this, unknown headers like {"x-bad-header": "value"} would parse successfully + // and only blow up at runtime in CosmosClientCache when CosmosHeaderName.fromString() is called. + for (key <- parsed.keys) { + CosmosHeaderName.fromString(key) // throws IllegalArgumentException for unknown headers + } + + parsed } catch { + case e: IllegalArgumentException => throw e case e: Exception => throw new IllegalArgumentException( s"Invalid JSON for '${CosmosConfigNames.AdditionalHeaders}': '$headersJson'. " + "Expected format: {\"x-ms-cosmos-workload-id\": \"15\"}", e) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala index 8a2b6a8191f9..9752136066f4 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosClientConfigurationSpec.scala @@ -441,11 +441,14 @@ class CosmosClientConfigurationSpec extends UnitSpec { configuration.additionalHeaders shouldBe None } - // Verifies that spark.cosmos.additionalHeaders accepts multiple headers at the parsing level. - // The JSON is valid and CosmosClientConfiguration stores it as-is. - // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache when - // converting to Map[CosmosHeaderName, String] for CosmosClientBuilder.additionalHeaders(). - it should "reject unknown additional headers" in { + // Verifies that unknown header names in spark.cosmos.additionalHeaders are rejected + // at config-parse time. This ensures Spark jobs fail fast during configuration rather than + // starting, allocating cluster resources, and only failing later at runtime when + // CosmosClientCache tries to create the client. + // Note: CosmosConfigEntry.parse() wraps all parsing exceptions in RuntimeException, + // so the IllegalArgumentException from CosmosHeaderName.fromString() surfaces as + // RuntimeException with the original IllegalArgumentException as the cause. + it should "reject unknown additional headers at parse time" in { val userConfig = Map( "spark.cosmos.accountEndpoint" -> "https://localhost:8081", "spark.cosmos.accountKey" -> "xyz", @@ -453,12 +456,11 @@ class CosmosClientConfigurationSpec extends UnitSpec { ) val readConsistencyStrategy = ReadConsistencyStrategy.DEFAULT - val configuration = CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") - - // Parsing succeeds — the JSON is valid and CosmosClientConfiguration stores it as-is. - // The CosmosHeaderName.fromString() validation happens later in CosmosClientCache - configuration.additionalHeaders shouldBe defined - configuration.additionalHeaders.get should have size 2 + val thrown = the [RuntimeException] thrownBy { + CosmosClientConfiguration(userConfig, readConsistencyStrategy, sparkEnvironmentInfo = "") + } + thrown.getCause shouldBe a [IllegalArgumentException] + thrown.getCause.getMessage should include ("x-custom-header") } // Verifies that spark.cosmos.additionalHeaders handles an empty JSON object ("{}") gracefully, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java index 72c330425054..35d8fd6aece4 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosBridgeInternal.java @@ -115,9 +115,9 @@ public static CosmosClientBuilder cloneCosmosClientBuilder(CosmosClientBuilder b .multipleWriteRegionsEnabled(builder.isMultipleWriteRegionsEnabled()) .readRequestsFallbackEnabled(builder.isReadRequestsFallbackEnabled()); -fixig if (builder.getAdditionalHeadersRaw() != null) { - copy.additionalHeaders(builder.getAdditionalHeadersRaw()); - } + if (builder.getAdditionalHeadersRaw() != null) { + copy.additionalHeaders(builder.getAdditionalHeadersRaw()); + } return copy; } From 4b99a7e6ec1a61179ece18ffa5f5a1638e71bc97 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 19 Mar 2026 20:36:18 -0500 Subject: [PATCH 13/13] fix:fixing tests --- .../azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java | 5 ++++- .../test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java index 2fe594e60831..6882cac1193e 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdDirectInterceptorTests.java @@ -69,7 +69,10 @@ public void beforeClass() { Map headers = new HashMap<>(); headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); - clientWithWorkloadId = getClientBuilder() + // Clone the shared builder before setting additionalHeaders. + // getClientBuilder() returns the same mutable instance from the data provider. + // Calling .additionalHeaders() directly on it would mutate the shared builder, + clientWithWorkloadId = copyCosmosClientBuilder(getClientBuilder()) .additionalHeaders(headers) .buildAsyncClient(); diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java index c9b06bddee58..823acf3c01f2 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/rx/WorkloadIdE2ETests.java @@ -62,7 +62,10 @@ public void beforeClass() { Map headers = new HashMap<>(); headers.put(CosmosHeaderName.WORKLOAD_ID, "15"); - clientWithWorkloadId = getClientBuilder() + // Clone the shared builder before setting additionalHeaders. + // getClientBuilder() returns the same mutable instance from the data provider. + // Calling .additionalHeaders() directly on it would mutate the shared builder, + clientWithWorkloadId = copyCosmosClientBuilder(getClientBuilder()) .additionalHeaders(headers) .buildAsyncClient();