diff --git a/gateway-service/src/main/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancer.java b/gateway-service/src/main/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancer.java index ee5e756a83..06d00b646a 100644 --- a/gateway-service/src/main/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancer.java +++ b/gateway-service/src/main/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancer.java @@ -16,6 +16,7 @@ import com.nimbusds.jwt.proc.ExpiredJWTException; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.Strings; import org.springframework.cloud.client.ServiceInstance; import org.springframework.cloud.client.loadbalancer.Request; import org.springframework.cloud.client.loadbalancer.RequestDataContext; @@ -25,7 +26,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.web.server.ResponseStatusException; -import org.zowe.apiml.constants.ApimlConstants; import org.zowe.apiml.gateway.caching.LoadBalancerCache; import org.zowe.apiml.gateway.caching.LoadBalancerCache.LoadBalancerCacheRecord; import reactor.core.publisher.Flux; @@ -35,18 +35,12 @@ import java.text.ParseException; import java.time.Clock; import java.time.LocalDateTime; -import java.util.Base64; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; +import java.util.*; import java.util.stream.Stream; import static org.apache.commons.lang3.StringUtils.isNotBlank; import static org.zowe.apiml.constants.ApimlConstants.X_INSTANCEID; import static reactor.core.publisher.Flux.just; -import static reactor.core.publisher.Mono.empty; /** * A sticky session load balancer that ensures requests from the same user are routed to the same service instance. @@ -54,6 +48,7 @@ @Slf4j public class DeterministicLoadBalancer extends SameInstancePreferenceServiceInstanceListSupplier { + public static final String HEADER_PREFIX = "Bearer "; private static final String HEADER_NONE_SIGNATURE = Base64.getEncoder().encodeToString("{\"typ\":\"JWT\",\"alg\":\"none\"}".getBytes(StandardCharsets.UTF_8)); private final LoadBalancerCache cache; @@ -85,22 +80,47 @@ public Flux> get(Request request) { if (serviceId == null) { return Flux.empty(); } - AtomicReference principal = new AtomicReference<>(); + + var requestContext = request.getContext(); + var instanceId = getInstanceId(requestContext); + if (instanceId != null) { + // if instanceId is set in headers use it + try { + return delegate.get(request) + .map(serviceInstances -> checkInstanceIdHeader(instanceId, serviceInstances)); + } catch (ResponseStatusException ex) { + return Flux.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID")); + } + } + + var userId = getSub(requestContext); + if (userId == null) { + // if no userId is available return all + log.debug("No authentication present on request, not filtering the service: {}", serviceId); + return delegate.get(request); + } + return delegate.get(request) - .flatMap(serviceInstances -> getSub(request.getContext()) - .switchIfEmpty(Mono.just("")) - .flatMap(user -> { - if (user == null || user.isEmpty()) { - log.debug("No authentication present on request, not filtering the service: {}", serviceId); - return empty(); - } else { - principal.set(user); - return cache.retrieve(user, serviceId).onErrorResume(t -> Mono.empty()); - } - }) - .switchIfEmpty(Mono.just(LoadBalancerCacheRecord.NONE)) - .flatMapMany(cacheRecord -> filterInstances(principal.get(), serviceId, cacheRecord, serviceInstances, request.getContext())) - ) + .flatMap(serviceInstances -> { + if (serviceInstances.isEmpty()) { + // no instances available - just return + log.debug("No services selected"); + return Flux.just(serviceInstances); + } + + boolean stickySession = lbTypeIsAuthentication(serviceInstances.iterator().next()); + if (!stickySession) { + // service does not support sticky session by userId, just return + log.debug("Service {} does not support sticky session", serviceId); + return Flux.just(serviceInstances); + } + + log.debug("Obtain service instances for {} from the cache", serviceId); + return cache.retrieve(userId, serviceId) + .onErrorResume(t -> Mono.empty()) + .flatMapMany(cacheRecord -> filterInstances(userId, serviceId, cacheRecord, serviceInstances)) + .switchIfEmpty(Flux.just(serviceInstances)); + }) .doOnError(e -> log.debug("Error in determining service instances", e)); } @@ -115,30 +135,23 @@ private boolean isTooOld(LocalDateTime cachedDate) { return now.isAfter(cachedDate); } - private Mono getSub(Object requestContext) { + private String getSub(Object requestContext) { if (requestContext instanceof RequestDataContext ctx) { var token = Optional.ofNullable(getTokenFromCookie(ctx)) .orElseGet(() -> getTokenFromHeader(ctx)); - return Mono.just(extractSubFromToken(token)); + return extractSubFromToken(token); } - return Mono.just(""); + return null; } private String getTokenFromCookie(RequestDataContext ctx) { - var tokens = ctx.getClientRequest().getCookies().get("apimlAuthenticationToken"); - return tokens == null || tokens.isEmpty() ? null : tokens.get(0); + return ctx.getClientRequest().getCookies().getFirst("apimlAuthenticationToken"); } private String getTokenFromHeader(RequestDataContext ctx) { - var authHeaderValues = ctx.getClientRequest().getHeaders().get(HttpHeaders.AUTHORIZATION); - var token = authHeaderValues == null || authHeaderValues.isEmpty() ? null : authHeaderValues.get(0); - if (token != null && token.startsWith(ApimlConstants.BEARER_AUTHENTICATION_PREFIX)) { - token = token.replaceFirst(ApimlConstants.BEARER_AUTHENTICATION_PREFIX, "").trim(); - if (token.isEmpty()) { - return null; - } - - return token; + var authHeaderValue = ctx.getClientRequest().getHeaders().getFirst(HttpHeaders.AUTHORIZATION); + if (Strings.CS.startsWith(authHeaderValue, HEADER_PREFIX)) { + return authHeaderValue.substring(HEADER_PREFIX.length()); } return null; } @@ -157,27 +170,17 @@ private Flux> filterInstances( String user, String serviceId, LoadBalancerCacheRecord cacheRecord, - List serviceInstances, - Object requestContext) { - - Flux> result; - if (shouldIgnore(serviceInstances, user)) { - var instanceId = getInstanceId(requestContext); - try { - return just(checkInstanceIdHeader(instanceId, serviceInstances)); - } catch (ResponseStatusException ex) { - return Flux.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID")); + List serviceInstances + ) { + if (isNotBlank(cacheRecord.getInstanceId())) { + if (isTooOld(cacheRecord.getCreationTime())) { + return cache.delete(user, serviceId) + .thenMany(chooseOne(user, serviceInstances)); } + return chooseOne(cacheRecord.getInstanceId(), user, serviceInstances); } - if (isNotBlank(cacheRecord.getInstanceId()) && isTooOld(cacheRecord.getCreationTime())) { - result = cache.delete(user, serviceId) - .thenMany(chooseOne(user, serviceInstances)); - } else if (isNotBlank(cacheRecord.getInstanceId())) { - result = chooseOne(cacheRecord.getInstanceId(), user, serviceInstances); - } else { - result = chooseOne(user, serviceInstances); - } - return result; + + return chooseOne(user, serviceInstances); } /** @@ -253,10 +256,6 @@ private Flux> chooseOne(String user, List return chooseOne(null, user, serviceInstances); } - boolean shouldIgnore(List instances, String user) { - return StringUtils.isEmpty(user) || instances.isEmpty() || !lbTypeIsAuthentication(instances.get(0)); - } - private boolean lbTypeIsAuthentication(ServiceInstance instance) { Map metadata = instance.getMetadata(); if (metadata != null) { @@ -305,6 +304,7 @@ private String extractSubFromToken(String token) { return claims.getSubject(); } } - return ""; + return null; } + } diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancerTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancerTest.java index 7eed522b64..0b120a1f31 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancerTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancerTest.java @@ -45,12 +45,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; @ExtendWith(MockitoExtension.class) @TestInstance(Lifecycle.PER_CLASS) @@ -162,7 +158,6 @@ void setUp() { @Test void whenServiceDoesNotHaveMetadata_thenUseDefaultList() { when(instance1.getMetadata()).thenReturn(null); - when(lbCache.retrieve("USER", "service")).thenReturn(Mono.just(LoadBalancerCacheRecord.NONE)); StepVerifier.create(loadBalancer.get(request)) .assertNext(chosenInstances -> { @@ -179,8 +174,6 @@ void whenServiceDoesNotUseSticky_thenUseDefaultList() { metadata.put("apiml.lb.type", "somethingelse"); when(instance1.getMetadata()).thenReturn(metadata); - when(lbCache.retrieve("USER", "service")).thenReturn(Mono.just(LoadBalancerCacheRecord.NONE)); - StepVerifier.create(loadBalancer.get(request)) .assertNext(chosenInstances -> { assertNotNull(chosenInstances); @@ -247,6 +240,9 @@ void whenCacheEntryExpired_thenUpdatePreference() { assertNotNull(chosenInstances); assertEquals(1, chosenInstances.size()); assertEquals("instance1", chosenInstances.get(0).getInstanceId()); + verify(lbCache).retrieve("USER", "service"); + verify(lbCache).delete("USER", "service"); + verify(lbCache).store(eq("USER"), eq("service"), any()); }) .expectComplete() .verify(); @@ -266,6 +262,8 @@ void whenNoPreferece_thenCreateOne() { assertNotNull(chosenInstances); assertEquals(1, chosenInstances.size()); assertEquals("instance1", chosenInstances.get(0).getInstanceId()); + + verify(lbCache).retrieve("USER", "service"); }) .expectComplete() .verify(); @@ -368,15 +366,11 @@ class GivenInstanceIdHeaderIsPresent { @BeforeEach void setUp() { var context = new RequestDataContext(requestData); - MultiValueMap cookie = new LinkedMultiValueMap<>(); - cookie.add("apimlAuthenticationToken", "invalidToken"); - when(request.getContext()).thenReturn(context); - when(requestData.getCookies()).thenReturn(cookie); } @Test - void whenInstanceIdExists_thenChoseeIt() { + void whenInstanceIdExists_thenChooseIt() { var headers = new HttpHeaders(); headers.add("X-InstanceId", "instance2"); when(requestData.getHeaders()).thenReturn(headers); @@ -386,6 +380,21 @@ void whenInstanceIdExists_thenChoseeIt() { assertNotNull(chosenInstances); assertEquals(1, chosenInstances.size()); assertEquals("instance2", chosenInstances.get(0).getInstanceId()); + verify(lbCache, never()).retrieve(any(), any()); + }) + .expectComplete() + .verify(); + } + + @Test + void whenNoToken_thenDoNotCallCache() { + when(requestData.getHeaders()).thenReturn(new HttpHeaders()); + when(requestData.getCookies()).thenReturn(new LinkedMultiValueMap<>()); + + StepVerifier.create(loadBalancer.get(request)) + .assertNext(chosenInstances -> { + assertNotNull(chosenInstances); + verify(lbCache, never()).retrieve(any(), any()); }) .expectComplete() .verify();