From 848f3433a61d57f2947b800a766f10e77d4d6280 Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Thu, 12 Mar 2026 21:16:21 +0000 Subject: [PATCH 1/8] Add a wrapper on top of IsolationChannel to maintain a set of primary and fallback channel --- .../worker/StreamingDataflowWorker.java | 66 ++++-- .../FanOutStreamingEngineWorkerHarness.java | 15 +- .../client/grpc/GrpcDispatcherClient.java | 2 +- .../client/grpc/stubs/FailoverChannel.java | 207 ++++++++++++++++++ ...anOutStreamingEngineWorkerHarnessTest.java | 20 +- .../grpc/stubs/FailoverChannelTest.java | 158 +++++++++++++ .../grpc/stubs/IsolationChannelTest.java | 32 +-- .../client/grpc/stubs/NoopClientCall.java | 51 +++++ .../windmill/src/main/proto/windmill.proto | 2 + 9 files changed, 498 insertions(+), 55 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index aad27b869863..9866f85a0956 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -87,8 +87,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.FailoverChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; @@ -113,6 +115,8 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.auth.MoreCallCredentials; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; @@ -376,7 +380,7 @@ private StreamingWorkerHarnessFactoryOutput createFanOutStreamingEngineWorkerHar MemoryMonitor memoryMonitor, GrpcDispatcherClient dispatcherClient) { WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); - ChannelCache channelCache = createChannelCache(options, configFetcher); + ChannelCache channelCache = createChannelCache(options, configFetcher, dispatcherClient); FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = FanOutStreamingEngineWorkerHarness.create( createJobHeader(options, clientId), @@ -789,20 +793,54 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) } private static ChannelCache createChannelCache( - DataflowWorkerHarnessOptions workerOptions, ComputationConfig.Fetcher configFetcher) { + DataflowWorkerHarnessOptions workerOptions, + ComputationConfig.Fetcher configFetcher, + GrpcDispatcherClient dispatcherClient) { ChannelCache channelCache = - ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> { - // IsolationChannel will create and manage separate RPC channels to the same - // serviceAddress. - return IsolationChannel.create( - () -> - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - }); + Boolean.TRUE.equals( + workerOptions + .getUseWindmillIsolatedChannels()) // Create failover channel only if isolated + // channels + // is enabled for dispatcher client + ? ChannelCache.create( + (currentFlowControlSettings, serviceAddress) -> { + ManagedChannel primaryChannel = + IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + // Create an isolated fallback channel from dispatcher endpoints. + // This ensures both primary and fallback use separate isolated channels. + ManagedChannel fallbackChannel = + IsolationChannel.create( + () -> + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + return FailoverChannel.create( + primaryChannel, + fallbackChannel, + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); + }) + : ChannelCache.create( + (currentFlowControlSettings, serviceAddress) -> { + // IsolationChannel will create and manage separate RPC channels to the same + // serviceAddress. + return IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + }); + configFetcher .getGlobalConfigHandle() .registerConfigObserver( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 75c2b91af603..63ab5379bd49 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -410,15 +410,18 @@ private GlobalDataStreamSender getOrCreateGlobalDataSteam( } private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { + GetWorkRequest.Builder getWorkRequestBuilder = + GetWorkRequest.newBuilder() + .setClientId(jobHeader.getClientId()) + .setJobId(jobHeader.getJobId()) + .setProjectId(jobHeader.getProjectId()) + .setWorkerId(jobHeader.getWorkerId()); + endpoint.workerToken().ifPresent(getWorkRequestBuilder::setBackendWorkerToken); + WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( WindmillConnection.from(endpoint, this::createWindmillStub), - GetWorkRequest.newBuilder() - .setClientId(jobHeader.getClientId()) - .setJobId(jobHeader.getJobId()) - .setProjectId(jobHeader.getProjectId()) - .setWorkerId(jobHeader.getWorkerId()) - .build(), + getWorkRequestBuilder.build(), GetWorkBudget.noBudget(), streamFactory, workItemScheduler, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index 82e66c4b0d74..0d8f75dd816a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -128,7 +128,7 @@ public CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { : randomlySelectNextStub(windmillServiceStubs)); } - ImmutableSet getDispatcherEndpoints() { + public ImmutableSet getDispatcherEndpoints() { return dispatcherStubs.get().dispatcherEndpoints(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java new file mode 100644 index 000000000000..12ebc9a4991b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link ManagedChannel} that wraps a primary and a fallback channel. It fails over to the + * fallback channel if the primary channel returns {@link Status#UNAVAILABLE}. + */ +@Internal +public final class FailoverChannel extends ManagedChannel { + private static final Logger LOG = LoggerFactory.getLogger(FailoverChannel.class); + // Time to wait before retrying the primary channel after a failure, to avoid retrying too quickly + private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); + private final ManagedChannel primary; + private final ManagedChannel fallback; + @Nullable private final CallCredentials fallbackCallCredentials; + private final AtomicBoolean useFallback = new AtomicBoolean(false); + private final AtomicLong lastFallbackTimeNanos = new AtomicLong(0); + private final LongSupplier nanoClock; + + private FailoverChannel( + ManagedChannel primary, + ManagedChannel fallback, + @Nullable CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + this.primary = primary; + this.fallback = fallback; + this.fallbackCallCredentials = fallbackCallCredentials; + this.nanoClock = nanoClock; + } + + public static FailoverChannel create(ManagedChannel primary, ManagedChannel fallback) { + return new FailoverChannel(primary, fallback, null, System::nanoTime); + } + + public static FailoverChannel create( + ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, System::nanoTime); + } + + static FailoverChannel forTest( + ManagedChannel primary, + ManagedChannel fallback, + CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, nanoClock); + } + + @Override + public String authority() { + return primary.authority(); + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + if (useFallback.get()) { + long elapsedNanos = nanoClock.getAsLong() - lastFallbackTimeNanos.get(); + if (elapsedNanos > FALLBACK_COOLING_PERIOD_NANOS) { + if (useFallback.compareAndSet(true, false)) { + LOG.info("Fallback cooling period elapsed. Retrying direct path."); + } + } else { + CallOptions fallbackCallOptions = callOptions; + if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { + fallbackCallOptions = callOptions.withCallCredentials(fallbackCallCredentials); + } + // The boolean `true` marks that the ClientCall is using the + // fallback (cloudpath) channel. The inner call listener uses this + // flag so `notifyFailure` will only transition to fallback when a + // non-fallback (primary) call fails; fallback calls simply log + // failures and do not re-trigger another fallback transition. + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, fallbackCallOptions), + true, + methodDescriptor.getFullMethodName()); + } + } + // The boolean `false` marks that the ClientCall is using the + // primary (direct) channel. If this call closes with a non-OK status, + // `notifyFailure` will flip `useFallback` to true, causing subsequent + // calls to go to the fallback channel for the cooling period. + return new FailoverClientCall<>( + primary.newCall(methodDescriptor, callOptions), + false, + methodDescriptor.getFullMethodName()); + } + + @Override + public ManagedChannel shutdown() { + primary.shutdown(); + if (fallback != null) { + fallback.shutdown(); + } + return this; + } + + @Override + public ManagedChannel shutdownNow() { + primary.shutdownNow(); + if (fallback != null) { + fallback.shutdownNow(); + } + return this; + } + + @Override + public boolean isShutdown() { + return primary.isShutdown() && (fallback == null || fallback.isShutdown()); + } + + @Override + public boolean isTerminated() { + return primary.isTerminated() && (fallback == null || fallback.isTerminated()); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long endTimeNanos = nanoClock.getAsLong() + unit.toNanos(timeout); + boolean primaryTerminated = primary.awaitTermination(timeout, unit); + if (fallback != null) { + long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); + return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); + } + return primaryTerminated; + } + + private void notifyFailure(Status status, boolean isFallback, String methodName) { + if (!status.isOk() && !isFallback && fallback != null) { + if (useFallback.compareAndSet(false, true)) { + lastFallbackTimeNanos.set(nanoClock.getAsLong()); + LOG.warn( + "Direct path connection failed with status {} for method: {}. Falling back to" + + " cloudpath for 1 hour.", + status, + methodName); + } + } else if (isFallback) { + LOG.warn("Fallback channel call for method: {} closed with status: {}", methodName, status); + } + } + + private final class FailoverClientCall + extends SimpleForwardingClientCall { + private final boolean isFallback; + private final String methodName; + + /** + * @param delegate the underlying ClientCall (either primary or fallback) + * @param isFallback true if `delegate` is a fallback channel call, false if it is a primary + * channel call. This flag is inspected by {@link #notifyFailure} to determine whether a + * failure should trigger switching to the fallback channel (only primary failures do). + * @param methodName full gRPC method name (for logging) + */ + FailoverClientCall(ClientCall delegate, boolean isFallback, String methodName) { + super(delegate); + this.isFallback = isFallback; + this.methodName = methodName; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + super.start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + notifyFailure(status, isFallback, methodName); + super.onClose(status, trailers); + } + }, + headers); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 94c8f4b75957..b5f244f77eb6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -132,7 +132,7 @@ private static WorkItemScheduler noOpProcessWorkItemFn() { getWorkStreamLatencies) -> {}; } - private static GetWorkRequest getWorkRequest(long items, long bytes) { + private static GetWorkRequest getWorkRequest(long items, long bytes, String backendWorkerToken) { return GetWorkRequest.newBuilder() .setJobId(JOB_ID) .setProjectId(PROJECT_ID) @@ -140,6 +140,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) + .setBackendWorkerToken(backendWorkerToken) .build(); } @@ -239,9 +240,22 @@ public void testStreamsStartCorrectly() throws InterruptedException { .distributeBudget( any(), eq(GetWorkBudget.builder().setItems(items).setBytes(bytes).build())); - verify(streamFactory, times(2)) + verify(streamFactory, times(1)) .createDirectGetWorkStream( - any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); + any(), + eq(getWorkRequest(0, 0, workerToken)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); + verify(streamFactory, times(1)) + .createDirectGetWorkStream( + any(), + eq(getWorkRequest(0, 0, workerToken2)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); verify(streamFactory, times(2)).createDirectGetDataStream(any()); verify(streamFactory, times(2)).createDirectCommitWorkStream(any()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java new file mode 100644 index 000000000000..4710b46ec31e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall.Listener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +@RunWith(JUnit4.class) +public class FailoverChannelTest { + + private MethodDescriptor methodDescriptor = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) + .setRequestMarshaller(new IsolationChannelTest.NoopMarshaller()) + .setResponseMarshaller(new IsolationChannelTest.NoopMarshaller()) + .build(); + + @Test + public void testFallbackAndRetry() throws Exception { + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + ClientCall fallbackCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + + FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel); + + // First call, triggers fallback + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + + // Fail with UNAVAILABLE + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + // Second call should use fallback Channel + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + + @Test + public void testFallbackAndPeriodicRetry() throws Exception { + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + ClientCall underlyingCall2 = mock(ClientCall.class); + ClientCall fallbackCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, underlyingCall2); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + + AtomicLong time = new AtomicLong(0); + + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Trigger fallback + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), any()); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + // Advance time by 30 mins (less than 1 hour) + time.addAndGet(TimeUnit.MINUTES.toNanos(30)); + + // Should still use fallback + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + + // Advance time by another 40 mins (total > 1 hour) + time.addAndGet(TimeUnit.MINUTES.toNanos(40)); + + // Should retry direct path + ClientCall call3 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call3.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + + @Test + public void testFallbackWithCredentials() throws Exception { + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + ClientCall fallbackCall = mock(ClientCall.class); + CallCredentials mockCredentials = mock(CallCredentials.class); + + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + + FailoverChannel failoverChannel = + FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); + + // Trigger fallback + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), any()); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + // Next call should use fallback with credentials + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + verify(mockFallbackChannel).newCall(same(methodDescriptor), optionsCaptor.capture()); + assertEquals(mockCredentials, optionsCaptor.getValue().getCredentials()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java index 20321bbd66c3..580bf873d916 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java @@ -49,42 +49,12 @@ import org.mockito.ArgumentCaptor; import org.mockito.InOrder; -/** - * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in - * places where a scriptable call is necessary. By default, all methods are noops, and designed to - * be overridden. - */ -class NoopClientCall extends ClientCall { - - /** - * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It - * is designed to be used in places where a scriptable call listener is necessary. By default, all - * methods are noops, and designed to be overridden. - */ - public static class NoopClientCallListener extends ClientCall.Listener {} - - @Override - public void start(ClientCall.Listener listener, Metadata headers) {} - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(String message, Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} -} - @RunWith(JUnit4.class) public class IsolationChannelTest { private Supplier channelSupplier = mock(Supplier.class); - private static class NoopMarshaller implements Marshaller { + public static class NoopMarshaller implements Marshaller { @Override public InputStream stream(Object o) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java new file mode 100644 index 000000000000..93a421e0d618 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; + +/** + * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in + * places where a scriptable call is necessary. By default, all methods are noops, and designed to + * be overridden. + */ +public class NoopClientCall extends ClientCall { + + /** + * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It + * is designed to be used in places where a scriptable call listener is necessary. By default, all + * methods are noops, and designed to be overridden. + */ + public static class NoopClientCallListener extends ClientCall.Listener {} + + @Override + public void start(ClientCall.Listener listener, Metadata headers) {} + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} +} diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index a4b3df906dd9..6286b2d67110 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -470,6 +470,8 @@ message GetWorkRequest { optional string project_id = 7; optional int64 max_items = 2 [default = 0xffffffff]; optional int64 max_bytes = 3 [default = 0x7fffffffffffffff]; + repeated string computation_id_filter = 8; + optional string backend_worker_token = 9; reserved 6; } From 42ec76ca1c67ea7382aafb04902289be144a1f99 Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Mon, 16 Mar 2026 03:11:40 +0000 Subject: [PATCH 2/8] Removing check to verify isolation channel option is set --- .../worker/StreamingDataflowWorker.java | 65 ++++++------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 9866f85a0956..1f2c861e7a62 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -793,61 +793,38 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) } private static ChannelCache createChannelCache( - DataflowWorkerHarnessOptions workerOptions, - ComputationConfig.Fetcher configFetcher, - GrpcDispatcherClient dispatcherClient) { - ChannelCache channelCache = - Boolean.TRUE.equals( - workerOptions - .getUseWindmillIsolatedChannels()) // Create failover channel only if isolated - // channels - // is enabled for dispatcher client - ? ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> { - ManagedChannel primaryChannel = - IsolationChannel.create( - () -> - remoteChannel( + DataflowWorkerHarnessOptions workerOptions, + ComputationConfig.Fetcher configFetcher, + GrpcDispatcherClient dispatcherClient) { + ChannelCache channelCache = ChannelCache.create( + (currentFlowControlSettings, serviceAddress) -> { + ManagedChannel primaryChannel = IsolationChannel.create( + () -> remoteChannel( serviceAddress, workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), currentFlowControlSettings), currentFlowControlSettings.getOnReadyThresholdBytes()); // Create an isolated fallback channel from dispatcher endpoints. // This ensures both primary and fallback use separate isolated channels. - ManagedChannel fallbackChannel = - IsolationChannel.create( - () -> - remoteChannel( + ManagedChannel fallbackChannel = IsolationChannel.create( + () -> remoteChannel( dispatcherClient.getDispatcherEndpoints().iterator().next(), workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), currentFlowControlSettings), currentFlowControlSettings.getOnReadyThresholdBytes()); return FailoverChannel.create( - primaryChannel, - fallbackChannel, - MoreCallCredentials.from( - new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); - }) - : ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> { - // IsolationChannel will create and manage separate RPC channels to the same - // serviceAddress. - return IsolationChannel.create( - () -> - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - }); - - configFetcher - .getGlobalConfigHandle() - .registerConfigObserver( - config -> - channelCache.consumeFlowControlSettings( - config.userWorkerJobSettings().getFlowControlSettings())); - return channelCache; + primaryChannel, + fallbackChannel, + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); + }); + + configFetcher + .getGlobalConfigHandle() + .registerConfigObserver( + config -> channelCache.consumeFlowControlSettings( + config.userWorkerJobSettings().getFlowControlSettings())); + return channelCache; } @VisibleForTesting From d3f00694e252929afb318e10c17d66087e6fedb1 Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Thu, 19 Mar 2026 02:28:19 +0000 Subject: [PATCH 3/8] Adding failover mode based on connection status --- .../worker/StreamingDataflowWorker.java | 73 +++---- .../client/grpc/stubs/FailoverChannel.java | 180 ++++++++++++++---- .../grpc/stubs/FailoverChannelTest.java | 114 ++++++++--- 3 files changed, 262 insertions(+), 105 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 5c4ac9b6811c..cc9501583641 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -385,7 +385,8 @@ private StreamingWorkerHarnessFactoryOutput createFanOutStreamingEngineWorkerHar MemoryMonitor memoryMonitor, GrpcDispatcherClient dispatcherClient) { WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); - ChannelCache channelCache = createChannelCache(options, checkNotNull(configFetcher), dispatcherClient); + ChannelCache channelCache = + createChannelCache(options, checkNotNull(configFetcher), dispatcherClient); @SuppressWarnings("methodref.receiver.bound") FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = FanOutStreamingEngineWorkerHarness.create( @@ -808,38 +809,44 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) } private static ChannelCache createChannelCache( - DataflowWorkerHarnessOptions workerOptions, - ComputationConfig.Fetcher configFetcher, - GrpcDispatcherClient dispatcherClient) { - ChannelCache channelCache = ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> { - ManagedChannel primaryChannel = IsolationChannel.create( - () -> remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - // Create an isolated fallback channel from dispatcher endpoints. - // This ensures both primary and fallback use separate isolated channels. - ManagedChannel fallbackChannel = IsolationChannel.create( - () -> remoteChannel( - dispatcherClient.getDispatcherEndpoints().iterator().next(), - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - return FailoverChannel.create( - primaryChannel, - fallbackChannel, - MoreCallCredentials.from( - new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); - }); - - configFetcher - .getGlobalConfigHandle() - .registerConfigObserver( - config -> channelCache.consumeFlowControlSettings( - config.userWorkerJobSettings().getFlowControlSettings())); - return channelCache; + DataflowWorkerHarnessOptions workerOptions, + ComputationConfig.Fetcher configFetcher, + GrpcDispatcherClient dispatcherClient) { + ChannelCache channelCache = + ChannelCache.create( + (currentFlowControlSettings, serviceAddress) -> { + ManagedChannel primaryChannel = + IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + // Create an isolated fallback channel from dispatcher endpoints. + // This ensures both primary and fallback use separate isolated channels. + ManagedChannel fallbackChannel = + IsolationChannel.create( + () -> + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + return FailoverChannel.create( + primaryChannel, + fallbackChannel, + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); + }); + + configFetcher + .getGlobalConfigHandle() + .registerConfigObserver( + config -> + channelCache.consumeFlowControlSettings( + config.userWorkerJobSettings().getFlowControlSettings())); + return channelCache; } @VisibleForTesting diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index 12ebc9a4991b..ea550f361f3e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -26,6 +26,7 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; @@ -36,32 +37,52 @@ import org.slf4j.LoggerFactory; /** - * A {@link ManagedChannel} that wraps a primary and a fallback channel. It fails over to the - * fallback channel if the primary channel returns {@link Status#UNAVAILABLE}. + * A {@link ManagedChannel} that wraps a primary and a fallback channel. + * + *

Routes requests to either primary or fallback channel based on two independent failover modes: + * + *

    + *
  • Connection Status Failover: If the primary channel is not ready for 10+ seconds + * (e.g., during network issues), routes to fallback channel. Switches back as soon as the + * primary channel becomes READY again. + *
  • RPC Failover: If primary channel RPC fails with transient errors ({@link + * Status.Code#UNAVAILABLE}, {@link Status.Code#DEADLINE_EXCEEDED}, or {@link + * Status.Code#UNKNOWN}), switches to fallback channel and waits for a 1-hour cooling period + * before retrying primary. + *
*/ @Internal public final class FailoverChannel extends ManagedChannel { private static final Logger LOG = LoggerFactory.getLogger(FailoverChannel.class); - // Time to wait before retrying the primary channel after a failure, to avoid retrying too quickly + // Time to wait before retrying the primary channel after an RPC-based fallback. private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); + private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); private final ManagedChannel primary; - private final ManagedChannel fallback; + @Nullable private final ManagedChannel fallback; @Nullable private final CallCredentials fallbackCallCredentials; - private final AtomicBoolean useFallback = new AtomicBoolean(false); - private final AtomicLong lastFallbackTimeNanos = new AtomicLong(0); + // Set when primary's connection state has been unavailable for too long. + private final AtomicBoolean useFallbackDueToState = new AtomicBoolean(false); + // Set when an RPC on primary fails with a transient error. + private final AtomicBoolean useFallbackDueToRPC = new AtomicBoolean(false); + private final AtomicLong lastRPCFallbackTimeNanos = new AtomicLong(0); + private final AtomicLong primaryNotReadySinceNanos = new AtomicLong(-1); private final LongSupplier nanoClock; + private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); private FailoverChannel( ManagedChannel primary, - ManagedChannel fallback, + @Nullable ManagedChannel fallback, @Nullable CallCredentials fallbackCallCredentials, LongSupplier nanoClock) { this.primary = primary; this.fallback = fallback; this.fallbackCallCredentials = fallbackCallCredentials; this.nanoClock = nanoClock; + // Register callback to monitor primary channel state changes + registerPrimaryStateChangeListener(); } + // Test-only. public static FailoverChannel create(ManagedChannel primary, ManagedChannel fallback) { return new FailoverChannel(primary, fallback, null, System::nanoTime); } @@ -87,32 +108,34 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - if (useFallback.get()) { - long elapsedNanos = nanoClock.getAsLong() - lastFallbackTimeNanos.get(); - if (elapsedNanos > FALLBACK_COOLING_PERIOD_NANOS) { - if (useFallback.compareAndSet(true, false)) { - LOG.info("Fallback cooling period elapsed. Retrying direct path."); + // Check if the RPC-based cooling period has elapsed. + if (useFallbackDueToRPC.get()) { + long timeSinceLastFallback = nanoClock.getAsLong() - lastRPCFallbackTimeNanos.get(); + if (timeSinceLastFallback >= FALLBACK_COOLING_PERIOD_NANOS) { + if (useFallbackDueToRPC.compareAndSet(true, false)) { + LOG.info("Primary channel cooling period elapsed; switching back from fallback."); } - } else { - CallOptions fallbackCallOptions = callOptions; - if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { - fallbackCallOptions = callOptions.withCallCredentials(fallbackCallCredentials); - } - // The boolean `true` marks that the ClientCall is using the - // fallback (cloudpath) channel. The inner call listener uses this - // flag so `notifyFailure` will only transition to fallback when a - // non-fallback (primary) call fails; fallback calls simply log - // failures and do not re-trigger another fallback transition. - return new FailoverClientCall<>( - fallback.newCall(methodDescriptor, fallbackCallOptions), - true, - methodDescriptor.getFullMethodName()); } } - // The boolean `false` marks that the ClientCall is using the - // primary (direct) channel. If this call closes with a non-OK status, - // `notifyFailure` will flip `useFallback` to true, causing subsequent - // calls to go to the fallback channel for the cooling period. + + if (fallback != null && (useFallbackDueToRPC.get() || useFallbackDueToState.get())) { + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + true, + methodDescriptor.getFullMethodName()); + } + + // If primary has not become ready for a sustained period, fail over to fallback. + if (fallback != null && shouldFallBackDueToPrimaryState()) { + if (useFallbackDueToState.compareAndSet(false, true)) { + LOG.warn("Primary connection unavailable. Switching to secondary connection."); + } + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + true, + methodDescriptor.getFullMethodName()); + } + return new FailoverClientCall<>( primary.newCall(methodDescriptor, callOptions), false, @@ -158,18 +181,56 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE return primaryTerminated; } + private boolean shouldFallbackBasedOnRPCStatus(Status status) { + switch (status.getCode()) { + case UNAVAILABLE: + case DEADLINE_EXCEEDED: + case UNKNOWN: + return true; + default: + return false; + } + } + + private boolean hasFallbackChannel() { + return fallback != null; + } + + private CallOptions applyFallbackCredentials(CallOptions callOptions) { + if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { + return callOptions.withCallCredentials(fallbackCallCredentials); + } + return callOptions; + } + + private boolean shouldFallBackDueToPrimaryState() { + ConnectivityState connectivityState = primary.getState(true); + if (connectivityState == ConnectivityState.READY) { + primaryNotReadySinceNanos.set(-1); + return false; + } + long currentTimeNanos = nanoClock.getAsLong(); + if (primaryNotReadySinceNanos.get() < 0) { + primaryNotReadySinceNanos.set(currentTimeNanos); + } + return currentTimeNanos - primaryNotReadySinceNanos.get() > PRIMARY_NOT_READY_WAIT_NANOS; + } + private void notifyFailure(Status status, boolean isFallback, String methodName) { - if (!status.isOk() && !isFallback && fallback != null) { - if (useFallback.compareAndSet(false, true)) { - lastFallbackTimeNanos.set(nanoClock.getAsLong()); + if (!status.isOk() + && !isFallback + && hasFallbackChannel() + && shouldFallbackBasedOnRPCStatus(status)) { + if (useFallbackDueToRPC.compareAndSet(false, true)) { + lastRPCFallbackTimeNanos.set(nanoClock.getAsLong()); LOG.warn( - "Direct path connection failed with status {} for method: {}. Falling back to" - + " cloudpath for 1 hour.", - status, - methodName); + "Primary connection failed for method: {}. Switching to secondary connection. Status: {}", + methodName, + status.getCode()); } - } else if (isFallback) { - LOG.warn("Fallback channel call for method: {} closed with status: {}", methodName, status); + } else if (isFallback && !status.isOk()) { + LOG.warn( + "Secondary connection failed for method: {}. Status: {}", methodName, status.getCode()); } } @@ -180,9 +241,10 @@ private final class FailoverClientCall /** * @param delegate the underlying ClientCall (either primary or fallback) - * @param isFallback true if `delegate` is a fallback channel call, false if it is a primary - * channel call. This flag is inspected by {@link #notifyFailure} to determine whether a - * failure should trigger switching to the fallback channel (only primary failures do). + * @param isFallback true if {@code delegate} is a fallback channel call, false if it is a + * primary channel call. This flag is inspected by {@link #notifyFailure} to determine + * whether a failure should trigger switching to the fallback channel (only primary failures + * do). * @param methodName full gRPC method name (for logging) */ FailoverClientCall(ClientCall delegate, boolean isFallback, String methodName) { @@ -204,4 +266,38 @@ public void onClose(Status status, Metadata trailers) { headers); } } + + /** Registers callback for primary channel state changes. */ + private void registerPrimaryStateChangeListener() { + if (!stateChangeListenerRegistered.getAndSet(true)) { + try { + ConnectivityState currentState = primary.getState(false); + primary.notifyWhenStateChanged(currentState, this::onPrimaryStateChanged); + } catch (Exception e) { + LOG.warn( + "Failed to register channel state monitor. Continuing with fallback detection.", e); + stateChangeListenerRegistered.set(false); + } + } + } + + /** Callback invoked when primary channel connectivity state changes. */ + private void onPrimaryStateChanged() { + if (isShutdown() || isTerminated()) { + return; + } + + // If primary is READY, clear state-based fallback immediately. + if (primary.getState(false) == ConnectivityState.READY) { + if (useFallbackDueToState.compareAndSet(true, false)) { + LOG.info("Primary channel recovered; switching back from fallback."); + } + } + + // Always re-register for next state change (unless shutdown) + if (!isShutdown() && !isTerminated()) { + stateChangeListenerRegistered.set(false); + registerPrimaryStateChangeListener(); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index 4710b46ec31e..c98004943613 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -22,16 +22,19 @@ import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall.Listener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; @@ -53,7 +56,8 @@ public class FailoverChannelTest { .build(); @Test - public void testFallbackAndRetry() throws Exception { + public void testRPCFailureTriggersFallback() throws Exception { + // RPC failure with UNAVAILABLE should switch to fallback channel. ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); ClientCall underlyingCall = mock(ClientCall.class); @@ -63,7 +67,6 @@ public void testFallbackAndRetry() throws Exception { FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel); - // First call, triggers fallback ClientCall call1 = failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); Metadata metadata1 = new Metadata(); @@ -71,11 +74,8 @@ public void testFallbackAndRetry() throws Exception { ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); verify(underlyingCall).start(captor.capture(), same(metadata1)); - - // Fail with UNAVAILABLE captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); - // Second call should use fallback Channel ClientCall call2 = failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); @@ -83,21 +83,21 @@ public void testFallbackAndRetry() throws Exception { } @Test - public void testFallbackAndPeriodicRetry() throws Exception { + public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { + // After RPC failure, channel stays on fallback during cooling period, then returns to primary. ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); ClientCall underlyingCall = mock(ClientCall.class); - ClientCall underlyingCall2 = mock(ClientCall.class); ClientCall fallbackCall = mock(ClientCall.class); - when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, underlyingCall2); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + when(mockChannel.getState(true)).thenReturn(ConnectivityState.READY); AtomicLong time = new AtomicLong(0); - FailoverChannel failoverChannel = FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); - // Trigger fallback + // Trigger RPC failure fallback ClientCall call1 = failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); @@ -105,40 +105,30 @@ public void testFallbackAndPeriodicRetry() throws Exception { verify(underlyingCall).start(captor.capture(), any()); captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); - // Advance time by 30 mins (less than 1 hour) + // Within cooling period: still on fallback time.addAndGet(TimeUnit.MINUTES.toNanos(30)); - - // Should still use fallback - ClientCall call2 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); - // Advance time by another 40 mins (total > 1 hour) + // After cooling period: recovers to primary time.addAndGet(TimeUnit.MINUTES.toNanos(40)); - - // Should retry direct path - ClientCall call3 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call3.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel, atLeast(2)).newCall(any(), any()); } @Test public void testFallbackWithCredentials() throws Exception { + // Fallback channel should receive custom credentials when provided. ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); ClientCall underlyingCall = mock(ClientCall.class); - ClientCall fallbackCall = mock(ClientCall.class); CallCredentials mockCredentials = mock(CallCredentials.class); - when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); - when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); - // Trigger fallback ClientCall call1 = failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); @@ -146,13 +136,77 @@ public void testFallbackWithCredentials() throws Exception { verify(underlyingCall).start(captor.capture(), any()); captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); - // Next call should use fallback with credentials - ClientCall call2 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(CallOptions.class); verify(mockFallbackChannel).newCall(same(methodDescriptor), optionsCaptor.capture()); assertEquals(mockCredentials, optionsCaptor.getValue().getCredentials()); } + + @Test + public void testStateFallbackAfterPrimaryNotReady() { + // If primary connection is not ready for 10+ seconds, routes to fallback. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockChannel.getState(true)).thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Within 10 seconds: still routes to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: routes to fallback + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + } + + @Test + public void testStateBasedFallbackRecoveryViaCallback() { + // After state-based fallback, recovery to primary is immediate when callback fires with READY. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // getState(true): IDLE starts timer, IDLE exceeds timer, READY on recovery check + when(mockChannel.getState(true)) + .thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE, ConnectivityState.READY); + // getState(false): IDLE for constructor registration, READY when callback fires + when(mockChannel.getState(false)) + .thenReturn(ConnectivityState.IDLE, ConnectivityState.READY, ConnectivityState.READY); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // First call - primary not yet timed out, routes to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: state-based fallback kicks in + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + + // Callback fires with primary now READY: clears state flag immediately + stateChangeCallback.get().run(); + + // Next call recovers to primary with no waiting + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } } From 03dcb31df824444505fafdc1ac1a6c49afe7e8ae Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Tue, 24 Mar 2026 05:27:10 +0000 Subject: [PATCH 4/8] Change to wrap FailoverChannel by IsolationChannel --- .../worker/StreamingDataflowWorker.java | 45 ++- .../client/grpc/stubs/FailoverChannel.java | 198 ++++++----- .../grpc/stubs/FailoverChannelTest.java | 313 +++++++++++++++++- 3 files changed, 440 insertions(+), 116 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index cc9501583641..3780f46d8829 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -116,7 +116,6 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.sdk.values.WindowedValues; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.auth.MoreCallCredentials; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; @@ -814,31 +813,25 @@ private static ChannelCache createChannelCache( GrpcDispatcherClient dispatcherClient) { ChannelCache channelCache = ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> { - ManagedChannel primaryChannel = - IsolationChannel.create( - () -> - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - // Create an isolated fallback channel from dispatcher endpoints. - // This ensures both primary and fallback use separate isolated channels. - ManagedChannel fallbackChannel = - IsolationChannel.create( - () -> - remoteChannel( - dispatcherClient.getDispatcherEndpoints().iterator().next(), - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); - return FailoverChannel.create( - primaryChannel, - fallbackChannel, - MoreCallCredentials.from( - new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); - }); + (currentFlowControlSettings, serviceAddress) -> + // IsolationChannel wraps FailoverChannel so that each active RPC gets its own + // FailoverChannel instance. FailoverChannel creates two channels (primary, + // fallback) + // per active RPC. + IsolationChannel.create( + () -> + FailoverChannel.create( + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))), + currentFlowControlSettings.getOnReadyThresholdBytes())); configFetcher .getGlobalConfigHandle() diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index ea550f361f3e..120b18609adb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -19,9 +19,9 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.LongSupplier; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; @@ -46,9 +46,10 @@ * (e.g., during network issues), routes to fallback channel. Switches back as soon as the * primary channel becomes READY again. *
  • RPC Failover: If primary channel RPC fails with transient errors ({@link - * Status.Code#UNAVAILABLE}, {@link Status.Code#DEADLINE_EXCEEDED}, or {@link - * Status.Code#UNKNOWN}), switches to fallback channel and waits for a 1-hour cooling period - * before retrying primary. + * Status.Code#UNAVAILABLE} or {@link Status.Code#UNKNOWN}), or with {@link + * Status.Code#DEADLINE_EXCEEDED} before receiving any response (indicating the connection was + * never established) and connection status is not READY, switches to fallback channel and + * waits for a 1-hour cooling period before retrying primary. * */ @Internal @@ -57,21 +58,35 @@ public final class FailoverChannel extends ManagedChannel { // Time to wait before retrying the primary channel after an RPC-based fallback. private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); + private final ManagedChannel primary; - @Nullable private final ManagedChannel fallback; + private final ManagedChannel fallback; @Nullable private final CallCredentials fallbackCallCredentials; - // Set when primary's connection state has been unavailable for too long. - private final AtomicBoolean useFallbackDueToState = new AtomicBoolean(false); - // Set when an RPC on primary fails with a transient error. - private final AtomicBoolean useFallbackDueToRPC = new AtomicBoolean(false); - private final AtomicLong lastRPCFallbackTimeNanos = new AtomicLong(0); - private final AtomicLong primaryNotReadySinceNanos = new AtomicLong(-1); private final LongSupplier nanoClock; + // Held only during registration to prevent duplicate listener registration. private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); + // All mutable routing state is consolidated here to ensure related fields are updated atomically. + private final FailoverState state = new FailoverState(); + + private static final class FailoverState { + // Set when primary's connection state has been unavailable for too long. + @GuardedBy("this") + boolean useFallbackDueToState; + // Set when an RPC on primary fails with an error. + @GuardedBy("this") + boolean useFallbackDueToRPC; + // Timestamp when RPC-based fallback was triggered. Only meaningful when useFallbackDueToRPC + // is true. + @GuardedBy("this") + long lastRPCFallbackTimeNanos; + // Time when primary first became not-ready. -1 when primary is currently READY. + @GuardedBy("this") + long primaryNotReadySinceNanos = -1; + } private FailoverChannel( ManagedChannel primary, - @Nullable ManagedChannel fallback, + ManagedChannel fallback, @Nullable CallCredentials fallbackCallCredentials, LongSupplier nanoClock) { this.primary = primary; @@ -82,11 +97,6 @@ private FailoverChannel( registerPrimaryStateChangeListener(); } - // Test-only. - public static FailoverChannel create(ManagedChannel primary, ManagedChannel fallback) { - return new FailoverChannel(primary, fallback, null, System::nanoTime); - } - public static FailoverChannel create( ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { return new FailoverChannel(primary, fallback, fallbackCallCredentials, System::nanoTime); @@ -108,28 +118,36 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - // Check if the RPC-based cooling period has elapsed. - if (useFallbackDueToRPC.get()) { - long timeSinceLastFallback = nanoClock.getAsLong() - lastRPCFallbackTimeNanos.get(); - if (timeSinceLastFallback >= FALLBACK_COOLING_PERIOD_NANOS) { - if (useFallbackDueToRPC.compareAndSet(true, false)) { + // Read connectivity state before acquiring the lock to avoid calling an external API while + // holding our lock. + ConnectivityState primaryState = primary.getState(false); + final boolean useFallback; + synchronized (state) { + // Step 1: If we switched to fallback due to a failed RPC, check whether enough time has + // elapsed to retry primary. If so, clear the flag — the next step will then re-evaluate + // whether primary is actually healthy before committing to routing there. + if (state.useFallbackDueToRPC) { + long timeSinceLastFallback = nanoClock.getAsLong() - state.lastRPCFallbackTimeNanos; + if (timeSinceLastFallback >= FALLBACK_COOLING_PERIOD_NANOS) { + state.useFallbackDueToRPC = false; LOG.info("Primary channel cooling period elapsed; switching back from fallback."); } } - } - if (fallback != null && (useFallbackDueToRPC.get() || useFallbackDueToState.get())) { - return new FailoverClientCall<>( - fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), - true, - methodDescriptor.getFullMethodName()); + // Step 2: If neither fallback flag is set, inspect the primary's connectivity state. This + // may set useFallbackDueToState if primary has been non-READY for longer than the + // threshold. Skipped when already on fallback. + // useFallbackDueToState is cleared in onPrimaryStateChanged callback when primary becomes + // READY again. + if (!state.useFallbackDueToRPC && !state.useFallbackDueToState) { + checkAndUpdateStateFallback(primaryState); + } + + // Step 3: Decide which channel to route the request to based on the current state. + useFallback = state.useFallbackDueToRPC || state.useFallbackDueToState; } - // If primary has not become ready for a sustained period, fail over to fallback. - if (fallback != null && shouldFallBackDueToPrimaryState()) { - if (useFallbackDueToState.compareAndSet(false, true)) { - LOG.warn("Primary connection unavailable. Switching to secondary connection."); - } + if (useFallback) { return new FailoverClientCall<>( fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), true, @@ -145,57 +163,49 @@ public ClientCall newCall( @Override public ManagedChannel shutdown() { primary.shutdown(); - if (fallback != null) { - fallback.shutdown(); - } + fallback.shutdown(); return this; } @Override public ManagedChannel shutdownNow() { primary.shutdownNow(); - if (fallback != null) { - fallback.shutdownNow(); - } + fallback.shutdownNow(); return this; } @Override public boolean isShutdown() { - return primary.isShutdown() && (fallback == null || fallback.isShutdown()); + return primary.isShutdown() && fallback.isShutdown(); } @Override public boolean isTerminated() { - return primary.isTerminated() && (fallback == null || fallback.isTerminated()); + return primary.isTerminated() && fallback.isTerminated(); } @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { long endTimeNanos = nanoClock.getAsLong() + unit.toNanos(timeout); boolean primaryTerminated = primary.awaitTermination(timeout, unit); - if (fallback != null) { - long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); - return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); - } - return primaryTerminated; + long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); + return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); } - private boolean shouldFallbackBasedOnRPCStatus(Status status) { + private boolean shouldFallbackBasedOnRPCStatus(Status status, boolean receivedResponse) { switch (status.getCode()) { case UNAVAILABLE: - case DEADLINE_EXCEEDED: case UNKNOWN: return true; + case DEADLINE_EXCEEDED: + // Only failover if no response was received. If a response was received, the connection + // was healthy and the timeout is an application-level issue, not a connectivity problem. + return !receivedResponse; default: return false; } } - private boolean hasFallbackChannel() { - return fallback != null; - } - private CallOptions applyFallbackCredentials(CallOptions callOptions) { if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { return callOptions.withCallCredentials(fallbackCallCredentials); @@ -203,30 +213,44 @@ private CallOptions applyFallbackCredentials(CallOptions callOptions) { return callOptions; } - private boolean shouldFallBackDueToPrimaryState() { - ConnectivityState connectivityState = primary.getState(true); - if (connectivityState == ConnectivityState.READY) { - primaryNotReadySinceNanos.set(-1); - return false; + /** + * Checks primary channel connectivity state and updates {@code state.useFallbackDueToState} if + * the primary has been not-ready long enough to warrant failover. + */ + @GuardedBy("state") + private void checkAndUpdateStateFallback(ConnectivityState connectivityState) { + // gRPC's state machine only transitions to IDLE from READY. Hence, we treat both + // READY and IDLE as healthy states. + if (connectivityState == ConnectivityState.READY + || connectivityState == ConnectivityState.IDLE) { + state.primaryNotReadySinceNanos = -1; + return; } long currentTimeNanos = nanoClock.getAsLong(); - if (primaryNotReadySinceNanos.get() < 0) { - primaryNotReadySinceNanos.set(currentTimeNanos); + if (state.primaryNotReadySinceNanos < 0) { + state.primaryNotReadySinceNanos = currentTimeNanos; + } + if (currentTimeNanos - state.primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS) { + if (!state.useFallbackDueToState) { + state.useFallbackDueToState = true; + LOG.warn("Primary connection unavailable. Switching to secondary connection."); + } } - return currentTimeNanos - primaryNotReadySinceNanos.get() > PRIMARY_NOT_READY_WAIT_NANOS; } - private void notifyFailure(Status status, boolean isFallback, String methodName) { - if (!status.isOk() - && !isFallback - && hasFallbackChannel() - && shouldFallbackBasedOnRPCStatus(status)) { - if (useFallbackDueToRPC.compareAndSet(false, true)) { - lastRPCFallbackTimeNanos.set(nanoClock.getAsLong()); - LOG.warn( - "Primary connection failed for method: {}. Switching to secondary connection. Status: {}", - methodName, - status.getCode()); + private void notifyCallDone( + Status status, boolean isFallback, String methodName, boolean receivedResponse) { + if (!status.isOk() && !isFallback && shouldFallbackBasedOnRPCStatus(status, receivedResponse)) { + synchronized (state) { + if (!state.useFallbackDueToRPC) { + state.useFallbackDueToRPC = true; + state.lastRPCFallbackTimeNanos = nanoClock.getAsLong(); + LOG.warn( + "Primary connection failed for method: {}. Switching to secondary connection." + + " Status: {}", + methodName, + status.getCode()); + } } } else if (isFallback && !status.isOk()) { LOG.warn( @@ -238,14 +262,18 @@ private final class FailoverClientCall extends SimpleForwardingClientCall { private final boolean isFallback; private final String methodName; + // Tracks whether any response message was received. Volatile ensures the write in onMessage + // is visible to the read in onClose even if they execute on different threads within gRPC's + // SerializingExecutor. + private volatile boolean receivedResponse = false; /** * @param delegate the underlying ClientCall (either primary or fallback) * @param isFallback true if {@code delegate} is a fallback channel call, false if it is a - * primary channel call. This flag is inspected by {@link #notifyFailure} to determine + * primary channel call. This flag is inspected by {@link #notifyCallDone} to determine * whether a failure should trigger switching to the fallback channel (only primary failures * do). - * @param methodName full gRPC method name (for logging) + * @param methodName gRPC method name (for logging) */ FailoverClientCall(ClientCall delegate, boolean isFallback, String methodName) { super(delegate); @@ -257,9 +285,15 @@ private final class FailoverClientCall public void start(Listener responseListener, Metadata headers) { super.start( new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onMessage(RespT message) { + receivedResponse = true; + super.onMessage(message); + } + @Override public void onClose(Status status, Metadata trailers) { - notifyFailure(status, isFallback, methodName); + notifyCallDone(status, isFallback, methodName, receivedResponse); super.onClose(status, trailers); } }, @@ -287,14 +321,22 @@ private void onPrimaryStateChanged() { return; } - // If primary is READY, clear state-based fallback immediately. + // If primary is READY, clear both useFallbackDueToState useFallbackDueToRPC flag. + // This ensures we switch back to primary as soon as it recovers, + // regardless of which failover mode triggered the switch. if (primary.getState(false) == ConnectivityState.READY) { - if (useFallbackDueToState.compareAndSet(true, false)) { - LOG.info("Primary channel recovered; switching back from fallback."); + synchronized (state) { + boolean wasOnFallback = state.useFallbackDueToState || state.useFallbackDueToRPC; + state.useFallbackDueToState = false; + state.useFallbackDueToRPC = false; + state.primaryNotReadySinceNanos = -1; + if (wasOnFallback) { + LOG.info("Primary channel recovered; switching back from fallback."); + } } } - // Always re-register for next state change (unless shutdown) + // Always re-register for next state change (unless shutdown). if (!isShutdown() && !isTerminated()) { stateChangeListenerRegistered.set(false); registerPrimaryStateChangeListener(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index c98004943613..204ae7d60ca0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -24,9 +24,17 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -55,7 +63,26 @@ public class FailoverChannelTest { .setResponseMarshaller(new IsolationChannelTest.NoopMarshaller()) .build(); - @Test + private static FailoverChannel createForTest(ManagedChannel primary, ManagedChannel fallback) { + return FailoverChannel.forTest(primary, fallback, null, System::nanoTime); + } + + /** + * Starts a call on the primary channel, captures the injected listener, and fires onClose with + * the given status. Use this to trigger RPC-based failover in tests. + */ + private void triggerRPCFailure( + FailoverChannel channel, ClientCall underlying, Status status) + throws Exception { + channel + .newCall(methodDescriptor, CallOptions.DEFAULT) + .start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlying).start(captor.capture(), any()); + captor.getValue().onClose(status, new Metadata()); + } + + public void testRPCFailureTriggersFallback() throws Exception { // RPC failure with UNAVAILABLE should switch to fallback channel. ManagedChannel mockChannel = mock(ManagedChannel.class); @@ -65,7 +92,7 @@ public void testRPCFailureTriggersFallback() throws Exception { when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); - FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel); + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); ClientCall call1 = failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); @@ -91,7 +118,8 @@ public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { ClientCall fallbackCall = mock(ClientCall.class); when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); - when(mockChannel.getState(true)).thenReturn(ConnectivityState.READY); + + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = @@ -105,17 +133,64 @@ public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { verify(underlyingCall).start(captor.capture(), any()); captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); - // Within cooling period: still on fallback + // Within cooling period, still on fallback time.addAndGet(TimeUnit.MINUTES.toNanos(30)); failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); - // After cooling period: recovers to primary + // After cooling period, recovers to primary time.addAndGet(TimeUnit.MINUTES.toNanos(40)); failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel, atLeast(2)).newCall(any(), any()); } + @Test + public void testRPCFallbackClearedByConnectivityRecovery() throws Exception { + // Race condition: RPC failure observed just before connectivity callback fires READY. + // Once the channel goes through unhealthy→healthy, the cooling period must be cancelled + // and traffic must return to primary immediately (not wait 1 hour). + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + when(mockChannel.getState(false)).thenReturn(ConnectivityState.IDLE, ConnectivityState.READY); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // RPC failure results in entering cooling period + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), any()); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + // Still within cooling period, routes to fallback + time.addAndGet(TimeUnit.MINUTES.toNanos(30)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + + // Primary recovers and callback fires READY, clearing the cooling period + stateChangeCallback.get().run(); + + // Verify immediately routes back to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + @Test public void testFallbackWithCredentials() throws Exception { // Fallback channel should receive custom credentials when provided. @@ -150,7 +225,13 @@ public void testStateFallbackAfterPrimaryNotReady() { ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); - when(mockChannel.getState(true)).thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE); + // IDLE for constructor registration, TRANSIENT_FAILURE for the + // two checkAndUpdateStateFallback() calls. + when(mockChannel.getState(false)) + .thenReturn( + ConnectivityState.IDLE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.TRANSIENT_FAILURE); AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = @@ -166,6 +247,32 @@ public void testStateFallbackAfterPrimaryNotReady() { verify(mockFallbackChannel).newCall(any(), any()); } + @Test + public void testIdleStateNotTreatedAsFallback() { + // IDLE is a normal healthy state (channel is not actively connected but will reconnect on + // demand). It must NOT start the not-ready timer or trigger state-based fallback, even after + // more than 10 seconds. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // Primary stays IDLE the entire time (constructor registration + all state checks). + when(mockChannel.getState(false)).thenReturn(ConnectivityState.IDLE); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Advance well past the 10-second threshold while primary remains IDLE + time.addAndGet(TimeUnit.SECONDS.toNanos(30)); + + // IDLE must not trigger fallback — all calls still route to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + verify(mockFallbackChannel, never()).newCall(any(), any()); + } + @Test public void testStateBasedFallbackRecoveryViaCallback() { // After state-based fallback, recovery to primary is immediate when callback fires with READY. @@ -173,12 +280,15 @@ public void testStateBasedFallbackRecoveryViaCallback() { ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); - // getState(true): IDLE starts timer, IDLE exceeds timer, READY on recovery check - when(mockChannel.getState(true)) - .thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE, ConnectivityState.READY); - // getState(false): IDLE for constructor registration, READY when callback fires + // IDLE for constructor registration, TRANSIENT_FAILURE for call1 (starts + // the 10s timer) and call2 (timer exceeds 10s), READY when callback fires + // (clears state flag) and for subsequent re-registration and state checks. when(mockChannel.getState(false)) - .thenReturn(ConnectivityState.IDLE, ConnectivityState.READY, ConnectivityState.READY); + .thenReturn( + ConnectivityState.IDLE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.READY); AtomicReference stateChangeCallback = new AtomicReference<>(); doAnswer( @@ -202,11 +312,190 @@ public void testStateBasedFallbackRecoveryViaCallback() { failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel).newCall(any(), any()); - // Callback fires with primary now READY: clears state flag immediately + // Callback fires with primary now READY stateChangeCallback.get().run(); // Next call recovers to primary with no waiting failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel, atLeast(2)).newCall(any(), any()); } + + // --- DEADLINE_EXCEEDED tests --- + + @Test + public void testDeadlineExceededWithoutResponseTriggersFallback() throws Exception { + // DEADLINE_EXCEEDED with no response = connection never established. Should failover. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + // Close with DEADLINE_EXCEEDED and no prior onMessage, should trigger failover + captor.getValue().onClose(Status.DEADLINE_EXCEEDED, new Metadata()); + + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + + @Test + public void testDeadlineExceededWithResponseDoesNotTriggerFallback() throws Exception { + // DEADLINE_EXCEEDED after receiving a response, should NOT + // failover. The connection was healthy since at least one response was delivered. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + // Simulate receiving a response before the timeout + captor.getValue().onMessage(new Object()); + // Close with DEADLINE_EXCEEDED after a response, should NOT trigger failover + captor.getValue().onClose(Status.DEADLINE_EXCEEDED, new Metadata()); + + // Next call should still route to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + verify(mockFallbackChannel, never()).newCall(any(), any()); + } + + // --- Concurrency tests --- + + @Test + public void testConcurrentRPCFailuresProduceConsistentFailover() throws Exception { + // Concurrent RPC failures from multiple threads should produce exactly one failover. + // After all threads complete, subsequent calls must consistently route to fallback. + int numThreads = 20; + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + + // Track all primary ClientCalls created so we can fire onClose on each + List> primaryCalls = Collections.synchronizedList(new ArrayList<>()); + when(mockChannel.newCall(any(), any())) + .thenAnswer( + inv -> { + ClientCall call = mock(ClientCall.class); + primaryCalls.add(call); + return call; + }); + when(mockFallbackChannel.newCall(any(), any())).thenAnswer(inv -> mock(ClientCall.class)); + // Ensure state-based fallback does not interfere + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + // Start N calls on primary and capture their listeners + List> listeners = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + ClientCall call = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + } + for (ClientCall primaryCall : primaryCalls) { + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(primaryCall).start(captor.capture(), any()); + listeners.add(captor.getValue()); + } + + // All threads fire UNAVAILABLE simultaneously + CyclicBarrier barrier = new CyclicBarrier(numThreads); + List threads = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + final Listener listener = listeners.get(i); + Thread t = + new Thread( + () -> { + try { + barrier.await(); + listener.onClose(Status.UNAVAILABLE, new Metadata()); + } catch (Exception e) { + Thread.currentThread().interrupt(); + } + }); + t.start(); + threads.add(t); + } + for (Thread t : threads) { + t.join(5000); + } + + // All subsequent calls must consistently route to fallback + int subsequentCalls = 5; + for (int i = 0; i < subsequentCalls; i++) { + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + } + verify(mockFallbackChannel, atLeast(subsequentCalls)).newCall(any(), any()); + } + + @Test + public void testConcurrentNewCallsDuringRPCFailoverAreConsistent() throws Exception { + // Calls made concurrently while RPC failover is triggered must route consistently: + // none should be lost and each must go to either primary (before failover) or + // fallback (after). + int numThreads = 20; + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenAnswer(inv -> mock(ClientCall.class)); + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); + + // Set up an in-flight primary call whose failure will trigger RPC-based failover. + ClientCall triggerCall = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + triggerCall.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(listenerCaptor.capture(), any()); + ClientCall.Listener wrappedListener = listenerCaptor.getValue(); + + // All threads (failover trigger + newCall callers) start simultaneously via a barrier. + CyclicBarrier barrier = new CyclicBarrier(numThreads + 1); + List> tasks = new ArrayList<>(); + tasks.add( + () -> { + barrier.await(); + wrappedListener.onClose(Status.UNAVAILABLE, new Metadata()); + return null; + }); + for (int i = 0; i < numThreads; i++) { + tasks.add( + () -> { + barrier.await(); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + return null; + }); + } + + ExecutorService executor = Executors.newFixedThreadPool(numThreads + 1); + executor.invokeAll(tasks, 5, TimeUnit.SECONDS); + executor.shutdown(); + + // After concurrent operations, state must be coherent: subsequent calls go to fallback. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } } From e3ab71ee13a4e43dce5ce96f97143b326029a2ce Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Thu, 26 Mar 2026 04:57:42 +0000 Subject: [PATCH 5/8] Adding channel id in logs + addressing other review comments --- .../worker/StreamingDataflowWorker.java | 38 ++-- .../client/grpc/stubs/FailoverChannel.java | 181 ++++++++++-------- .../grpc/stubs/FailoverChannelTest.java | 53 ++--- .../grpc/stubs/IsolationChannelTest.java | 34 +--- .../client/grpc/stubs/NoopClientCall.java | 21 ++ 5 files changed, 160 insertions(+), 167 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 3780f46d8829..4ce3f9b651f4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -813,25 +813,25 @@ private static ChannelCache createChannelCache( GrpcDispatcherClient dispatcherClient) { ChannelCache channelCache = ChannelCache.create( - (currentFlowControlSettings, serviceAddress) -> - // IsolationChannel wraps FailoverChannel so that each active RPC gets its own - // FailoverChannel instance. FailoverChannel creates two channels (primary, - // fallback) - // per active RPC. - IsolationChannel.create( - () -> - FailoverChannel.create( - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - remoteChannel( - dispatcherClient.getDispatcherEndpoints().iterator().next(), - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - MoreCallCredentials.from( - new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))), - currentFlowControlSettings.getOnReadyThresholdBytes())); + (currentFlowControlSettings, serviceAddress) -> { + // IsolationChannel wrapping FailoverChannel so that each active RPC gets its own + // FailoverChannel instance. FailoverChannel creates two channels (primary, + // fallback) per active RPC. + return IsolationChannel.create( + () -> + FailoverChannel.create( + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))), + currentFlowControlSettings.getOnReadyThresholdBytes()); + }); configFetcher .getGlobalConfigHandle() diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index 120b18609adb..39fdb985a6f4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -19,6 +19,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.LongSupplier; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -55,18 +56,20 @@ @Internal public final class FailoverChannel extends ManagedChannel { private static final Logger LOG = LoggerFactory.getLogger(FailoverChannel.class); + private static final AtomicInteger CHANNEL_ID_COUNTER = new AtomicInteger(0); // Time to wait before retrying the primary channel after an RPC-based fallback. private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); private final ManagedChannel primary; private final ManagedChannel fallback; + private final int channelId; @Nullable private final CallCredentials fallbackCallCredentials; private final LongSupplier nanoClock; // Held only during registration to prevent duplicate listener registration. private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); // All mutable routing state is consolidated here to ensure related fields are updated atomically. - private final FailoverState state = new FailoverState(); + private final FailoverState state; private static final class FailoverState { // Set when primary's connection state has been unavailable for too long. @@ -82,6 +85,72 @@ private static final class FailoverState { // Time when primary first became not-ready. -1 when primary is currently READY. @GuardedBy("this") long primaryNotReadySinceNanos = -1; + + private final int channelId; + + FailoverState(int channelId) { + this.channelId = channelId; + } + + /** + * Determines whether the next RPC should route to the fallback channel, updating internal state + * as needed. + */ + synchronized boolean computeUseFallback(long nowNanos, ConnectivityState primaryState) { + // Clear RPC-based fallback if the cooling period has elapsed. + if (useFallbackDueToRPC + && nowNanos - lastRPCFallbackTimeNanos >= FALLBACK_COOLING_PERIOD_NANOS) { + useFallbackDueToRPC = false; + LOG.info( + "[channel-{}] Primary channel cooling period elapsed; switching back from fallback.", + channelId); + } + // If not already on fallback, check primary connectivity state. + // gRPC's state machine only transitions to IDLE from READY. Treat both as healthy. + if (!useFallbackDueToRPC && !useFallbackDueToState) { + if (primaryState == ConnectivityState.READY || primaryState == ConnectivityState.IDLE) { + primaryNotReadySinceNanos = -1; + } else { + if (primaryNotReadySinceNanos < 0) { + primaryNotReadySinceNanos = nowNanos; + } + if (nowNanos - primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS + && !useFallbackDueToState) { + useFallbackDueToState = true; + LOG.warn( + "[channel-{}] Primary connection unavailable. Switching to secondary connection.", + channelId); + } + } + } + return useFallbackDueToRPC || useFallbackDueToState; + } + + /** + * Transitions the fallback state. + * When toFallback is true (RPC failure) it enables RPC-based fallback if + * not already active and returns true so the caller can log the failure details. + * When toFallback is false (primary recovered) it clears all fallback flags + * and returns true if recovery actually changed state, so the caller can log it. + */ + synchronized boolean transitionFallback(boolean toFallback, long nowNanos) { + if (toFallback) { + if (!useFallbackDueToRPC) { + useFallbackDueToRPC = true; + lastRPCFallbackTimeNanos = nowNanos; + // Return true to indicate fallback state was changed and caller should log the event. + return true; + } + // Already in RPC-based fallback, no state change. + return false; + } + // Clear all fallback state as primary has recovered. + boolean wasOnFallback = useFallbackDueToState || useFallbackDueToRPC; + useFallbackDueToState = false; + useFallbackDueToRPC = false; + primaryNotReadySinceNanos = -1; + return wasOnFallback; + } } private FailoverChannel( @@ -91,6 +160,8 @@ private FailoverChannel( LongSupplier nanoClock) { this.primary = primary; this.fallback = fallback; + this.channelId = CHANNEL_ID_COUNTER.getAndIncrement(); + this.state = new FailoverState(channelId); this.fallbackCallCredentials = fallbackCallCredentials; this.nanoClock = nanoClock; // Register callback to monitor primary channel state changes @@ -118,34 +189,11 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - // Read connectivity state before acquiring the lock to avoid calling an external API while - // holding our lock. + // Read connectivity state and clock before the synchronized call to avoid holding external + // APIs under the state lock. ConnectivityState primaryState = primary.getState(false); - final boolean useFallback; - synchronized (state) { - // Step 1: If we switched to fallback due to a failed RPC, check whether enough time has - // elapsed to retry primary. If so, clear the flag — the next step will then re-evaluate - // whether primary is actually healthy before committing to routing there. - if (state.useFallbackDueToRPC) { - long timeSinceLastFallback = nanoClock.getAsLong() - state.lastRPCFallbackTimeNanos; - if (timeSinceLastFallback >= FALLBACK_COOLING_PERIOD_NANOS) { - state.useFallbackDueToRPC = false; - LOG.info("Primary channel cooling period elapsed; switching back from fallback."); - } - } - - // Step 2: If neither fallback flag is set, inspect the primary's connectivity state. This - // may set useFallbackDueToState if primary has been non-READY for longer than the - // threshold. Skipped when already on fallback. - // useFallbackDueToState is cleared in onPrimaryStateChanged callback when primary becomes - // READY again. - if (!state.useFallbackDueToRPC && !state.useFallbackDueToState) { - checkAndUpdateStateFallback(primaryState); - } - - // Step 3: Decide which channel to route the request to based on the current state. - useFallback = state.useFallbackDueToRPC || state.useFallbackDueToState; - } + long nowNanos = nanoClock.getAsLong(); + boolean useFallback = state.computeUseFallback(nowNanos, primaryState); if (useFallback) { return new FailoverClientCall<>( @@ -193,68 +241,45 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE } private boolean shouldFallbackBasedOnRPCStatus(Status status, boolean receivedResponse) { + // If a response was received, the connection was healthy and any error is an application-level + // issue, not a connectivity problem. Never failover in this case regardless of status. + if (receivedResponse) { + return false; + } switch (status.getCode()) { case UNAVAILABLE: case UNKNOWN: - return true; case DEADLINE_EXCEEDED: - // Only failover if no response was received. If a response was received, the connection - // was healthy and the timeout is an application-level issue, not a connectivity problem. - return !receivedResponse; + return true; default: return false; } } private CallOptions applyFallbackCredentials(CallOptions callOptions) { - if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { + if (fallbackCallCredentials != null) { return callOptions.withCallCredentials(fallbackCallCredentials); } return callOptions; } - /** - * Checks primary channel connectivity state and updates {@code state.useFallbackDueToState} if - * the primary has been not-ready long enough to warrant failover. - */ - @GuardedBy("state") - private void checkAndUpdateStateFallback(ConnectivityState connectivityState) { - // gRPC's state machine only transitions to IDLE from READY. Hence, we treat both - // READY and IDLE as healthy states. - if (connectivityState == ConnectivityState.READY - || connectivityState == ConnectivityState.IDLE) { - state.primaryNotReadySinceNanos = -1; - return; - } - long currentTimeNanos = nanoClock.getAsLong(); - if (state.primaryNotReadySinceNanos < 0) { - state.primaryNotReadySinceNanos = currentTimeNanos; - } - if (currentTimeNanos - state.primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS) { - if (!state.useFallbackDueToState) { - state.useFallbackDueToState = true; - LOG.warn("Primary connection unavailable. Switching to secondary connection."); - } - } - } - private void notifyCallDone( Status status, boolean isFallback, String methodName, boolean receivedResponse) { if (!status.isOk() && !isFallback && shouldFallbackBasedOnRPCStatus(status, receivedResponse)) { - synchronized (state) { - if (!state.useFallbackDueToRPC) { - state.useFallbackDueToRPC = true; - state.lastRPCFallbackTimeNanos = nanoClock.getAsLong(); - LOG.warn( - "Primary connection failed for method: {}. Switching to secondary connection." - + " Status: {}", - methodName, - status.getCode()); - } + if (state.transitionFallback(true, nanoClock.getAsLong())) { + LOG.warn( + "[channel-{}] Primary connection failed for method: {}. Switching to secondary" + + " connection. Status: {}", + channelId, + methodName, + status.getCode()); } } else if (isFallback && !status.isOk()) { LOG.warn( - "Secondary connection failed for method: {}. Status: {}", methodName, status.getCode()); + "[channel-{}] Secondary connection failed for method: {}. Status: {}", + channelId, + methodName, + status.getCode()); } } @@ -309,7 +334,9 @@ private void registerPrimaryStateChangeListener() { primary.notifyWhenStateChanged(currentState, this::onPrimaryStateChanged); } catch (Exception e) { LOG.warn( - "Failed to register channel state monitor. Continuing with fallback detection.", e); + "[channel-{}] Failed to register channel state monitor. Continuing with fallback detection.", + channelId, + e); stateChangeListenerRegistered.set(false); } } @@ -321,18 +348,12 @@ private void onPrimaryStateChanged() { return; } - // If primary is READY, clear both useFallbackDueToState useFallbackDueToRPC flag. - // This ensures we switch back to primary as soon as it recovers, + // If primary is READY, clear both fallback flags so we immediately resume routing there, // regardless of which failover mode triggered the switch. if (primary.getState(false) == ConnectivityState.READY) { - synchronized (state) { - boolean wasOnFallback = state.useFallbackDueToState || state.useFallbackDueToRPC; - state.useFallbackDueToState = false; - state.useFallbackDueToRPC = false; - state.primaryNotReadySinceNanos = -1; - if (wasOnFallback) { - LOG.info("Primary channel recovered; switching back from fallback."); - } + if (state.transitionFallback(false, 0)) { + LOG.info( + "[channel-{}] Primary channel recovered; switching back from fallback.", channelId); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index 204ae7d60ca0..4633dec996e6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -59,8 +59,8 @@ public class FailoverChannelTest { MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) - .setRequestMarshaller(new IsolationChannelTest.NoopMarshaller()) - .setResponseMarshaller(new IsolationChannelTest.NoopMarshaller()) + .setRequestMarshaller(new NoopClientCall.NoopMarshaller()) + .setResponseMarshaller(new NoopClientCall.NoopMarshaller()) .build(); private static FailoverChannel createForTest(ManagedChannel primary, ManagedChannel fallback) { @@ -74,38 +74,29 @@ private static FailoverChannel createForTest(ManagedChannel primary, ManagedChan private void triggerRPCFailure( FailoverChannel channel, ClientCall underlying, Status status) throws Exception { + Metadata metadata = new Metadata(); channel .newCall(methodDescriptor, CallOptions.DEFAULT) - .start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + .start(new NoopClientCall.NoopClientCallListener<>(), metadata); ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); - verify(underlying).start(captor.capture(), any()); + verify(underlying).start(captor.capture(), same(metadata)); captor.getValue().onClose(status, new Metadata()); } - + @Test public void testRPCFailureTriggersFallback() throws Exception { // RPC failure with UNAVAILABLE should switch to fallback channel. ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); ClientCall underlyingCall = mock(ClientCall.class); - ClientCall fallbackCall = mock(ClientCall.class); when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); - when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); FailoverChannel failoverChannel = createForTest(mockChannel, mockFallbackChannel); - ClientCall call1 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - Metadata metadata1 = new Metadata(); - call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); - ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); - verify(underlyingCall).start(captor.capture(), same(metadata1)); - captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); - - ClientCall call2 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); } @@ -115,23 +106,15 @@ public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); ClientCall underlyingCall = mock(ClientCall.class); - ClientCall fallbackCall = mock(ClientCall.class); when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); - when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); - + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); - // Trigger RPC failure fallback - ClientCall call1 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); - ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); - verify(underlyingCall).start(captor.capture(), any()); - captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); // Within cooling period, still on fallback time.addAndGet(TimeUnit.MINUTES.toNanos(30)); @@ -171,12 +154,7 @@ public void testRPCFallbackClearedByConnectivityRecovery() throws Exception { FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); // RPC failure results in entering cooling period - ClientCall call1 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); - ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); - verify(underlyingCall).start(captor.capture(), any()); - captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); // Still within cooling period, routes to fallback time.addAndGet(TimeUnit.MINUTES.toNanos(30)); @@ -204,12 +182,7 @@ public void testFallbackWithCredentials() throws Exception { FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); - ClientCall call1 = - failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); - call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); - ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); - verify(underlyingCall).start(captor.capture(), any()); - captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java index 580bf873d916..4eb31caf3501 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java @@ -31,8 +31,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.IOException; -import java.io.InputStream; import java.util.concurrent.TimeUnit; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; @@ -40,7 +38,6 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor.Marshaller; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.junit.Test; @@ -54,30 +51,12 @@ public class IsolationChannelTest { private Supplier channelSupplier = mock(Supplier.class); - public static class NoopMarshaller implements Marshaller { - - @Override - public InputStream stream(Object o) { - return new InputStream() { - @Override - public int read() throws IOException { - return 0; - } - }; - } - - @Override - public Object parse(InputStream inputStream) { - return null; - } - }; - private MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) - .setRequestMarshaller(new NoopMarshaller()) - .setResponseMarshaller(new NoopMarshaller()) + .setRequestMarshaller(new NoopClientCall.NoopMarshaller()) + .setResponseMarshaller(new NoopClientCall.NoopMarshaller()) .build(); @Test @@ -384,19 +363,18 @@ public void testAwaitTermination() throws Exception { when(mockChannel.shutdown()).thenReturn(mockChannel); when(mockChannel.isTerminated()).thenReturn(false, false, false, true, true); - when(mockChannel.awaitTermination(longThat(l -> l < 2_000_000), eq(TimeUnit.NANOSECONDS))) + when(mockChannel.awaitTermination(longThat(l -> l > 0), eq(TimeUnit.NANOSECONDS))) .thenReturn(false, true); isolationChannel.shutdown(); - assertFalse(isolationChannel.awaitTermination(1, TimeUnit.MILLISECONDS)); - assertTrue(isolationChannel.awaitTermination(1, TimeUnit.MILLISECONDS)); + assertFalse(isolationChannel.awaitTermination(10, TimeUnit.SECONDS)); + assertTrue(isolationChannel.awaitTermination(10, TimeUnit.SECONDS)); assertTrue(isolationChannel.isTerminated()); verify(channelSupplier, times(1)).get(); verify(mockChannel, times(1)).shutdown(); verify(mockChannel, times(5)).isTerminated(); - verify(mockChannel, times(2)) - .awaitTermination(longThat(l -> l < 2_000_000), eq(TimeUnit.NANOSECONDS)); + verify(mockChannel, times(2)).awaitTermination(longThat(l -> l > 0), eq(TimeUnit.NANOSECONDS)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java index 93a421e0d618..1f62aa57f57b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java @@ -17,8 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; +import java.io.IOException; +import java.io.InputStream; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor.Marshaller; /** * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in @@ -34,6 +37,24 @@ public class NoopClientCall extends ClientCall { */ public static class NoopClientCallListener extends ClientCall.Listener {} + public static class NoopMarshaller implements Marshaller { + + @Override + public InputStream stream(Object o) { + return new InputStream() { + @Override + public int read() throws IOException { + return 0; + } + }; + } + + @Override + public Object parse(InputStream inputStream) { + return null; + } + } + @Override public void start(ClientCall.Listener listener, Metadata headers) {} From 0a6604c14339568842a9bec007dce00b4a521de0 Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Thu, 26 Mar 2026 06:02:46 +0000 Subject: [PATCH 6/8] Removing getState from newCall --- .../client/grpc/stubs/FailoverChannel.java | 64 ++++++++++--------- .../grpc/stubs/FailoverChannelTest.java | 44 +++++++++---- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index 39fdb985a6f4..71119e45150a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -96,7 +96,7 @@ private static final class FailoverState { * Determines whether the next RPC should route to the fallback channel, updating internal state * as needed. */ - synchronized boolean computeUseFallback(long nowNanos, ConnectivityState primaryState) { + synchronized boolean computeUseFallback(long nowNanos) { // Clear RPC-based fallback if the cooling period has elapsed. if (useFallbackDueToRPC && nowNanos - lastRPCFallbackTimeNanos >= FALLBACK_COOLING_PERIOD_NANOS) { @@ -105,33 +105,35 @@ synchronized boolean computeUseFallback(long nowNanos, ConnectivityState primary "[channel-{}] Primary channel cooling period elapsed; switching back from fallback.", channelId); } - // If not already on fallback, check primary connectivity state. - // gRPC's state machine only transitions to IDLE from READY. Treat both as healthy. - if (!useFallbackDueToRPC && !useFallbackDueToState) { - if (primaryState == ConnectivityState.READY || primaryState == ConnectivityState.IDLE) { - primaryNotReadySinceNanos = -1; - } else { - if (primaryNotReadySinceNanos < 0) { - primaryNotReadySinceNanos = nowNanos; - } - if (nowNanos - primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS - && !useFallbackDueToState) { - useFallbackDueToState = true; - LOG.warn( - "[channel-{}] Primary connection unavailable. Switching to secondary connection.", - channelId); - } - } + // Check if primary has been not-ready long enough to switch to fallback. + // primaryNotReadySinceNanos is set by the state-change callback when primary is not ready. + if (!useFallbackDueToRPC + && !useFallbackDueToState + && primaryNotReadySinceNanos >= 0 + && nowNanos - primaryNotReadySinceNanos > PRIMARY_NOT_READY_WAIT_NANOS) { + useFallbackDueToState = true; + LOG.warn( + "[channel-{}] Primary connection unavailable. Switching to secondary connection.", + channelId); } return useFallbackDueToRPC || useFallbackDueToState; } /** - * Transitions the fallback state. - * When toFallback is true (RPC failure) it enables RPC-based fallback if - * not already active and returns true so the caller can log the failure details. - * When toFallback is false (primary recovered) it clears all fallback flags - * and returns true if recovery actually changed state, so the caller can log it. + * Starts the not-ready grace period timer. Called by the state-change callback when primary + * transitions to a non-ready state. Has no effect if already tracking or already on fallback. + */ + synchronized void markPrimaryNotReady(long nowNanos) { + if (!useFallbackDueToRPC && !useFallbackDueToState && primaryNotReadySinceNanos < 0) { + primaryNotReadySinceNanos = nowNanos; + } + } + + /** + * Transitions the fallback state. When toFallback is true (RPC failure) it enables RPC-based + * fallback if not already active and returns true so the caller can log the failure details. + * When toFallback is false (primary recovered) it clears all fallback flags and returns true if + * recovery actually changed state, so the caller can log it. */ synchronized boolean transitionFallback(boolean toFallback, long nowNanos) { if (toFallback) { @@ -189,11 +191,9 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - // Read connectivity state and clock before the synchronized call to avoid holding external - // APIs under the state lock. - ConnectivityState primaryState = primary.getState(false); + // Read the clock before the synchronized call to avoid holding it under the state lock. long nowNanos = nanoClock.getAsLong(); - boolean useFallback = state.computeUseFallback(nowNanos, primaryState); + boolean useFallback = state.computeUseFallback(nowNanos); if (useFallback) { return new FailoverClientCall<>( @@ -348,13 +348,17 @@ private void onPrimaryStateChanged() { return; } - // If primary is READY, clear both fallback flags so we immediately resume routing there, - // regardless of which failover mode triggered the switch. - if (primary.getState(false) == ConnectivityState.READY) { + ConnectivityState newState = primary.getState(false); + // IDLE means the channel was READY but has no active RPCs — treat as healthy. + if (newState == ConnectivityState.READY || newState == ConnectivityState.IDLE) { if (state.transitionFallback(false, 0)) { LOG.info( "[channel-{}] Primary channel recovered; switching back from fallback.", channelId); } + } else { + // Primary is not ready; start the grace period timer so computeUseFallback can + // switch to fallback once PRIMARY_NOT_READY_WAIT_NANOS elapses. + state.markPrimaryNotReady(nanoClock.getAsLong()); } // Always re-register for next state change (unless shutdown). diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index 4633dec996e6..ee72a5d4a993 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -193,28 +193,41 @@ public void testFallbackWithCredentials() throws Exception { @Test public void testStateFallbackAfterPrimaryNotReady() { - // If primary connection is not ready for 10+ seconds, routes to fallback. + // If the state-change callback signals primary is not ready for 10+ seconds, + // the next newCall() should route to fallback. ManagedChannel mockChannel = mock(ManagedChannel.class); ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); - // IDLE for constructor registration, TRANSIENT_FAILURE for the - // two checkAndUpdateStateFallback() calls. + // IDLE for constructor registration, TRANSIENT_FAILURE when callback fires, + // TRANSIENT_FAILURE for re-registration after the callback. when(mockChannel.getState(false)) .thenReturn( ConnectivityState.IDLE, ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.TRANSIENT_FAILURE); + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); - // Within 10 seconds: still routes to primary + // Callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer. + stateChangeCallback.get().run(); + + // Within 10 seconds: grace period not elapsed, routes to primary. failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel).newCall(any(), any()); - // After 10 seconds: routes to fallback + // After 10 seconds: routes to fallback. time.addAndGet(TimeUnit.SECONDS.toNanos(11)); failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel).newCall(any(), any()); @@ -253,14 +266,18 @@ public void testStateBasedFallbackRecoveryViaCallback() { ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); - // IDLE for constructor registration, TRANSIENT_FAILURE for call1 (starts - // the 10s timer) and call2 (timer exceeds 10s), READY when callback fires - // (clears state flag) and for subsequent re-registration and state checks. + // getState() calls in order: + // 1. constructor registerPrimaryStateChangeListener() → IDLE + // 2. onPrimaryStateChanged() fires (TRANSIENT_FAILURE) → TRANSIENT_FAILURE + // 3. re-registerPrimaryStateChangeListener() after 1st callback → TRANSIENT_FAILURE + // 4. onPrimaryStateChanged() fires (READY) → READY + // 5. re-registerPrimaryStateChangeListener() after 2nd callback → READY when(mockChannel.getState(false)) .thenReturn( ConnectivityState.IDLE, ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.TRANSIENT_FAILURE, + ConnectivityState.READY, ConnectivityState.READY); AtomicReference stateChangeCallback = new AtomicReference<>(); @@ -276,19 +293,22 @@ public void testStateBasedFallbackRecoveryViaCallback() { FailoverChannel failoverChannel = FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); - // First call - primary not yet timed out, routes to primary + // First callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer at t=0. + stateChangeCallback.get().run(); + + // Within grace period: routes to primary. failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel).newCall(any(), any()); - // After 10 seconds: state-based fallback kicks in + // After 10 seconds: state-based fallback kicks in. time.addAndGet(TimeUnit.SECONDS.toNanos(11)); failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockFallbackChannel).newCall(any(), any()); - // Callback fires with primary now READY + // Second callback fires: primary is now READY, clears all fallback state. stateChangeCallback.get().run(); - // Next call recovers to primary with no waiting + // Next call recovers to primary immediately. failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); verify(mockChannel, atLeast(2)).newCall(any(), any()); } From 3debb37b64aa3a1c94389be018948f775c1df3ed Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Tue, 31 Mar 2026 20:15:15 +0000 Subject: [PATCH 7/8] Initializaing fallback channel lazily --- .../worker/StreamingDataflowWorker.java | 13 +- .../client/grpc/stubs/FailoverChannel.java | 156 ++++++++++++++---- .../grpc/stubs/FailoverChannelTest.java | 60 ++++++- 3 files changed, 186 insertions(+), 43 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 4ce3f9b651f4..b78d416d4f6a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -815,8 +815,8 @@ private static ChannelCache createChannelCache( ChannelCache.create( (currentFlowControlSettings, serviceAddress) -> { // IsolationChannel wrapping FailoverChannel so that each active RPC gets its own - // FailoverChannel instance. FailoverChannel creates two channels (primary, - // fallback) per active RPC. + // FailoverChannel instance. The fallback channel is created lazily, at most once, + // only if failover is actually needed. return IsolationChannel.create( () -> FailoverChannel.create( @@ -824,10 +824,11 @@ private static ChannelCache createChannelCache( serviceAddress, workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), currentFlowControlSettings), - remoteChannel( - dispatcherClient.getDispatcherEndpoints().iterator().next(), - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), + () -> + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), MoreCallCredentials.from( new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))), currentFlowControlSettings.getOnReadyThresholdBytes()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index 71119e45150a..1dda60142732 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.LongSupplier; +import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import org.apache.beam.sdk.annotations.Internal; @@ -34,6 +35,7 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,11 +48,11 @@ *
  • Connection Status Failover: If the primary channel is not ready for 10+ seconds * (e.g., during network issues), routes to fallback channel. Switches back as soon as the * primary channel becomes READY again. - *
  • RPC Failover: If primary channel RPC fails with transient errors ({@link - * Status.Code#UNAVAILABLE} or {@link Status.Code#UNKNOWN}), or with {@link + *
  • RPC Failover: If primary channel RPCs fail continuously with transient errors + * ({@link Status.Code#UNAVAILABLE} or {@link Status.Code#UNKNOWN}), or with {@link * Status.Code#DEADLINE_EXCEEDED} before receiving any response (indicating the connection was - * never established) and connection status is not READY, switches to fallback channel and - * waits for a 1-hour cooling period before retrying primary. + * never established), for 30+ seconds without any successful response, switches to fallback + * channel and waits for a 1-hour cooling period before retrying primary. * */ @Internal @@ -60,15 +62,18 @@ public final class FailoverChannel extends ManagedChannel { // Time to wait before retrying the primary channel after an RPC-based fallback. private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); + // Minimum duration of continuous RPC failures required before switching to fallback. + private static final long RPC_FAILURE_THRESHOLD_NANOS = TimeUnit.SECONDS.toNanos(30); private final ManagedChannel primary; - private final ManagedChannel fallback; + private final Supplier fallbackSupplier; + // Non-null once the fallback channel has been created. + @Nullable private volatile ManagedChannel fallback; private final int channelId; @Nullable private final CallCredentials fallbackCallCredentials; private final LongSupplier nanoClock; // Held only during registration to prevent duplicate listener registration. private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); - // All mutable routing state is consolidated here to ensure related fields are updated atomically. private final FailoverState state; private static final class FailoverState { @@ -85,11 +90,16 @@ private static final class FailoverState { // Time when primary first became not-ready. -1 when primary is currently READY. @GuardedBy("this") long primaryNotReadySinceNanos = -1; + // Time when the first consecutive RPC failure was observed. -1 when no failure streak. + @GuardedBy("this") + long firstRPCFailureSinceNanos = -1; private final int channelId; + private final long rpcFailureThresholdNanos; - FailoverState(int channelId) { + FailoverState(int channelId, long rpcFailureThresholdNanos) { this.channelId = channelId; + this.rpcFailureThresholdNanos = rpcFailureThresholdNanos; } /** @@ -101,6 +111,7 @@ synchronized boolean computeUseFallback(long nowNanos) { if (useFallbackDueToRPC && nowNanos - lastRPCFallbackTimeNanos >= FALLBACK_COOLING_PERIOD_NANOS) { useFallbackDueToRPC = false; + firstRPCFailureSinceNanos = -1; LOG.info( "[channel-{}] Primary channel cooling period elapsed; switching back from fallback.", channelId); @@ -130,49 +141,84 @@ synchronized void markPrimaryNotReady(long nowNanos) { } /** - * Transitions the fallback state. When toFallback is true (RPC failure) it enables RPC-based - * fallback if not already active and returns true so the caller can log the failure details. - * When toFallback is false (primary recovered) it clears all fallback flags and returns true if - * recovery actually changed state, so the caller can log it. + * Clears all fallback state when the primary channel recovers (READY/IDLE callback). Returns + * true if any fallback state was actually cleared, so the caller can log the recovery. + */ + synchronized boolean markPrimaryReady() { + boolean wasOnFallback = useFallbackDueToState || useFallbackDueToRPC; + useFallbackDueToState = false; + useFallbackDueToRPC = false; + primaryNotReadySinceNanos = -1; + firstRPCFailureSinceNanos = -1; + return wasOnFallback; + } + + /** + * Records an RPC failure on the primary channel. Switches to RPC-based fallback only after + * failures have persisted for {@link FailoverChannel#RPC_FAILURE_THRESHOLD_NANOS}. Returns true + * if fallback was newly triggered so the caller can log the event. */ - synchronized boolean transitionFallback(boolean toFallback, long nowNanos) { - if (toFallback) { - if (!useFallbackDueToRPC) { + synchronized boolean notePrimaryRpcFailure(long nowNanos) { + if (useFallbackDueToRPC) { + return false; + } + if (firstRPCFailureSinceNanos < 0) { + if (rpcFailureThresholdNanos <= 0) { useFallbackDueToRPC = true; lastRPCFallbackTimeNanos = nowNanos; - // Return true to indicate fallback state was changed and caller should log the event. return true; } - // Already in RPC-based fallback, no state change. + // This is the first failure. Start the timer. + firstRPCFailureSinceNanos = nowNanos; return false; } - // Clear all fallback state as primary has recovered. - boolean wasOnFallback = useFallbackDueToState || useFallbackDueToRPC; - useFallbackDueToState = false; - useFallbackDueToRPC = false; - primaryNotReadySinceNanos = -1; - return wasOnFallback; + if (nowNanos - firstRPCFailureSinceNanos >= rpcFailureThresholdNanos) { + // Failures have persisted long enough. Switch to fallback. + useFallbackDueToRPC = true; + lastRPCFallbackTimeNanos = nowNanos; + firstRPCFailureSinceNanos = -1; + return true; + } + return false; + } + + /** Resets the RPC failure streak. Called when a primary RPC succeeds. */ + synchronized void notePrimaryRpcSuccess() { + firstRPCFailureSinceNanos = -1; } } private FailoverChannel( ManagedChannel primary, - ManagedChannel fallback, + Supplier fallbackSupplier, @Nullable CallCredentials fallbackCallCredentials, - LongSupplier nanoClock) { + LongSupplier nanoClock, + long rpcFailureThresholdNanos) { this.primary = primary; - this.fallback = fallback; + this.fallbackSupplier = Suppliers.memoize(fallbackSupplier::get); this.channelId = CHANNEL_ID_COUNTER.getAndIncrement(); - this.state = new FailoverState(channelId); + this.state = new FailoverState(channelId, rpcFailureThresholdNanos); this.fallbackCallCredentials = fallbackCallCredentials; this.nanoClock = nanoClock; // Register callback to monitor primary channel state changes registerPrimaryStateChangeListener(); } + public static FailoverChannel create( + ManagedChannel primary, + Supplier fallbackSupplier, + CallCredentials fallbackCallCredentials) { + return new FailoverChannel( + primary, + fallbackSupplier, + fallbackCallCredentials, + System::nanoTime, + RPC_FAILURE_THRESHOLD_NANOS); + } + public static FailoverChannel create( ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { - return new FailoverChannel(primary, fallback, fallbackCallCredentials, System::nanoTime); + return create(primary, () -> fallback, fallbackCallCredentials); } static FailoverChannel forTest( @@ -180,7 +226,23 @@ static FailoverChannel forTest( ManagedChannel fallback, CallCredentials fallbackCallCredentials, LongSupplier nanoClock) { - return new FailoverChannel(primary, fallback, fallbackCallCredentials, nanoClock); + return forTest(primary, fallback, fallbackCallCredentials, nanoClock, 0L); + } + + static FailoverChannel forTest( + ManagedChannel primary, + ManagedChannel fallback, + CallCredentials fallbackCallCredentials, + LongSupplier nanoClock, + long rpcFailureThresholdNanos) { + return new FailoverChannel( + primary, () -> fallback, fallbackCallCredentials, nanoClock, rpcFailureThresholdNanos); + } + + /** Returns the fallback channel, creating it from the supplier at most once. */ + private ManagedChannel getOrCreateFallback() { + fallback = fallbackSupplier.get(); + return fallback; } @Override @@ -197,7 +259,7 @@ public ClientCall newCall( if (useFallback) { return new FailoverClientCall<>( - fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + getOrCreateFallback().newCall(methodDescriptor, applyFallbackCredentials(callOptions)), true, methodDescriptor.getFullMethodName()); } @@ -211,33 +273,41 @@ public ClientCall newCall( @Override public ManagedChannel shutdown() { primary.shutdown(); - fallback.shutdown(); + if (fallback != null) { + fallback.shutdown(); + } return this; } @Override public ManagedChannel shutdownNow() { primary.shutdownNow(); - fallback.shutdownNow(); + if (fallback != null) { + fallback.shutdownNow(); + } return this; } @Override public boolean isShutdown() { - return primary.isShutdown() && fallback.isShutdown(); + return primary.isShutdown() && (fallback == null || fallback.isShutdown()); } @Override public boolean isTerminated() { - return primary.isTerminated() && fallback.isTerminated(); + return primary.isTerminated() && (fallback == null || fallback.isTerminated()); } @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { long endTimeNanos = nanoClock.getAsLong() + unit.toNanos(timeout); boolean primaryTerminated = primary.awaitTermination(timeout, unit); + ManagedChannel fb = fallback; + if (fb == null) { + return primaryTerminated; + } long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); - return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); + return primaryTerminated && fb.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); } private boolean shouldFallbackBasedOnRPCStatus(Status status, boolean receivedResponse) { @@ -266,7 +336,7 @@ private CallOptions applyFallbackCredentials(CallOptions callOptions) { private void notifyCallDone( Status status, boolean isFallback, String methodName, boolean receivedResponse) { if (!status.isOk() && !isFallback && shouldFallbackBasedOnRPCStatus(status, receivedResponse)) { - if (state.transitionFallback(true, nanoClock.getAsLong())) { + if (state.notePrimaryRpcFailure(nanoClock.getAsLong())) { LOG.warn( "[channel-{}] Primary connection failed for method: {}. Switching to secondary" + " connection. Status: {}", @@ -274,6 +344,10 @@ private void notifyCallDone( methodName, status.getCode()); } + } else if (!isFallback && (status.isOk() || receivedResponse)) { + // Primary RPC succeeded (clean close or received at least one response). + // Reset the failure streak so transient errors don't accumulate toward failover. + state.notePrimaryRpcSuccess(); } else if (isFallback && !status.isOk()) { LOG.warn( "[channel-{}] Secondary connection failed for method: {}. Status: {}", @@ -331,6 +405,16 @@ private void registerPrimaryStateChangeListener() { if (!stateChangeListenerRegistered.getAndSet(true)) { try { ConnectivityState currentState = primary.getState(false); + // Seed failover state from the current connectivity state at registration time. + // Without this, if primary starts in a non-ready state (e.g. TRANSIENT_FAILURE) and + // never transitions, markPrimaryNotReady() would never be called and state-based + // failover would not trigger even after the grace period. + if (currentState == ConnectivityState.READY || currentState == ConnectivityState.IDLE) { + state.markPrimaryReady(); + } else { + // Seed the not-ready timer even if there is no future state transition. + state.markPrimaryNotReady(nanoClock.getAsLong()); + } primary.notifyWhenStateChanged(currentState, this::onPrimaryStateChanged); } catch (Exception e) { LOG.warn( @@ -351,7 +435,7 @@ private void onPrimaryStateChanged() { ConnectivityState newState = primary.getState(false); // IDLE means the channel was READY but has no active RPCs — treat as healthy. if (newState == ConnectivityState.READY || newState == ConnectivityState.IDLE) { - if (state.transitionFallback(false, 0)) { + if (state.markPrimaryReady()) { LOG.info( "[channel-{}] Primary channel recovered; switching back from fallback.", channelId); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index ee72a5d4a993..9c47ddd53776 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -100,6 +100,39 @@ public void testRPCFailureTriggersFallback() throws Exception { verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); } + @Test + public void testRPCFallbackRespectsThirtySecondGracePeriod() throws Exception { + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall firstUnderlyingCall = mock(ClientCall.class); + ClientCall secondUnderlyingCall = mock(ClientCall.class); + ClientCall thirdUnderlyingCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())) + .thenReturn(firstUnderlyingCall, secondUnderlyingCall, thirdUnderlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockChannel.getState(false)).thenReturn(ConnectivityState.READY); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest( + mockChannel, mockFallbackChannel, null, time::get, TimeUnit.SECONDS.toNanos(30)); + + // First failure at t=0 should not trigger fallback. + triggerRPCFailure(failoverChannel, firstUnderlyingCall, Status.UNAVAILABLE); + verify(mockFallbackChannel, never()).newCall(any(), any()); + + // Second failure before 30s should still not trigger fallback. + time.addAndGet(TimeUnit.SECONDS.toNanos(29)); + triggerRPCFailure(failoverChannel, secondUnderlyingCall, Status.UNAVAILABLE); + verify(mockFallbackChannel, never()).newCall(any(), any()); + + // Failure at/after 30s should trigger fallback. + time.addAndGet(TimeUnit.SECONDS.toNanos(1)); + triggerRPCFailure(failoverChannel, thirdUnderlyingCall, Status.UNAVAILABLE); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + @Test public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { // After RPC failure, channel stays on fallback during cooling period, then returns to primary. @@ -179,8 +212,9 @@ public void testFallbackWithCredentials() throws Exception { when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); + FailoverChannel.forTest(mockChannel, mockFallbackChannel, mockCredentials, time::get); triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); @@ -233,6 +267,30 @@ public void testStateFallbackAfterPrimaryNotReady() { verify(mockFallbackChannel).newCall(any(), any()); } + @Test + public void testStateFallbackWhenPrimaryStartsNonReadyWithoutTransition() { + // Primary starts in a non-ready state and stays there. Even without a state transition, + // fallback should start after the 10s grace period. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockChannel.getState(false)).thenReturn(ConnectivityState.TRANSIENT_FAILURE); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Before grace period, still routes to primary. + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds in non-ready state, should route to fallback. + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + } + @Test public void testIdleStateNotTreatedAsFallback() { // IDLE is a normal healthy state (channel is not actively connected but will reconnect on From 5d060fcd3cf92eb485dce2af2f53afb732a8a4b0 Mon Sep 17 00:00:00 2001 From: Sania Parveen Date: Wed, 1 Apr 2026 20:50:42 +0000 Subject: [PATCH 8/8] Removing unused helper method --- .../client/grpc/stubs/FailoverChannel.java | 17 ++------ .../grpc/stubs/FailoverChannelTest.java | 43 ++++++++++++++----- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java index 1dda60142732..faa08c497c8f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -216,19 +216,6 @@ public static FailoverChannel create( RPC_FAILURE_THRESHOLD_NANOS); } - public static FailoverChannel create( - ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { - return create(primary, () -> fallback, fallbackCallCredentials); - } - - static FailoverChannel forTest( - ManagedChannel primary, - ManagedChannel fallback, - CallCredentials fallbackCallCredentials, - LongSupplier nanoClock) { - return forTest(primary, fallback, fallbackCallCredentials, nanoClock, 0L); - } - static FailoverChannel forTest( ManagedChannel primary, ManagedChannel fallback, @@ -241,7 +228,9 @@ static FailoverChannel forTest( /** Returns the fallback channel, creating it from the supplier at most once. */ private ManagedChannel getOrCreateFallback() { - fallback = fallbackSupplier.get(); + if (fallback == null) { + fallback = fallbackSupplier.get(); + } return fallback; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java index 9c47ddd53776..9a46e5bc5489 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -38,6 +38,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; @@ -54,7 +56,6 @@ @RunWith(JUnit4.class) public class FailoverChannelTest { - private MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) @@ -64,7 +65,29 @@ public class FailoverChannelTest { .build(); private static FailoverChannel createForTest(ManagedChannel primary, ManagedChannel fallback) { - return FailoverChannel.forTest(primary, fallback, null, System::nanoTime); + return createForTest(primary, fallback, null, System::nanoTime, null); + } + + private static FailoverChannel createForTest( + ManagedChannel primary, + ManagedChannel fallback, + @Nullable CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + return createForTest(primary, fallback, fallbackCallCredentials, nanoClock, null); + } + + private static FailoverChannel createForTest( + ManagedChannel primary, + ManagedChannel fallback, + @Nullable CallCredentials fallbackCallCredentials, + LongSupplier nanoClock, + @Nullable Long rpcFailureThresholdNanos) { + return FailoverChannel.forTest( + primary, + fallback, + fallbackCallCredentials, + nanoClock, + rpcFailureThresholdNanos != null ? rpcFailureThresholdNanos : 0L); } /** @@ -114,7 +137,7 @@ public void testRPCFallbackRespectsThirtySecondGracePeriod() throws Exception { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest( + createForTest( mockChannel, mockFallbackChannel, null, time::get, TimeUnit.SECONDS.toNanos(30)); // First failure at t=0 should not trigger fallback. @@ -145,7 +168,7 @@ public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); @@ -184,7 +207,7 @@ public void testRPCFallbackClearedByConnectivityRecovery() throws Exception { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); // RPC failure results in entering cooling period triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); @@ -214,7 +237,7 @@ public void testFallbackWithCredentials() throws Exception { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, mockCredentials, time::get); + createForTest(mockChannel, mockFallbackChannel, mockCredentials, time::get); triggerRPCFailure(failoverChannel, underlyingCall, Status.UNAVAILABLE); @@ -252,7 +275,7 @@ public void testStateFallbackAfterPrimaryNotReady() { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); // Callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer. stateChangeCallback.get().run(); @@ -279,7 +302,7 @@ public void testStateFallbackWhenPrimaryStartsNonReadyWithoutTransition() { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); // Before grace period, still routes to primary. failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); @@ -305,7 +328,7 @@ public void testIdleStateNotTreatedAsFallback() { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); // Advance well past the 10-second threshold while primary remains IDLE time.addAndGet(TimeUnit.SECONDS.toNanos(30)); @@ -349,7 +372,7 @@ public void testStateBasedFallbackRecoveryViaCallback() { AtomicLong time = new AtomicLong(0); FailoverChannel failoverChannel = - FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + createForTest(mockChannel, mockFallbackChannel, null, time::get); // First callback fires: primary is TRANSIENT_FAILURE, starts the not-ready timer at t=0. stateChangeCallback.get().run();