diff --git a/backend/auth_handler.py b/backend/auth_handler.py
new file mode 100644
index 000000000..3afb37147
--- /dev/null
+++ b/backend/auth_handler.py
@@ -0,0 +1,192 @@
+import json
+import logging
+import os
+import urllib.request
+import urllib.parse
+import base64
+from http.cookies import SimpleCookie
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))
+
+
+def handler(event, context):
+ """Main Lambda handler - routes requests to appropriate function"""
+ path = event.get('path', '')
+ method = event.get('httpMethod', '')
+
+ if path == '/auth/token-exchange' and method == 'POST':
+ return token_exchange_handler(event)
+ elif path == '/auth/logout' and method == 'POST':
+ return logout_handler(event)
+ elif path == '/auth/userinfo' and method == 'GET':
+ return userinfo_handler(event)
+ else:
+ return error_response(404, 'Not Found', event)
+
+
+def error_response(status_code, message, event=None):
+ """Return error response with CORS headers"""
+ response = {
+ 'statusCode': status_code,
+ 'headers': get_cors_headers(event) if event else {'Content-Type': 'application/json'},
+ 'body': json.dumps({'error': message}),
+ }
+ return response
+
+
+def get_cors_headers(event):
+ """Get CORS headers for response"""
+ cloudfront_url = os.environ.get('CLOUDFRONT_URL', '')
+ return {
+ 'Content-Type': 'application/json',
+ 'Access-Control-Allow-Origin': cloudfront_url,
+ 'Access-Control-Allow-Credentials': 'true',
+ 'Access-Control-Allow-Methods': 'GET, POST, OPTIONS',
+ 'Access-Control-Allow-Headers': 'Content-Type',
+ }
+
+
+def token_exchange_handler(event):
+ """Exchange authorization code for tokens and set httpOnly cookies"""
+ try:
+ body = json.loads(event.get('body', '{}'))
+ code = body.get('code')
+ code_verifier = body.get('code_verifier')
+
+ if not code or not code_verifier:
+ return error_response(400, 'Missing code or code_verifier', event)
+
+ okta_url = os.environ.get('CUSTOM_AUTH_URL', '')
+ client_id = os.environ.get('CUSTOM_AUTH_CLIENT_ID', '')
+ redirect_uri = os.environ.get('CUSTOM_AUTH_REDIRECT_URL', '')
+
+ if not okta_url or not client_id:
+ return error_response(500, 'Missing Okta configuration', event)
+
+ # Call Okta token endpoint
+ token_url = f'{okta_url}/v1/token'
+ token_data = {
+ 'grant_type': 'authorization_code',
+ 'code': code,
+ 'code_verifier': code_verifier,
+ 'client_id': client_id,
+ 'redirect_uri': redirect_uri,
+ }
+
+ data = urllib.parse.urlencode(token_data).encode('utf-8')
+ req = urllib.request.Request(
+ token_url,
+ data=data,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
+ )
+
+ try:
+ with urllib.request.urlopen(req, timeout=10) as response:
+ tokens = json.loads(response.read().decode('utf-8'))
+ except urllib.error.HTTPError as e:
+ error_body = e.read().decode('utf-8')
+ logger.error(f'Token exchange failed: {error_body}')
+ return error_response(401, 'Authentication failed. Please try again.', event)
+
+ cookies = build_cookies(tokens)
+
+ return {
+ 'statusCode': 200,
+ 'headers': get_cors_headers(event),
+ 'multiValueHeaders': {'Set-Cookie': cookies},
+ 'body': json.dumps({'success': True}),
+ }
+
+ except Exception as e:
+ logger.error(f'Token exchange error: {str(e)}')
+ return error_response(500, 'Internal server error', event)
+
+
+def build_cookies(tokens):
+ """Build httpOnly cookies for tokens"""
+ cookies = []
+ secure = True
+ httponly = True
+ samesite = 'Lax'
+ max_age = 3600 # 1 hour
+
+ for token_name in ['access_token', 'id_token']:
+ if tokens.get(token_name):
+ cookie = SimpleCookie()
+ cookie[token_name] = tokens[token_name]
+ cookie[token_name]['path'] = '/'
+ cookie[token_name]['secure'] = secure
+ cookie[token_name]['httponly'] = httponly
+ cookie[token_name]['samesite'] = samesite
+ cookie[token_name]['max-age'] = max_age
+ cookies.append(cookie[token_name].OutputString())
+
+ return cookies
+
+
+def logout_handler(event):
+ """Clear all auth cookies"""
+ cookies = []
+ for cookie_name in ['access_token', 'id_token', 'refresh_token']:
+ cookie = SimpleCookie()
+ cookie[cookie_name] = ''
+ cookie[cookie_name]['path'] = '/'
+ cookie[cookie_name]['max-age'] = 0
+ cookies.append(cookie[cookie_name].OutputString())
+
+ return {
+ 'statusCode': 200,
+ 'headers': get_cors_headers(event),
+ 'multiValueHeaders': {'Set-Cookie': cookies},
+ 'body': json.dumps({'success': True}),
+ }
+
+
+def userinfo_handler(event):
+ """Return user info from id_token cookie"""
+ try:
+ cookie_header = event.get('headers', {}).get('Cookie') or event.get('headers', {}).get('cookie', '')
+ cookies = SimpleCookie()
+ cookies.load(cookie_header)
+
+ id_token_cookie = cookies.get('id_token')
+ if not id_token_cookie:
+ return error_response(401, 'Not authenticated', event)
+
+ id_token = id_token_cookie.value
+
+ # Decode JWT payload
+ parts = id_token.split('.')
+ if len(parts) != 3:
+ return error_response(401, 'Invalid token format', event)
+
+ payload = parts[1]
+ padding = 4 - len(payload) % 4
+ if padding != 4:
+ payload += '=' * padding
+
+ decoded = base64.urlsafe_b64decode(payload)
+ claims = json.loads(decoded)
+
+ email_claim = os.environ.get('CLAIMS_MAPPING_EMAIL', 'email')
+ user_id_claim = os.environ.get('CLAIMS_MAPPING_USER_ID', 'sub')
+
+ email = claims.get(email_claim, claims.get('email', claims.get('sub', '')))
+ user_id = claims.get(user_id_claim, claims.get('sub', ''))
+
+ return {
+ 'statusCode': 200,
+ 'headers': get_cors_headers(event),
+ 'body': json.dumps(
+ {
+ 'email': email,
+ 'name': claims.get('name', email),
+ 'sub': user_id,
+ }
+ ),
+ }
+
+ except Exception as e:
+ logger.error(f'Userinfo error: {str(e)}')
+ return error_response(500, 'Internal server error', event)
diff --git a/deploy/custom_resources/custom_authorizer/custom_authorizer_lambda.py b/deploy/custom_resources/custom_authorizer/custom_authorizer_lambda.py
index 153191946..20593fffd 100644
--- a/deploy/custom_resources/custom_authorizer/custom_authorizer_lambda.py
+++ b/deploy/custom_resources/custom_authorizer/custom_authorizer_lambda.py
@@ -1,5 +1,6 @@
import logging
import os
+from http.cookies import SimpleCookie
from requests import HTTPError
@@ -23,10 +24,32 @@
def lambda_handler(incoming_event, context):
- # Get the Token which is sent in the Authorization Header
+ # Get the Token - first try Cookie header, then Authorization header
logger.debug(incoming_event)
- auth_token = incoming_event['headers']['Authorization']
+ headers = incoming_event.get('headers', {})
+
+ # Try to get access_token from Cookie header first (for cookie-based auth)
+ auth_token = None
+ cookie_header = headers.get('Cookie') or headers.get('cookie', '')
+
+ if cookie_header:
+ # Parse cookies to find access_token
+ cookies = SimpleCookie()
+ cookies.load(cookie_header)
+ access_token_cookie = cookies.get('access_token')
+ if access_token_cookie:
+ # Add Bearer prefix for consistency with existing validation
+ auth_token = f'Bearer {access_token_cookie.value}'
+ logger.debug('Using access_token from Cookie header')
+
+ # Fallback to Authorization header (for backward compatibility)
+ if not auth_token:
+ auth_token = headers.get('Authorization') or headers.get('authorization')
+ if auth_token:
+ logger.debug('Using token from Authorization header')
+
if not auth_token:
+ logger.warning('No authentication token found in Cookie or Authorization header')
return AuthServices.generate_deny_policy(incoming_event['methodArn'])
# Validate User is Active with Proper Access Token
diff --git a/deploy/stacks/cloudfront.py b/deploy/stacks/cloudfront.py
index d61f3208c..bd62585c5 100644
--- a/deploy/stacks/cloudfront.py
+++ b/deploy/stacks/cloudfront.py
@@ -11,6 +11,7 @@
Duration,
RemovalPolicy,
CfnOutput,
+ Fn,
)
from .cdk_asset_trail import setup_cdk_asset_trail
@@ -30,6 +31,7 @@ def __init__(
custom_waf_rules=None,
tooling_account_id=None,
backend_region=None,
+ custom_auth=None,
**kwargs,
):
super().__init__(scope, id, **kwargs)
@@ -166,6 +168,55 @@ def __init__(
log_file_prefix='cloudfront-logs/frontend',
)
+ # Add API Gateway behaviors for cookie-based authentication (when using custom_auth)
+ if custom_auth and backend_region:
+ # Get API Gateway URL from SSM parameter (set by backend stack)
+ api_gateway_url_param = ssm.StringParameter.from_string_parameter_name(
+ self,
+ 'ApiGatewayUrlParam',
+ string_parameter_name=f'/dataall/{envname}/apiGateway/backendUrl',
+ )
+
+ # Extract API Gateway domain from URL using CloudFormation intrinsic functions
+ # Input: https://xyz123.execute-api.us-east-1.amazonaws.com/prod/
+ # Split by '/': ['https:', '', 'xyz123.execute-api.us-east-1.amazonaws.com', 'prod', '']
+ # Select index 2: 'xyz123.execute-api.us-east-1.amazonaws.com'
+ api_gateway_origin = origins.HttpOrigin(
+ domain_name=Fn.select(2, Fn.split('/', api_gateway_url_param.string_value)),
+ origin_path='/prod',
+ protocol_policy=cloudfront.OriginProtocolPolicy.HTTPS_ONLY,
+ )
+
+ # Add behavior for /auth/* routes (token exchange, userinfo, logout)
+ cloudfront_distribution.add_behavior(
+ path_pattern='/auth/*',
+ origin=api_gateway_origin,
+ cache_policy=cloudfront.CachePolicy.CACHING_DISABLED,
+ origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER_EXCEPT_HOST_HEADER,
+ allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL,
+ viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.HTTPS_ONLY,
+ )
+
+ # Add behavior for /graphql/* routes
+ cloudfront_distribution.add_behavior(
+ path_pattern='/graphql/*',
+ origin=api_gateway_origin,
+ cache_policy=cloudfront.CachePolicy.CACHING_DISABLED,
+ origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER_EXCEPT_HOST_HEADER,
+ allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL,
+ viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.HTTPS_ONLY,
+ )
+
+ # Add behavior for /search/* routes
+ cloudfront_distribution.add_behavior(
+ path_pattern='/search/*',
+ origin=api_gateway_origin,
+ cache_policy=cloudfront.CachePolicy.CACHING_DISABLED,
+ origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER_EXCEPT_HOST_HEADER,
+ allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL,
+ viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.HTTPS_ONLY,
+ )
+
ssm_distribution_id = ssm.StringParameter(
self,
f'SSMDistribution{envname}',
@@ -276,16 +327,12 @@ def __init__(
@staticmethod
def error_responses():
+ # Only intercept 404 for SPA routing (redirect to index.html)
+ # Do NOT intercept 403 - let API Gateway errors pass through
return [
cloudfront.ErrorResponse(
http_status=404,
- response_http_status=404,
- ttl=Duration.seconds(0),
- response_page_path='/index.html',
- ),
- cloudfront.ErrorResponse(
- http_status=403,
- response_http_status=403,
+ response_http_status=200,
ttl=Duration.seconds(0),
response_page_path='/index.html',
),
diff --git a/deploy/stacks/lambda_api.py b/deploy/stacks/lambda_api.py
index a73f18726..17c448db1 100644
--- a/deploy/stacks/lambda_api.py
+++ b/deploy/stacks/lambda_api.py
@@ -23,7 +23,7 @@
BundlingOptions,
)
from cdk_klayers import Klayers
-from aws_cdk.aws_apigateway import EndpointType, SecurityPolicy
+from aws_cdk.aws_apigateway import DomainNameOptions, EndpointType, SecurityPolicy
from aws_cdk.aws_certificatemanager import Certificate
from aws_cdk.aws_ec2 import (
InterfaceVpcEndpoint,
@@ -35,7 +35,6 @@
from .pyNestedStack import pyNestedClass
from .solution_bundling import SolutionBundling
from .waf_rules import get_waf_rules
-from .runtime_options import PYTHON_LAMBDA_RUNTIME
DEFAULT_API_RATE_LIMIT = 10000
DEFAULT_API_BURST_LIMIT = 5000
@@ -159,6 +158,8 @@ def __init__(
api_handler_env['frontend_domain_url'] = f'https://{custom_domain.get("hosted_zone_name", None)}'
if custom_auth:
api_handler_env['custom_auth'] = custom_auth.get('provider', None)
+ api_handler_env['custom_auth_url'] = custom_auth.get('url', None)
+ api_handler_env['custom_auth_client'] = custom_auth.get('client_id', None)
self.api_handler = _lambda.DockerImageFunction(
self,
'LambdaGraphQL',
@@ -242,6 +243,66 @@ def __init__(
)
)
+ # Auth handler Lambda for cookie-based authentication
+ self.auth_handler_dlq = self.set_dlq(f'{resource_prefix}-{envname}-authhandler-dlq')
+ auth_handler_sg = self.create_lambda_sgs(envname, 'authhandler', resource_prefix, vpc)
+
+ # Get CloudFront URL from custom_domain config or use default
+ if custom_domain and custom_domain.get('hosted_zone_name'):
+ cloudfront_url = f'https://{custom_domain.get("hosted_zone_name")}'
+ else:
+ cloudfront_url = '' # Must be configured via custom_domain in cdk.json
+
+ auth_handler_env = {
+ 'envname': envname,
+ 'LOG_LEVEL': log_level,
+ 'CLOUDFRONT_URL': cloudfront_url,
+ }
+
+ # Add custom auth config for token exchange with Okta
+ if custom_auth:
+ auth_handler_env['CUSTOM_AUTH_URL'] = custom_auth.get('url', '')
+ auth_handler_env['CUSTOM_AUTH_CLIENT_ID'] = custom_auth.get('client_id', '')
+ auth_handler_env['CUSTOM_AUTH_REDIRECT_URL'] = custom_auth.get('redirect_url', cloudfront_url + '/callback')
+ # Pass claims mapping for user info extraction
+ claims_mapping = custom_auth.get('claims_mapping', {})
+ auth_handler_env['CLAIMS_MAPPING_EMAIL'] = claims_mapping.get('email', 'email')
+ auth_handler_env['CLAIMS_MAPPING_USER_ID'] = claims_mapping.get('user_id', 'sub')
+
+ self.auth_handler = _lambda.DockerImageFunction(
+ self,
+ 'AuthHandler',
+ function_name=f'{resource_prefix}-{envname}-authhandler',
+ log_group=logs.LogGroup(
+ self,
+ 'authhandlerloggroup',
+ log_group_name=f'/aws/lambda/{resource_prefix}-{envname}-backend-authhandler',
+ retention=getattr(logs.RetentionDays, self.log_retention_duration),
+ ),
+ description='dataall auth handler for cookie-based authentication',
+ role=self.create_function_role(envname, resource_prefix, 'authhandler', pivot_role_name, vpc),
+ code=_lambda.DockerImageCode.from_ecr(
+ repository=ecr_repository, tag=image_tag, cmd=['auth_handler.handler']
+ ),
+ vpc=vpc,
+ security_groups=[auth_handler_sg],
+ memory_size=512 if prod_sizing else 256,
+ timeout=Duration.seconds(30),
+ environment=auth_handler_env,
+ environment_encryption=lambda_env_key,
+ dead_letter_queue_enabled=True,
+ dead_letter_queue=self.auth_handler_dlq,
+ on_failure=lambda_destination.SqsDestination(self.auth_handler_dlq),
+ tracing=_lambda.Tracing.ACTIVE,
+ logging_format=_lambda.LoggingFormat.JSON,
+ application_log_level_v2=getattr(_lambda.ApplicationLogLevel, log_level),
+ )
+
+ # Allow auth handler to access internet (for Okta API calls)
+ self.auth_handler.connections.allow_to(
+ ec2.Peer.any_ipv4(), ec2.Port.tcp(443), 'Allow NAT Internet Access for Okta'
+ )
+
# Create the custom authorizer lambda
custom_authorizer_assets = os.path.realpath(
os.path.join(
@@ -283,7 +344,8 @@ def __init__(
)
# Initialize Klayers
- klayers = Klayers(self, python_version=PYTHON_LAMBDA_RUNTIME, region=self.region)
+ runtime = _lambda.Runtime.PYTHON_3_12
+ klayers = Klayers(self, python_version=runtime, region=self.region)
# get the latest layer version for the cryptography package
cryptography_layer = klayers.layer_version(self, 'cryptography')
@@ -303,7 +365,7 @@ def __init__(
code=_lambda.Code.from_asset(
path=custom_authorizer_assets,
bundling=BundlingOptions(
- image=PYTHON_LAMBDA_RUNTIME.bundling_image,
+ image=_lambda.Runtime.PYTHON_3_9.bundling_image,
local=SolutionBundling(source_path=custom_authorizer_assets),
),
),
@@ -314,7 +376,7 @@ def __init__(
environment_encryption=lambda_env_key,
vpc=vpc,
security_groups=[authorizer_fn_sg],
- runtime=PYTHON_LAMBDA_RUNTIME,
+ runtime=runtime,
layers=[cryptography_layer],
logging_format=_lambda.LoggingFormat.JSON,
application_log_level_v2=getattr(_lambda.ApplicationLogLevel, log_level),
@@ -368,6 +430,7 @@ def __init__(
user_pool,
custom_auth,
throttling_config,
+ custom_domain,
)
self.create_sns_topic(
@@ -540,6 +603,7 @@ def create_api_gateway(
user_pool,
custom_auth,
throttling_config,
+ custom_domain,
):
api_deploy_options = apigw.StageOptions(
throttling_rate_limit=throttling_config.get('global_rate_limit', DEFAULT_API_RATE_LIMIT),
@@ -563,6 +627,7 @@ def create_api_gateway(
resource_prefix,
user_pool,
custom_auth,
+ custom_domain,
)
# Create IP set if IP filtering enabled in CDK.json
@@ -623,6 +688,7 @@ def set_up_graphql_api_gateway(
resource_prefix,
user_pool,
custom_auth,
+ custom_domain,
):
# Create a custom Authorizer
custom_authorizer_role = iam.Role(
@@ -644,10 +710,13 @@ def set_up_graphql_api_gateway(
self,
'CustomAuthorizer',
handler=self.authorizer_fn,
- identity_sources=[apigw.IdentitySource.header('Authorization')],
+ # Empty identity_sources allows Lambda to be invoked without specific headers
+ # This enables cookie-based auth where tokens come from Cookie header
+ identity_sources=[],
authorizer_name=f'{resource_prefix}-{envname}-custom-authorizer',
assume_role=custom_authorizer_role,
- results_cache_ttl=Duration.minutes(1),
+ # Disable caching to ensure cookies are read on every request
+ results_cache_ttl=Duration.seconds(0),
)
if not internet_facing:
if apig_vpce:
@@ -829,6 +898,64 @@ def set_up_graphql_api_gateway(
request_models={'application/json': search_validation_model},
)
+ # Auth routes for cookie-based authentication
+ auth_integration = apigw.LambdaIntegration(self.auth_handler)
+ auth = gw.root.add_resource(path_part='auth')
+
+ # Get CloudFront URL for CORS (use custom domain if available)
+ if custom_domain and custom_domain.get('hosted_zone_name'):
+ cors_origin = f'https://{custom_domain.get("hosted_zone_name")}'
+ else:
+ cors_origin = '' # Must be configured via custom_domain in cdk.json
+
+ # Token exchange route - NO authorization (public endpoint for OAuth callback)
+ token_exchange = auth.add_resource(
+ path_part='token-exchange',
+ default_cors_preflight_options=apigw.CorsOptions(
+ allow_methods=['POST', 'OPTIONS'],
+ allow_origins=[cors_origin],
+ allow_credentials=True,
+ allow_headers=['Content-Type'],
+ ),
+ )
+ token_exchange.add_method(
+ 'POST',
+ auth_integration,
+ authorization_type=apigw.AuthorizationType.NONE,
+ )
+
+ # Logout route - NO authorization (needs to work even with expired tokens)
+ logout = auth.add_resource(
+ path_part='logout',
+ default_cors_preflight_options=apigw.CorsOptions(
+ allow_methods=['POST', 'OPTIONS'],
+ allow_origins=[cors_origin],
+ allow_credentials=True,
+ allow_headers=['Content-Type'],
+ ),
+ )
+ logout.add_method(
+ 'POST',
+ auth_integration,
+ authorization_type=apigw.AuthorizationType.NONE,
+ )
+
+ # Userinfo route - NO authorization (Lambda reads cookies and validates)
+ userinfo = auth.add_resource(
+ path_part='userinfo',
+ default_cors_preflight_options=apigw.CorsOptions(
+ allow_methods=['GET', 'OPTIONS'],
+ allow_origins=[cors_origin],
+ allow_credentials=True,
+ allow_headers=['Content-Type'],
+ ),
+ )
+ userinfo.add_method(
+ 'GET',
+ auth_integration,
+ authorization_type=apigw.AuthorizationType.NONE,
+ )
+
apigateway_log_group = logs.LogGroup(
self,
f'{resource_prefix}/{envname}/apigateway',
diff --git a/frontend/src/authentication/contexts/GenericAuthContext.js b/frontend/src/authentication/contexts/GenericAuthContext.js
index 8dee6a86c..698b009ca 100644
--- a/frontend/src/authentication/contexts/GenericAuthContext.js
+++ b/frontend/src/authentication/contexts/GenericAuthContext.js
@@ -1,13 +1,13 @@
import { createContext, useEffect, useReducer } from 'react';
import { SET_ERROR } from 'globalErrors';
import PropTypes from 'prop-types';
-import { useAuth } from 'react-oidc-context';
import {
fetchAuthSession,
fetchUserAttributes,
signInWithRedirect,
signOut
} from 'aws-amplify/auth';
+import { generatePKCE, generateState } from '../../utils/pkce';
const CUSTOM_AUTH = process.env.REACT_APP_CUSTOM_AUTH;
@@ -70,10 +70,6 @@ export const GenericAuthContext = createContext({
export const GenericAuthProvider = (props) => {
const { children } = props;
const [state, dispatch] = useReducer(reducer, initialState);
- const auth = useAuth();
- const isLoading = auth ? auth.isLoading : false;
- const userProfile = auth ? auth.user : null;
- const authEvents = auth ? auth.events : null;
useEffect(() => {
const initialize = async () => {
@@ -94,109 +90,40 @@ export const GenericAuthProvider = (props) => {
}
}
});
- } catch (error) {
- if (CUSTOM_AUTH) {
- processLoadingStateChange();
- } else {
- dispatch({
- type: 'INITIALIZE',
- payload: {
- isAuthenticated: false,
- isInitialized: true,
- user: null
- }
- });
- }
- }
- };
-
- initialize().catch((e) => dispatch({ type: SET_ERROR, error: e.message }));
- }, []);
-
- // useEffect needed for React OIDC context
- // Process OIDC state when isLoading state changes
- useEffect(() => {
- if (CUSTOM_AUTH) {
- processLoadingStateChange();
- }
- }, [isLoading]);
-
- // useEffect to process when a user is loaded by react OIDC
- // This is triggered when the userProfile ( i.e. auth.user ) is loaded by react OIDC
- useEffect(() => {
- const processStateChange = async () => {
- try {
- const user = await getAuthenticatedUser();
- dispatch({
- type: 'LOGIN',
- payload: {
- user: {
- id: user.email,
- email: user.email,
- name: user.email,
- id_token: user.id_token,
- short_id: user.short_id,
- access_token: user.access_token
- }
- }
- });
} catch (error) {
dispatch({
- type: 'LOGOUT',
+ type: 'INITIALIZE',
payload: {
isAuthenticated: false,
+ isInitialized: true,
user: null
}
});
}
};
- if (CUSTOM_AUTH) {
- processStateChange().catch((e) =>
- dispatch({ type: SET_ERROR, error: e.message })
- );
- }
- }, [userProfile]);
-
- // useEffect to process auth events generated by react OIDC
- // This is used to logout user when the token expires
- useEffect(() => {
- if (CUSTOM_AUTH) {
- return auth.events.addAccessTokenExpired(() => {
- auth.signoutSilent().then((r) => {
- dispatch({
- type: 'LOGOUT',
- payload: {
- isAuthenticated: false,
- user: null
- }
- });
- });
- });
- }
- }, [authEvents]);
+ initialize().catch((e) => dispatch({ type: SET_ERROR, error: e.message }));
+ }, []);
const getAuthenticatedUser = async () => {
if (CUSTOM_AUTH) {
- if (!auth.user) throw Error('User not initialized');
+ // Use relative URL - CloudFront proxies to API Gateway (same-origin)
+ const response = await fetch('/auth/userinfo', {
+ credentials: 'include'
+ });
+ if (!response.ok) throw Error('User not authenticated');
+ const user = await response.json();
return {
- email:
- auth.user.profile[
- process.env.REACT_APP_CUSTOM_AUTH_EMAIL_CLAIM_MAPPING
- ],
- id_token: auth.user.id_token,
- access_token: auth.user.access_token,
- short_id:
- auth.user.profile[
- process.env.REACT_APP_CUSTOM_AUTH_USERID_CLAIM_MAPPING
- ]
+ email: user.email,
+ id_token: 'cookie',
+ access_token: 'cookie',
+ short_id: user.sub
};
} else {
const [session, attrs] = await Promise.all([
fetchAuthSession(),
fetchUserAttributes()
]);
-
return {
email: attrs.email,
id_token: session.tokens.idToken.toString(),
@@ -206,39 +133,31 @@ export const GenericAuthProvider = (props) => {
}
};
- // Function to process OIDC State when it transitions from false to true
- function processLoadingStateChange() {
- if (isLoading) {
- dispatch({
- type: 'INITIALIZE',
- payload: {
- isAuthenticated: false,
- isInitialized: false, // setting to false when the OIDC State is loading
- user: null
- }
- });
- } else {
- dispatch({
- type: 'INITIALIZE',
- payload: {
- isAuthenticated: false,
- isInitialized: true, // setting to true when the OIDC state is completely loaded
- user: null
- }
- });
- }
- }
-
const login = async () => {
try {
if (CUSTOM_AUTH) {
- await auth.signinRedirect();
+ const { verifier, challenge } = await generatePKCE();
+ const state = generateState();
+
+ sessionStorage.setItem('pkce_verifier', verifier);
+ sessionStorage.setItem('pkce_state', state);
+
+ const params = new URLSearchParams({
+ client_id: process.env.REACT_APP_CUSTOM_AUTH_CLIENT_ID,
+ redirect_uri: window.location.origin + '/callback',
+ response_type: 'code',
+ scope: process.env.REACT_APP_CUSTOM_AUTH_SCOPES,
+ code_challenge: challenge,
+ code_challenge_method: 'S256',
+ state
+ });
+
+ window.location.href = `${process.env.REACT_APP_CUSTOM_AUTH_URL}/v1/authorize?${params}`;
} else {
await signInWithRedirect();
}
} catch (error) {
if (error.name === 'UserAlreadyAuthenticatedException') {
- // User is already authenticated, ignore this error
return;
}
console.error('Failed to authenticate user', error);
@@ -248,7 +167,8 @@ export const GenericAuthProvider = (props) => {
const logout = async () => {
try {
if (CUSTOM_AUTH) {
- await auth.signoutSilent();
+ // Use relative URL - CloudFront proxies to API Gateway (same-origin)
+ await fetch('/auth/logout', { method: 'POST', credentials: 'include' });
dispatch({
type: 'LOGOUT',
payload: {
@@ -256,6 +176,8 @@ export const GenericAuthProvider = (props) => {
user: null
}
});
+ sessionStorage.clear();
+ window.location.href = window.location.origin;
} else {
await signOut({ global: true });
dispatch({
@@ -265,8 +187,8 @@ export const GenericAuthProvider = (props) => {
user: null
}
});
+ sessionStorage.removeItem('window-location');
}
- sessionStorage.removeItem('window-location');
} catch (error) {
console.error('Failed to signout', error);
}
@@ -275,14 +197,13 @@ export const GenericAuthProvider = (props) => {
const reauth = async () => {
if (CUSTOM_AUTH) {
try {
- auth.signoutSilent().then((r) => {
- dispatch({
- type: 'REAUTH',
- payload: {
- reAuthStatus: false,
- requestInfo: null
- }
- });
+ await logout();
+ dispatch({
+ type: 'REAUTH',
+ payload: {
+ reAuthStatus: false,
+ requestInfo: null
+ }
});
} catch (error) {
console.error('Failed to ReAuth', error);
@@ -296,8 +217,8 @@ export const GenericAuthProvider = (props) => {
requestInfo: null
}
});
+ sessionStorage.removeItem('window-location');
}
- sessionStorage.removeItem('window-location');
};
return (
@@ -309,7 +230,7 @@ export const GenericAuthProvider = (props) => {
login,
logout,
reauth,
- isLoading
+ isLoading: !state.isInitialized
}}
>
{children}
@@ -317,6 +238,6 @@ export const GenericAuthProvider = (props) => {
);
};
-GenericAuthContext.propTypes = {
+GenericAuthProvider.propTypes = {
children: PropTypes.node.isRequired
};
diff --git a/frontend/src/authentication/views/Callback.js b/frontend/src/authentication/views/Callback.js
new file mode 100644
index 000000000..5daeac5d6
--- /dev/null
+++ b/frontend/src/authentication/views/Callback.js
@@ -0,0 +1,99 @@
+import { useEffect, useState } from 'react';
+import { useNavigate } from 'react-router-dom';
+import { Box, CircularProgress, Typography } from '@mui/material';
+
+const Callback = () => {
+ const navigate = useNavigate();
+ const [error, setError] = useState(null);
+
+ useEffect(() => {
+ const exchangeCode = async () => {
+ try {
+ const params = new URLSearchParams(window.location.search);
+ const code = params.get('code');
+ const state = params.get('state');
+ const errorParam = params.get('error');
+
+ if (errorParam) {
+ throw new Error(params.get('error_description') || errorParam);
+ }
+
+ if (!code) {
+ throw new Error('No authorization code received');
+ }
+
+ // Verify state matches
+ const savedState = sessionStorage.getItem('pkce_state');
+ if (state !== savedState) {
+ throw new Error('State mismatch - possible CSRF attack');
+ }
+
+ // Get code verifier
+ const codeVerifier = sessionStorage.getItem('pkce_verifier');
+ if (!codeVerifier) {
+ throw new Error('No code verifier found');
+ }
+
+ // Exchange code for tokens via backend
+ // Add AbortController for timeout
+ const controller = new AbortController();
+ const timeoutId = setTimeout(() => controller.abort(), 30000); // 30 second timeout
+
+ try {
+ const response = await fetch('/auth/token-exchange', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ credentials: 'include',
+ body: JSON.stringify({
+ code,
+ code_verifier: codeVerifier
+ }),
+ signal: controller.signal
+ });
+ clearTimeout(timeoutId);
+
+ if (!response.ok) {
+ const data = await response.json();
+ throw new Error(data.error || 'Token exchange failed');
+ }
+ } catch (fetchErr) {
+ clearTimeout(timeoutId);
+ if (fetchErr.name === 'AbortError') {
+ throw new Error('Request timed out. Please try again.');
+ }
+ throw fetchErr;
+ }
+
+ // Clear PKCE values
+ sessionStorage.removeItem('pkce_verifier');
+ sessionStorage.removeItem('pkce_state');
+
+ // Redirect to app
+ navigate('/console/environments', { replace: true });
+ } catch (err) {
+ console.error('Callback error:', err);
+ setError(err.message);
+ }
+ };
+
+ exchangeCode();
+ }, [navigate]);
+
+ if (error) {
+ return (
+
+ Authentication Error
+ {error}
+
+ );
+ }
+
+ return (
+
+
+ Completing sign in...
+
+ );
+};
+
+export default Callback;
diff --git a/frontend/src/routes.js b/frontend/src/routes.js
index 502f91abf..f9b00a248 100644
--- a/frontend/src/routes.js
+++ b/frontend/src/routes.js
@@ -13,6 +13,9 @@ const Loadable = (Component) => (props) =>
// Authentication pages
const Login = Loadable(lazy(() => import('./authentication/views/Login')));
+const Callback = Loadable(
+ lazy(() => import('./authentication/views/Callback'))
+);
// Error pages
const NotFound = Loadable(
@@ -206,6 +209,10 @@ const routes = [
)
+ },
+ {
+ path: 'callback',
+ element:
}
]
},
diff --git a/frontend/src/services/hooks/useClient.js b/frontend/src/services/hooks/useClient.js
index 9e20d4619..cdc89b17e 100644
--- a/frontend/src/services/hooks/useClient.js
+++ b/frontend/src/services/hooks/useClient.js
@@ -47,18 +47,30 @@ export const useClient = () => {
useEffect(() => {
const initClient = async () => {
const t = token;
+ const CUSTOM_AUTH = process.env.REACT_APP_CUSTOM_AUTH;
+
+ // Use relative URL for custom auth (CloudFront proxy), otherwise use env var
+ const graphqlUri = CUSTOM_AUTH
+ ? '/graphql/api'
+ : process.env.REACT_APP_GRAPHQL_API;
+
const httpLink = new HttpLink({
- uri: process.env.REACT_APP_GRAPHQL_API
+ uri: graphqlUri,
+ // Include credentials for cookie-based auth
+ credentials: CUSTOM_AUTH ? 'include' : 'same-origin'
});
const authLink = new ApolloLink((operation, forward) => {
- operation.setContext({
- headers: {
- Authorization: t ? `Bearer ${t}` : '',
- AccessKeyId: 'none',
- SecretKey: 'none'
- }
- });
+ // For custom auth, cookies are sent automatically via credentials: 'include'
+ // For Cognito, use Authorization header
+ const headers = CUSTOM_AUTH
+ ? { AccessKeyId: 'none', SecretKey: 'none' }
+ : {
+ Authorization: t ? `Bearer ${t}` : '',
+ AccessKeyId: 'none',
+ SecretKey: 'none'
+ };
+ operation.setContext({ headers });
return forward(operation);
});
const errorLink = onError(
@@ -97,6 +109,6 @@ export const useClient = () => {
if (token) {
initClient().catch((e) => console.error(e));
}
- }, [token, dispatch]);
+ }, [token, dispatch, setReAuth]);
return client;
};
diff --git a/frontend/src/utils/pkce.js b/frontend/src/utils/pkce.js
new file mode 100644
index 000000000..370261b81
--- /dev/null
+++ b/frontend/src/utils/pkce.js
@@ -0,0 +1,21 @@
+const base64URLEncode = (buffer) =>
+ btoa(String.fromCharCode(...new Uint8Array(buffer)))
+ .replace(/\+/g, '-')
+ .replace(/\//g, '_')
+ .replace(/=/g, '');
+
+const sha256 = async (plain) => {
+ const encoder = new TextEncoder();
+ const data = encoder.encode(plain);
+ return await crypto.subtle.digest('SHA-256', data);
+};
+
+export const generatePKCE = async () => {
+ // 96 bytes = 128 characters after base64url encoding (max per RFC 7636)
+ const verifier = base64URLEncode(crypto.getRandomValues(new Uint8Array(96)));
+ const challenge = base64URLEncode(await sha256(verifier));
+ return { verifier, challenge };
+};
+
+export const generateState = () =>
+ base64URLEncode(crypto.getRandomValues(new Uint8Array(32)));
\ No newline at end of file