From 8d4f01640e65f20d28b28cfca3a966f2424eb84c Mon Sep 17 00:00:00 2001 From: Divyansh Vijayvergia Date: Wed, 15 Apr 2026 13:08:54 +0000 Subject: [PATCH] add support for authentication through Azure MSI --- NEXT_CHANGELOG.md | 1 + .../sdk/core/AzureMsiCredentialsProvider.java | 93 +++++ .../sdk/core/DefaultCredentialsProvider.java | 1 + .../sdk/core/oauth/AzureMsiTokenSource.java | 113 ++++++ .../core/AzureMsiCredentialsProviderTest.java | 374 ++++++++++++++++++ 5 files changed, 582 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureMsiCredentialsProvider.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureMsiTokenSource.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureMsiCredentialsProviderTest.java diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a5bde1314..7d4ee7474 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,7 @@ ## Release v0.104.0 ### New Features and Improvements +* Add support for authentication through Azure Managed Service Identity (MSI) via the new `azure-msi` credential provider. * Added automatic detection of AI coding agents (Antigravity, Claude Code, Cline, Codex, Copilot CLI, Cursor, Gemini CLI, OpenCode) in the user-agent string. The SDK now appends `agent/` to HTTP request headers when running inside a known AI agent environment. ### Bug Fixes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureMsiCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureMsiCredentialsProvider.java new file mode 100644 index 000000000..dc3a41147 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureMsiCredentialsProvider.java @@ -0,0 +1,93 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.core.oauth.AzureMsiTokenSource; +import com.databricks.sdk.core.oauth.CachedTokenSource; +import com.databricks.sdk.core.oauth.OAuthHeaderFactory; +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.utils.AzureUtils; +import com.databricks.sdk.support.InternalApi; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Adds refreshed Azure Active Directory (AAD) tokens obtained via Azure Managed Service Identity + * (MSI) to every request. This provider authenticates using the Azure Instance Metadata Service + * (IMDS) endpoint, which is available on Azure VMs and other compute resources with managed + * identities enabled. + */ +@InternalApi +public class AzureMsiCredentialsProvider implements CredentialsProvider { + private static final Logger LOG = LoggerFactory.getLogger(AzureMsiCredentialsProvider.class); + private final ObjectMapper mapper = new ObjectMapper(); + + @Override + public String authType() { + return "azure-msi"; + } + + @Override + public OAuthHeaderFactory configure(DatabricksConfig config) { + if (!config.isAzure()) { + return null; + } + + if (!isAzureUseMsi(config)) { + return null; + } + + if (config.getAzureWorkspaceResourceId() == null && config.getHost() == null) { + return null; + } + + LOG.debug("Generating AAD token via Azure MSI"); + + AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor); + + CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + CachedTokenSource cloud = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + + return OAuthHeaderFactory.fromSuppliers( + inner::getToken, + () -> { + Token token = inner.getToken(); + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + token.getAccessToken()); + AzureUtils.addSpManagementToken(cloud, headers); + AzureUtils.addWorkspaceResourceId(config, headers); + return headers; + }); + } + + /** + * Null-safe check for the azureUseMsi config flag. The underlying field is a boxed Boolean, but + * the getter auto-unboxes to primitive boolean, which would NPE when the field is null. This + * helper treats null as false. + */ + private static boolean isAzureUseMsi(DatabricksConfig config) { + try { + return config.getAzureUseMsi(); + } catch (NullPointerException e) { + return false; + } + } + + /** + * Creates a CachedTokenSource for the specified Azure resource using MSI authentication. + * + * @param config The DatabricksConfig instance containing the required authentication parameters. + * @param resource The Azure resource for which OAuth tokens need to be fetched. + * @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure + * resource. + */ + CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + AzureMsiTokenSource tokenSource = + new AzureMsiTokenSource(config.getHttpClient(), resource, config.getAzureClientId()); + return new CachedTokenSource.Builder(tokenSource) + .setAsyncDisabled(config.getDisableAsyncTokenRefresh()) + .build(); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index 99716890f..4d4805b8a 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -179,6 +179,7 @@ private synchronized void addDefaultCredentialsProviders(DatabricksConfig config addOIDCCredentialsProviders(config); providers.add(new AzureGithubOidcCredentialsProvider()); + providers.add(new AzureMsiCredentialsProvider()); providers.add(new AzureServicePrincipalCredentialsProvider()); providers.add(new AzureCliCredentialsProvider()); providers.add(new ExternalBrowserCredentialsProvider()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureMsiTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureMsiTokenSource.java new file mode 100644 index 000000000..efcaeee7d --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureMsiTokenSource.java @@ -0,0 +1,113 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.databricks.sdk.support.InternalApi; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.time.Instant; + +/** + * A {@link TokenSource} that fetches OAuth tokens from the Azure Instance Metadata Service (IMDS) + * endpoint for Managed Service Identity (MSI) authentication. + * + *

