Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void refreshCredentialsIfRequired() throws IOException {
}
try {
// Wait for the refresh task to complete.
currentRefreshTask.task.get();
currentRefreshTask.get();
} catch (InterruptedException e) {
// Restore the interrupted status and throw an exception.
Thread.currentThread().interrupt();
Expand Down Expand Up @@ -495,31 +495,17 @@ class RefreshTask extends AbstractFuture<IntermediateCredentials> implements Run
this.task = task;
this.isNew = isNew;

// Add listener to update factory's credentials when the task completes.
// Single listener to guarantee that finishRefreshTask updates the internal state BEFORE
// the outer future completes and unblocks waiters.
task.addListener(
() -> {
try {
finishRefreshTask(task);
RefreshTask.this.set(Futures.getDone(task));
} catch (ExecutionException e) {
Throwable cause = e.getCause();
RefreshTask.this.setException(cause);
}
},
MoreExecutors.directExecutor());

// Add callback to set the result or exception based on the outcome.
Futures.addCallback(
task,
new FutureCallback<IntermediateCredentials>() {
@Override
public void onSuccess(IntermediateCredentials result) {
RefreshTask.this.set(result);
}

@Override
public void onFailure(@Nullable Throwable t) {
RefreshTask.this.setException(
t != null ? t : new IOException("Refresh failed with null Throwable."));
RefreshTask.this.setException(e.getCause());
} catch (Exception e) {
RefreshTask.this.setException(e);
}
},
MoreExecutors.directExecutor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -988,4 +988,51 @@ void generateToken_withMalformSessionKey_failure() throws Exception {

assertThrows(GeneralSecurityException.class, () -> factory.generateToken(accessBoundary));
}

@org.junit.jupiter.api.Test
void generateToken_freshInstance_concurrent_noNpe() throws Exception {
for (int run = 0; run < 10; run++) { // Run 10 times in a single test instance to save time
GoogleCredentials sourceCredentials = getServiceAccountSourceCredentials(mockTokenServerTransportFactory);
ClientSideCredentialAccessBoundaryFactory factory = ClientSideCredentialAccessBoundaryFactory.newBuilder()
.setSourceCredential(sourceCredentials)
.setHttpTransportFactory(mockStsTransportFactory)
.build();

CredentialAccessBoundary.Builder cabBuilder = CredentialAccessBoundary.newBuilder();
CredentialAccessBoundary accessBoundary = cabBuilder
.addRule(
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
.setAvailableResource("resource")
.setAvailablePermissions(ImmutableList.of("role"))
.build())
.build();

int numThreads = 5;
Thread[] threads = new Thread[numThreads];
CountDownLatch latch = new CountDownLatch(numThreads);
java.util.concurrent.atomic.AtomicInteger npeCount = new java.util.concurrent.atomic.AtomicInteger();

for (int i = 0; i < numThreads; i++) {
threads[i] = new Thread(() -> {
try {
latch.countDown();
latch.await();
factory.generateToken(accessBoundary);
} catch (NullPointerException e) {
npeCount.incrementAndGet();
} catch (Exception e) {
// Ignore other exceptions for the sake of the race reproduction
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

org.junit.jupiter.api.Assertions.assertEquals(0, npeCount.get(),
"Expected zero NullPointerExceptions due to the race condition, but some were thrown.");
}
}
}
Loading