This token source makes HTTP GET requests to the well-known IMDS endpoint at {@code + * http://169.254.169.254/metadata/identity/oauth2/token} to obtain access tokens for the specified + * Azure resource. + */ +@InternalApi +public class AzureMsiTokenSource implements TokenSource { + + private static final String IMDS_ENDPOINT = + "http://169.254.169.254/metadata/identity/oauth2/token"; + + private final HttpClient httpClient; + private final String resource; + private final String clientId; + private final ObjectMapper mapper = new ObjectMapper(); + + /** Response structure from the Azure IMDS token endpoint. */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class MsiTokenResponse { + @JsonProperty("token_type") + private String tokenType; + + @JsonProperty("access_token") + private String accessToken; + + @JsonProperty("expires_on") + private String expiresOn; + + Token toToken() { + if (accessToken == null || accessToken.isEmpty()) { + throw new DatabricksException("MSI token response missing or empty 'access_token' field"); + } + if (tokenType == null || tokenType.isEmpty()) { + throw new DatabricksException("MSI token response missing or empty 'token_type' field"); + } + if (expiresOn == null || expiresOn.isEmpty()) { + throw new DatabricksException("MSI token response missing 'expires_on' field"); + } + long epoch; + try { + epoch = Long.parseLong(expiresOn); + } catch (NumberFormatException e) { + throw new DatabricksException( + "Invalid 'expires_on' value in MSI token response: " + expiresOn, e); + } + return new Token(accessToken, tokenType, Instant.ofEpochSecond(epoch)); + } + } + + /** + * Creates a new AzureMsiTokenSource. + * + * @param httpClient The HTTP client to use for requests to the IMDS endpoint. + * @param resource The Azure resource for which to obtain an access token. + * @param clientId The client ID of the managed identity to use. May be null for system-assigned + * identities. + */ + public AzureMsiTokenSource(HttpClient httpClient, String resource, String clientId) { + this.httpClient = httpClient; + this.resource = resource; + this.clientId = clientId; + } + + @Override + public Token getToken() { + Request req = new Request("GET", IMDS_ENDPOINT); + req.withQueryParam("api-version", "2018-02-01"); + req.withQueryParam("resource", resource); + if (clientId != null && !clientId.isEmpty()) { + req.withQueryParam("client_id", clientId); + } + req.withHeader("Metadata", "true"); + + Response resp; + try { + resp = httpClient.execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request MSI token from IMDS endpoint: " + e.getMessage(), e); + } + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request MSI token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getDebugBody()); + } + + try { + MsiTokenResponse msiToken = mapper.readValue(resp.getBody(), MsiTokenResponse.class); + return msiToken.toToken(); + } catch (IOException e) { + throw new DatabricksException("Failed to parse MSI token response: " + e.getMessage(), e); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureMsiCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureMsiCredentialsProviderTest.java new file mode 100644 index 000000000..d25d0396f --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureMsiCredentialsProviderTest.java @@ -0,0 +1,374 @@ +package com.databricks.sdk.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; + +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.databricks.sdk.core.oauth.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.URL; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +class AzureMsiCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = "/a/b/c"; + private static final String RESOURCE = "https://management.azure.com/"; + private static final String ACCESS_TOKEN = "test-access-token"; + private static final String TOKEN_TYPE = "Bearer"; + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static CachedTokenSource mockTokenSource(String tokenValue) { + TokenSource inner = Mockito.mock(TokenSource.class); + Mockito.when(inner.getToken()) + .thenReturn(new Token(tokenValue, "Bearer", Instant.now().plusSeconds(3600))); + return new CachedTokenSource.Builder(inner).build(); + } + + private static Response makeImdsTokenResponse( + String accessToken, String tokenType, String expiresOn) throws Exception { + Map body = new HashMap<>(); + if (accessToken != null) body.put("access_token", accessToken); + if (tokenType != null) body.put("token_type", tokenType); + if (expiresOn != null) body.put("expires_on", expiresOn); + String json = MAPPER.writeValueAsString(body); + return new Response(json, new URL("http://169.254.169.254/")); + } + + // ── Token source tests ────────────────────────────────────────────────── + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class TokenSourceTests { + + @Test + void testHappyPath() throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + long expiresOn = Instant.now().plusSeconds(3600).getEpochSecond(); + Mockito.when(mockClient.execute(any())) + .thenReturn(makeImdsTokenResponse(ACCESS_TOKEN, TOKEN_TYPE, String.valueOf(expiresOn))); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, null); + Token token = tokenSource.getToken(); + + assertEquals(ACCESS_TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(Instant.ofEpochSecond(expiresOn), token.getExpiry()); + } + + @Test + void testRequestIncludesCorrectParams() throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + long expiresOn = Instant.now().plusSeconds(3600).getEpochSecond(); + Mockito.when(mockClient.execute(any())) + .thenReturn(makeImdsTokenResponse(ACCESS_TOKEN, TOKEN_TYPE, String.valueOf(expiresOn))); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, null); + tokenSource.getToken(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Request.class); + Mockito.verify(mockClient).execute(captor.capture()); + Request req = captor.getValue(); + + assertEquals("GET", req.getMethod()); + assertTrue(req.getUrl().contains("169.254.169.254/metadata/identity/oauth2/token")); + assertEquals("true", req.getHeaders().get("Metadata")); + assertTrue(req.getUri().toString().contains("api-version=2018-02-01")); + assertTrue( + req.getUri() + .toString() + .contains("resource=" + java.net.URLEncoder.encode(RESOURCE, "UTF-8"))); + assertFalse(req.getUri().toString().contains("client_id")); + } + + Stream clientIdCases() { + return Stream.of( + Arguments.of("my-client-id", true, "non-empty client_id is included"), + Arguments.of("", false, "empty client_id is excluded"), + Arguments.of(null, false, "null client_id is excluded")); + } + + @ParameterizedTest(name = "{2}") + @MethodSource("clientIdCases") + void testClientIdInRequest(String clientId, boolean expectPresent, String description) + throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + long expiresOn = Instant.now().plusSeconds(3600).getEpochSecond(); + Mockito.when(mockClient.execute(any())) + .thenReturn(makeImdsTokenResponse(ACCESS_TOKEN, TOKEN_TYPE, String.valueOf(expiresOn))); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, clientId); + tokenSource.getToken(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Request.class); + Mockito.verify(mockClient).execute(captor.capture()); + String uri = captor.getValue().getUri().toString(); + + if (expectPresent) { + assertTrue(uri.contains("client_id=" + clientId)); + } else { + assertFalse(uri.contains("client_id")); + } + } + + Stream invalidResponseCases() { + return Stream.of( + Arguments.of(null, TOKEN_TYPE, "1700000000", "access_token", "missing access_token"), + Arguments.of(ACCESS_TOKEN, null, "1700000000", "token_type", "missing token_type"), + Arguments.of(ACCESS_TOKEN, TOKEN_TYPE, null, "expires_on", "missing expires_on"), + Arguments.of( + ACCESS_TOKEN, + TOKEN_TYPE, + "not-a-number", + "Invalid 'expires_on' value", + "non-numeric expires_on"), + Arguments.of( + ACCESS_TOKEN, + TOKEN_TYPE, + "99999999999999999999999999999", + "Invalid 'expires_on' value", + "overflow expires_on")); + } + + @ParameterizedTest(name = "{4}") + @MethodSource("invalidResponseCases") + void testInvalidResponseThrowsException( + String accessToken, + String tokenType, + String expiresOn, + String expectedMessage, + String description) + throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + Mockito.when(mockClient.execute(any())) + .thenReturn(makeImdsTokenResponse(accessToken, tokenType, expiresOn)); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, null); + + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue( + ex.getMessage().contains(expectedMessage), + "Expected '" + expectedMessage + "' in: " + ex.getMessage()); + } + + @Test + void testNon200StatusThrowsException() throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + Mockito.when(mockClient.execute(any())) + .thenReturn( + new Response("Not Found", 404, "Not Found", new URL("http://169.254.169.254/"))); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, null); + + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("status code 404")); + } + + @Test + void testIOExceptionThrowsDatabricksException() throws Exception { + HttpClient mockClient = Mockito.mock(HttpClient.class); + Mockito.when(mockClient.execute(any())).thenThrow(new IOException("connection refused")); + + AzureMsiTokenSource tokenSource = new AzureMsiTokenSource(mockClient, RESOURCE, null); + + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("Failed to request MSI token from IMDS endpoint")); + assertTrue(ex.getMessage().contains("connection refused")); + } + } + + // ── Credentials provider tests ────────────────────────────────────────── + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class ProviderTests { + + @Test + void testAuthType() { + assertEquals("azure-msi", new AzureMsiCredentialsProvider().authType()); + } + + Stream returnsNullCases() { + return Stream.of( + Arguments.of( + "https://accounts.cloud.databricks.com", true, null, false, "non-Azure host"), + Arguments.of( + "https://adb-123.4.azuredatabricks.net", false, null, false, "MSI not enabled"), + Arguments.of(null, true, null, true, "neither host nor resource ID")); + } + + @ParameterizedTest(name = "returns null when {4}") + @MethodSource("returnsNullCases") + void testReturnsNull( + String host, + boolean azureUseMsi, + String resourceId, + boolean mockIsAzure, + String description) { + AzureMsiCredentialsProvider provider = Mockito.spy(new AzureMsiCredentialsProvider()); + DatabricksConfig config; + if (mockIsAzure) { + config = Mockito.spy(new DatabricksConfig().setCredentialsProvider(provider)); + Mockito.doReturn(true).when(config).isAzure(); + } else { + config = new DatabricksConfig().setCredentialsProvider(provider); + if (host != null) config.setHost(host); + if (resourceId != null) config.setAzureWorkspaceResourceId(resourceId); + } + if (azureUseMsi) config.setAzureUseMsi(true); + + assertNull(provider.configure(config)); + } + + @Test + void testHappyFlowWithResourceId() throws Exception { + // Mirrors TestMsiHappyFlow from Go SDK. + AzureMsiCredentialsProvider provider = Mockito.spy(new AzureMsiCredentialsProvider()); + + CachedTokenSource armTokenSource = mockTokenSource("bcd"); + CachedTokenSource innerTokenSource = mockTokenSource("cde"); + CachedTokenSource cloudTokenSource = mockTokenSource("def"); + Mockito.doReturn(armTokenSource) + .doReturn(innerTokenSource) + .doReturn(cloudTokenSource) + .when(provider) + .tokenSourceFor(any(), anyString()); + + HttpClient mockClient = Mockito.mock(HttpClient.class); + Map armResponse = new HashMap<>(); + Map properties = new HashMap<>(); + properties.put("workspaceUrl", "abc.azuredatabricks.net"); + armResponse.put("properties", properties); + Mockito.when(mockClient.execute(any())) + .thenReturn( + new Response( + MAPPER.writeValueAsString(armResponse), + 200, + "OK", + new URL("https://management.azure.com/"))); + + DatabricksConfig config = + new DatabricksConfig() + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID) + .setAzureUseMsi(true) + .setHttpClient(mockClient); + + OAuthHeaderFactory result = provider.configure(config); + + assertNotNull(result); + Map headers = result.headers(); + assertEquals("Bearer cde", headers.get("Authorization")); + assertEquals("def", headers.get("X-Databricks-Azure-SP-Management-Token")); + assertEquals(WORKSPACE_RESOURCE_ID, headers.get("X-Databricks-Azure-Workspace-Resource-Id")); + + // Verify host was resolved and ARM call was correct. + assertEquals("https://abc.azuredatabricks.net", config.getHost()); + ArgumentCaptor reqCaptor = ArgumentCaptor.forClass(Request.class); + Mockito.verify(mockClient).execute(reqCaptor.capture()); + assertTrue(reqCaptor.getValue().getUrl().contains(WORKSPACE_RESOURCE_ID)); + assertEquals("Bearer bcd", reqCaptor.getValue().getHeaders().get("Authorization")); + + // Verify tokenSourceFor was called 3 times: ARM, inner, cloud. + ArgumentCaptor resCaptor = ArgumentCaptor.forClass(String.class); + Mockito.verify(provider, Mockito.times(3)).tokenSourceFor(any(), resCaptor.capture()); + assertEquals("https://management.azure.com/", resCaptor.getAllValues().get(0)); + assertEquals(AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID, resCaptor.getAllValues().get(1)); + assertEquals("https://management.core.windows.net/", resCaptor.getAllValues().get(2)); + } + + @Test + void testFailsOnResolveWorkspace() throws Exception { + // Mirrors TestMsiFailsOnResolveWorkspace from Go SDK. + AzureMsiCredentialsProvider provider = Mockito.spy(new AzureMsiCredentialsProvider()); + CachedTokenSource armTokenSource = mockTokenSource("bcd"); + Mockito.doReturn(armTokenSource).when(provider).tokenSourceFor(any(), anyString()); + + HttpClient mockClient = Mockito.mock(HttpClient.class); + Mockito.when(mockClient.execute(any())) + .thenReturn( + new Response( + "Not Found", 404, "Not Found", new URL("https://management.azure.com/"))); + + DatabricksConfig config = + new DatabricksConfig() + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID) + .setAzureUseMsi(true) + .setHttpClient(mockClient); + + assertThrows(DatabricksException.class, () -> provider.configure(config)); + } + + @Test + void testHappyFlowWithHostAndNoResourceId() { + // Mirrors TestMsiHappyFlowWithHostAndNoResourceID from Go SDK. + AzureMsiCredentialsProvider provider = Mockito.spy(new AzureMsiCredentialsProvider()); + + CachedTokenSource innerTokenSource = mockTokenSource("cde"); + CachedTokenSource cloudTokenSource = mockTokenSource("def"); + Mockito.doReturn(innerTokenSource) + .doReturn(cloudTokenSource) + .when(provider) + .tokenSourceFor(any(), anyString()); + + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://adb-123.4.azuredatabricks.net") + .setCredentialsProvider(provider) + .setAzureUseMsi(true); + + OAuthHeaderFactory result = provider.configure(config); + + assertNotNull(result); + Map headers = result.headers(); + assertEquals("Bearer cde", headers.get("Authorization")); + assertEquals("def", headers.get("X-Databricks-Azure-SP-Management-Token")); + assertNull(headers.get("X-Databricks-Azure-Workspace-Resource-Id")); + + // Verify tokenSourceFor called with correct resources (no ARM call). + ArgumentCaptor resCaptor = ArgumentCaptor.forClass(String.class); + Mockito.verify(provider, Mockito.times(2)).tokenSourceFor(any(), resCaptor.capture()); + assertEquals(AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID, resCaptor.getAllValues().get(0)); + assertEquals("https://management.core.windows.net/", resCaptor.getAllValues().get(1)); + } + + @Test + void testWithClientId() { + // Mirrors TestMsiTokenNotFound scenario from Go SDK. + AzureMsiCredentialsProvider provider = Mockito.spy(new AzureMsiCredentialsProvider()); + + CachedTokenSource innerTokenSource = mockTokenSource("cde"); + CachedTokenSource cloudTokenSource = mockTokenSource("def"); + Mockito.doReturn(innerTokenSource) + .doReturn(cloudTokenSource) + .when(provider) + .tokenSourceFor(any(), anyString()); + + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://adb-123.4.azuredatabricks.net") + .setCredentialsProvider(provider) + .setAzureUseMsi(true) + .setAzureClientId("abc"); + + OAuthHeaderFactory result = provider.configure(config); + + assertNotNull(result); + assertEquals("Bearer cde", result.headers().get("Authorization")); + } + } +}