diff --git a/apps/api/src/auth/auth.module.ts b/apps/api/src/auth/auth.module.ts index 4128f9f15d..64ef0ce2d5 100644 --- a/apps/api/src/auth/auth.module.ts +++ b/apps/api/src/auth/auth.module.ts @@ -14,6 +14,13 @@ import { PermissionGuard } from './permission.guard'; auth, // Don't register global auth guard - we use HybridAuthGuard disableGlobalAuthGuard: true, + // CORS is already configured in main.ts — prevent the module from + // overriding it with its own trustedOrigins-based CORS. + disableTrustedOriginsCors: true, + // Body parsing for non-auth routes is handled in main.ts with a + // custom middleware that skips /api/auth paths. Disable the module's + // own SkipBodyParsingMiddleware to avoid conflicts. + disableBodyParser: true, }), ], controllers: [AuthController], diff --git a/apps/api/src/auth/auth.server.ts b/apps/api/src/auth/auth.server.ts index b45706ae97..1085f81bdc 100644 --- a/apps/api/src/auth/auth.server.ts +++ b/apps/api/src/auth/auth.server.ts @@ -19,7 +19,7 @@ const MAGIC_LINK_EXPIRES_IN_SECONDS = 60 * 60; // 1 hour */ function getCookieDomain(): string | undefined { const baseUrl = - process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL || ''; + process.env.BASE_URL || ''; if (baseUrl.includes('staging.trycomp.ai')) { return '.staging.trycomp.ai'; @@ -109,10 +109,10 @@ function validateSecurityConfig(): void { // Warn about development defaults in production if (process.env.NODE_ENV === 'production') { const baseUrl = - process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL || ''; + process.env.BASE_URL || ''; if (baseUrl.includes('localhost')) { console.warn( - 'SECURITY WARNING: AUTH_BASE_URL contains "localhost" in production. ' + + 'SECURITY WARNING: BASE_URL contains "localhost" in production. ' + 'This may cause issues with OAuth callbacks and cookies.', ); } @@ -125,23 +125,21 @@ validateSecurityConfig(); /** * The auth server instance - single source of truth for authentication. * - * IMPORTANT: For OAuth to work correctly with the app's auth proxy: - * - Set AUTH_BASE_URL to the app's URL (e.g., http://localhost:3000 in dev) - * - This ensures OAuth callbacks point to the app, which proxies to this API - * - Cookies will be set for the app's domain, not the API's domain - * - * In production, use the app's public URL (e.g., https://app.trycomp.ai) + * BASE_URL must point to the API (e.g., https://api.trycomp.ai). + * OAuth callbacks go directly to the API. Clients send absolute callbackURLs + * so better-auth redirects to the correct app after processing. + * Cross-subdomain cookies (.trycomp.ai) ensure the session works on all apps. */ export const auth = betterAuth({ database: prismaAdapter(db, { provider: 'postgresql', }), - // Use AUTH_BASE_URL pointing to the app (client), not the API itself - // This is critical for OAuth callbacks and cookie domains to work correctly - baseURL: - process.env.AUTH_BASE_URL || - process.env.BETTER_AUTH_URL || - 'http://localhost:3000', + // baseURL must point to the API (e.g., https://api.trycomp.ai) so that + // OAuth callbacks go directly to the API regardless of which frontend + // initiated the flow. Clients must send absolute callbackURLs so that + // after OAuth processing, better-auth redirects to the correct app. + // Cross-subdomain cookies (.trycomp.ai) ensure the session works everywhere. + baseURL: process.env.BASE_URL || 'http://localhost:3333', trustedOrigins: getTrustedOrigins(), emailAndPassword: { enabled: true, @@ -322,6 +320,12 @@ export const auth = betterAuth({ enabled: true, trustedProviders: ['google', 'github', 'microsoft'], }, + // Skip the state cookie CSRF check for OAuth flows. + // In our cross-origin setup (app/portal → API), the state cookie may not + // survive the OAuth redirect flow. The OAuth state parameter stored in the + // database already provides CSRF protection (random 32-char string validated + // against the DB). This is the same approach better-auth's oAuthProxy plugin uses. + skipStateCookieCheck: true, }, verification: { modelName: 'Verification', diff --git a/apps/api/src/config/better-auth.config.ts b/apps/api/src/config/better-auth.config.ts index b7d87a7531..e28c965fa2 100644 --- a/apps/api/src/config/better-auth.config.ts +++ b/apps/api/src/config/better-auth.config.ts @@ -2,7 +2,7 @@ import { registerAs } from '@nestjs/config'; import { z } from 'zod'; const betterAuthConfigSchema = z.object({ - url: z.string().url('AUTH_BASE_URL must be a valid URL'), + url: z.string().url('BASE_URL must be a valid URL'), }); export type BetterAuthConfig = z.infer; @@ -10,7 +10,7 @@ export type BetterAuthConfig = z.infer; /** * Better Auth configuration for the API. * - * Since the API now runs the auth server, AUTH_BASE_URL should point to the API itself. + * BASE_URL should point to the API itself since the API is the auth server. * For example: * - Production: https://api.trycomp.ai * - Staging: https://api.staging.trycomp.ai @@ -19,17 +19,14 @@ export type BetterAuthConfig = z.infer; export const betterAuthConfig = registerAs( 'betterAuth', (): BetterAuthConfig => { - // AUTH_BASE_URL is the URL of the auth server (which is now the API) - // Fall back to BETTER_AUTH_URL for backwards compatibility during migration - const url = process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL; + const url = process.env.BASE_URL; if (!url) { - throw new Error('AUTH_BASE_URL or BETTER_AUTH_URL environment variable is required'); + throw new Error('BASE_URL environment variable is required'); } const config = { url }; - // Validate configuration at startup const result = betterAuthConfigSchema.safeParse(config); if (!result.success) { diff --git a/apps/api/src/integration-platform/controllers/checks.controller.ts b/apps/api/src/integration-platform/controllers/checks.controller.ts index e6c3fffc1a..fd0e596f53 100644 --- a/apps/api/src/integration-platform/controllers/checks.controller.ts +++ b/apps/api/src/integration-platform/controllers/checks.controller.ts @@ -13,12 +13,14 @@ import { ApiTags, ApiSecurity } from '@nestjs/swagger'; import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; +import { OrganizationId } from '../../auth/auth-context.decorator'; import { getManifest, getAvailableChecks, runAllChecks, } from '@comp/integration-platform'; import { ConnectionRepository } from '../repositories/connection.repository'; +import { ConnectionService } from '../services/connection.service'; import { CredentialVaultService } from '../services/credential-vault.service'; import { ProviderRepository } from '../repositories/provider.repository'; import { CheckRunRepository } from '../repositories/check-run.repository'; @@ -40,6 +42,7 @@ export class ChecksController { private readonly providerRepository: ProviderRepository, private readonly credentialVaultService: CredentialVaultService, private readonly checkRunRepository: CheckRunRepository, + private readonly connectionService: ConnectionService, ) {} /** @@ -68,7 +71,11 @@ export class ChecksController { */ @Get('connections/:connectionId') @RequirePermission('integration', 'read') - async listConnectionChecks(@Param('connectionId') connectionId: string) { + async listConnectionChecks( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(connectionId, organizationId); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -106,7 +113,9 @@ export class ChecksController { async runConnectionChecks( @Param('connectionId') connectionId: string, @Body() body: RunChecksDto, + @OrganizationId() organizationId: string, ) { + await this.connectionService.getConnectionForOrg(connectionId, organizationId); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -306,7 +315,8 @@ export class ChecksController { async runSingleCheck( @Param('connectionId') connectionId: string, @Param('checkId') checkId: string, + @OrganizationId() organizationId: string, ) { - return this.runConnectionChecks(connectionId, { checkId }); + return this.runConnectionChecks(connectionId, { checkId }, organizationId); } } diff --git a/apps/api/src/integration-platform/controllers/connections.controller.ts b/apps/api/src/integration-platform/controllers/connections.controller.ts index b50542bf75..539ff4cf6f 100644 --- a/apps/api/src/integration-platform/controllers/connections.controller.ts +++ b/apps/api/src/integration-platform/controllers/connections.controller.ts @@ -260,8 +260,11 @@ export class ConnectionsController { */ @Get(':id') @RequirePermission('integration', 'read') - async getConnection(@Param('id') id: string) { - const connection = await this.connectionService.getConnection(id); + async getConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + const connection = await this.connectionService.getConnectionForOrg(id, organizationId); const providerSlug = (connection as { provider?: { slug: string } }) .provider?.slug; @@ -654,8 +657,11 @@ export class ConnectionsController { */ @Post(':id/test') @RequirePermission('integration', 'update') - async testConnection(@Param('id') id: string) { - const connection = await this.connectionService.getConnection(id); + async testConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + const connection = await this.connectionService.getConnectionForOrg(id, organizationId); const providerSlug = (connection as any).provider?.slug; if (!providerSlug) { @@ -744,7 +750,11 @@ export class ConnectionsController { */ @Post(':id/pause') @RequirePermission('integration', 'update') - async pauseConnection(@Param('id') id: string) { + async pauseConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(id, organizationId); const connection = await this.connectionService.pauseConnection(id); return { id: connection.id, status: connection.status }; } @@ -754,7 +764,11 @@ export class ConnectionsController { */ @Post(':id/resume') @RequirePermission('integration', 'update') - async resumeConnection(@Param('id') id: string) { + async resumeConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(id, organizationId); const connection = await this.connectionService.activateConnection(id); return { id: connection.id, status: connection.status }; } @@ -764,7 +778,11 @@ export class ConnectionsController { */ @Post(':id/disconnect') @RequirePermission('integration', 'delete') - async disconnectConnection(@Param('id') id: string) { + async disconnectConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(id, organizationId); const connection = await this.connectionService.disconnectConnection(id); return { id: connection.id, status: connection.status }; } @@ -774,7 +792,11 @@ export class ConnectionsController { */ @Delete(':id') @RequirePermission('integration', 'delete') - async deleteConnection(@Param('id') id: string) { + async deleteConnection( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(id, organizationId); await this.connectionService.deleteConnection(id); return { success: true }; } @@ -789,13 +811,7 @@ export class ConnectionsController { @OrganizationId() organizationId: string, @Body() body: { metadata?: Record }, ) { - const connection = await this.connectionService.getConnection(id); - if (connection.organizationId !== organizationId) { - throw new HttpException( - 'Connection does not belong to this organization', - HttpStatus.FORBIDDEN, - ); - } + const connection = await this.connectionService.getConnectionForOrg(id, organizationId); if (body.metadata && Object.keys(body.metadata).length > 0) { // Merge with existing metadata @@ -824,11 +840,7 @@ export class ConnectionsController { @Param('id') id: string, @OrganizationId() organizationId: string, ) { - const connection = await this.connectionService.getConnection(id); - - if (connection.organizationId !== organizationId) { - throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); - } + const connection = await this.connectionService.getConnectionForOrg(id, organizationId); if (connection.status !== 'active') { throw new HttpException( @@ -988,11 +1000,7 @@ export class ConnectionsController { @OrganizationId() organizationId: string, @Body() body: { credentials: Record }, ) { - const connection = await this.connectionService.getConnection(id); - - if (connection.organizationId !== organizationId) { - throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); - } + const connection = await this.connectionService.getConnectionForOrg(id, organizationId); const providerSlug = (connection as { provider?: { slug: string } }) .provider?.slug; diff --git a/apps/api/src/integration-platform/controllers/variables.controller.ts b/apps/api/src/integration-platform/controllers/variables.controller.ts index ea13a09357..52f0ea3e4f 100644 --- a/apps/api/src/integration-platform/controllers/variables.controller.ts +++ b/apps/api/src/integration-platform/controllers/variables.controller.ts @@ -13,8 +13,10 @@ import { ApiTags, ApiSecurity } from '@nestjs/swagger'; import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; +import { OrganizationId } from '../../auth/auth-context.decorator'; import { getManifest, type CheckVariable } from '@comp/integration-platform'; import { ConnectionRepository } from '../repositories/connection.repository'; +import { ConnectionService } from '../services/connection.service'; import { ProviderRepository } from '../repositories/provider.repository'; import { CredentialVaultService } from '../services/credential-vault.service'; import { AutoCheckRunnerService } from '../services/auto-check-runner.service'; @@ -52,6 +54,7 @@ export class VariablesController { private readonly providerRepository: ProviderRepository, private readonly credentialVaultService: CredentialVaultService, private readonly autoCheckRunnerService: AutoCheckRunnerService, + private readonly connectionService: ConnectionService, ) {} /** @@ -109,7 +112,12 @@ export class VariablesController { */ @Get('connections/:connectionId') @RequirePermission('integration', 'read') - async getConnectionVariables(@Param('connectionId') connectionId: string) { + async getConnectionVariables( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + await this.connectionService.getConnectionForOrg(connectionId, organizationId); + const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -179,7 +187,10 @@ export class VariablesController { async fetchVariableOptions( @Param('connectionId') connectionId: string, @Param('variableId') variableId: string, + @OrganizationId() organizationId: string, ): Promise<{ options: VariableOption[] }> { + await this.connectionService.getConnectionForOrg(connectionId, organizationId); + const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -386,7 +397,10 @@ export class VariablesController { async saveConnectionVariables( @Param('connectionId') connectionId: string, @Body() body: SaveVariablesDto, + @OrganizationId() organizationId: string, ) { + await this.connectionService.getConnectionForOrg(connectionId, organizationId); + const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); diff --git a/apps/api/src/integration-platform/services/connection.service.ts b/apps/api/src/integration-platform/services/connection.service.ts index e195a1f62d..1599b91089 100644 --- a/apps/api/src/integration-platform/services/connection.service.ts +++ b/apps/api/src/integration-platform/services/connection.service.ts @@ -35,6 +35,17 @@ export class ConnectionService { return connection; } + async getConnectionForOrg( + connectionId: string, + organizationId: string, + ): Promise { + const connection = await this.connectionRepository.findById(connectionId); + if (!connection || connection.organizationId !== organizationId) { + throw new NotFoundException(`Connection ${connectionId} not found`); + } + return connection; + } + async getConnectionByProviderSlug( providerSlug: string, organizationId: string, diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts index c887b87333..299ebae6fa 100644 --- a/apps/api/src/main.ts +++ b/apps/api/src/main.ts @@ -45,8 +45,21 @@ async function bootstrap(): Promise { // STEP 3: Configure body parser // NOTE: Attachment uploads are sent as base64 in JSON, so request payloads are // larger than the raw file size. Keep this above the user-facing max file size. - app.use(express.json({ limit: '150mb' })); - app.use(express.urlencoded({ limit: '150mb', extended: true })); + // IMPORTANT: Skip body parsing for /api/auth routes — better-auth needs the raw + // request stream to properly read the body (including OAuth callbackURL). + // Express-level middleware runs BEFORE NestJS module middleware, so without this + // skip, express.json() would consume the stream before better-auth's handler. + const jsonParser = express.json({ limit: '150mb' }); + const urlencodedParser = express.urlencoded({ limit: '150mb', extended: true }); + app.use((req: express.Request, res: express.Response, next: express.NextFunction) => { + if (req.path.startsWith('/api/auth')) { + return next(); + } + jsonParser(req, res, (err?: unknown) => { + if (err) return next(err); + urlencodedParser(req, res, next); + }); + }); // STEP 4: Enable global pipes and filters app.useGlobalPipes( diff --git a/apps/api/src/people/people.service.spec.ts b/apps/api/src/people/people.service.spec.ts index 6fb12d5301..3d1f4271f9 100644 --- a/apps/api/src/people/people.service.spec.ts +++ b/apps/api/src/people/people.service.spec.ts @@ -241,6 +241,7 @@ describe('PeopleService', () => { expect(result).toEqual(updatedMember); expect(MemberQueries.updateMember).toHaveBeenCalledWith( 'mem_1', + 'org_123', updateData, ); }); @@ -333,7 +334,7 @@ describe('PeopleService', () => { expect(result.success).toBe(true); expect(result.deletedMember.id).toBe('mem_1'); expect(db.member.update).toHaveBeenCalledWith({ - where: { id: 'mem_1' }, + where: { id: 'mem_1', organizationId: 'org_123' }, data: { deactivated: true, isActive: false }, }); expect(db.session.deleteMany).toHaveBeenCalledWith({ @@ -440,6 +441,10 @@ describe('PeopleService', () => { expect(result.fleetDmLabelId).toBeNull(); expect(fleetService.removeHostsByLabel).toHaveBeenCalledWith(42); + expect(MemberQueries.unlinkDevice).toHaveBeenCalledWith( + 'mem_1', + 'org_123', + ); }); it('should skip fleet removal when no label exists', async () => { diff --git a/apps/api/src/people/people.service.ts b/apps/api/src/people/people.service.ts index 457ef351f4..8c69b27f2d 100644 --- a/apps/api/src/people/people.service.ts +++ b/apps/api/src/people/people.service.ts @@ -281,6 +281,7 @@ export class PeopleService { const updatedMember = await MemberQueries.updateMember( memberId, + organizationId, updateData, ); @@ -345,7 +346,7 @@ export class PeopleService { await removeMemberFromOrgChart({ organizationId, memberId }); await db.member.update({ - where: { id: memberId }, + where: { id: memberId, organizationId }, data: { deactivated: true, isActive: false }, }); @@ -404,7 +405,7 @@ export class PeopleService { } return db.member.update({ - where: { id: memberId }, + where: { id: memberId, organizationId }, data: { deactivated: false, isActive: true }, select: MemberQueries.MEMBER_SELECT, }); @@ -472,7 +473,7 @@ export class PeopleService { ); } - const updatedMember = await MemberQueries.unlinkDevice(memberId); + const updatedMember = await MemberQueries.unlinkDevice(memberId, organizationId); this.logger.log( `Unlinked device for member: ${updatedMember.user.name} (${memberId})`, diff --git a/apps/api/src/people/utils/member-queries.ts b/apps/api/src/people/utils/member-queries.ts index f0bb09f82c..23a20f487f 100644 --- a/apps/api/src/people/utils/member-queries.ts +++ b/apps/api/src/people/utils/member-queries.ts @@ -91,10 +91,11 @@ export class MemberQueries { } /** - * Update a member by ID + * Update a member by ID within an organization */ static async updateMember( memberId: string, + organizationId: string, updateData: UpdatePeopleDto, ): Promise { // Separate user-level fields from member-level fields @@ -123,9 +124,9 @@ export class MemberQueries { // If we need to update both user and member, use a transaction if (hasUserUpdates) { return db.$transaction(async (tx) => { - // Get the member to find the associated userId - const member = await tx.member.findUniqueOrThrow({ - where: { id: memberId }, + // Get the member to find the associated userId (scoped to org) + const member = await tx.member.findFirstOrThrow({ + where: { id: memberId, organizationId }, select: { userId: true }, }); @@ -142,15 +143,15 @@ export class MemberQueries { // Update member fields if any if (hasMemberUpdates) { return tx.member.update({ - where: { id: memberId }, + where: { id: memberId, organizationId }, data: updatePayload, select: this.MEMBER_SELECT, }); } // Return updated member with fresh user data - return tx.member.findUniqueOrThrow({ - where: { id: memberId }, + return tx.member.findFirstOrThrow({ + where: { id: memberId, organizationId }, select: this.MEMBER_SELECT, }); }); @@ -158,7 +159,7 @@ export class MemberQueries { // Only member-level updates return db.member.update({ - where: { id: memberId }, + where: { id: memberId, organizationId }, data: updatePayload, select: this.MEMBER_SELECT, }); @@ -193,20 +194,26 @@ export class MemberQueries { } /** - * Delete a member by ID + * Delete a member by ID within an organization */ - static async deleteMember(memberId: string): Promise { + static async deleteMember( + memberId: string, + organizationId: string, + ): Promise { await db.member.delete({ - where: { id: memberId }, + where: { id: memberId, organizationId }, }); } /** - * Unlink device by resetting fleetDmLabelId to null + * Unlink device by resetting fleetDmLabelId to null within an organization */ - static async unlinkDevice(memberId: string): Promise { + static async unlinkDevice( + memberId: string, + organizationId: string, + ): Promise { return db.member.update({ - where: { id: memberId }, + where: { id: memberId, organizationId }, data: { fleetDmLabelId: null }, select: this.MEMBER_SELECT, }); diff --git a/apps/api/src/questionnaire/questionnaire.controller.spec.ts b/apps/api/src/questionnaire/questionnaire.controller.spec.ts index 86d8572469..12774d1ddc 100644 --- a/apps/api/src/questionnaire/questionnaire.controller.spec.ts +++ b/apps/api/src/questionnaire/questionnaire.controller.spec.ts @@ -5,6 +5,13 @@ jest.mock('../auth/auth.server', () => ({ auth: { api: { getSession: jest.fn() } }, })); +jest.mock('@comp/auth', () => ({ + statement: {}, + ac: {}, + allRoles: {}, + BUILT_IN_ROLE_PERMISSIONS: {}, +})); + jest.mock('@/vector-store/lib', () => ({ syncOrganizationEmbeddings: jest.fn(), findSimilarContentBatch: jest.fn(), @@ -211,13 +218,40 @@ describe('QuestionnaireController', () => { error: undefined, }); - const result = await controller.answerSingleQuestion(dto as any); + const result = await controller.answerSingleQuestion( + dto as any, + 'org_1', + ); expect(result.success).toBe(true); expect(result.data.answer).toBe('Our policy covers...'); expect(result.data.questionIndex).toBe(0); expect(result.data.sources).toHaveLength(1); }); + + it('should override body organizationId with auth-derived org', async () => { + const dto = { + question: 'What is your policy?', + organizationId: 'org_attacker', + questionIndex: 0, + totalQuestions: 5, + }; + mockService.answerSingleQuestion.mockResolvedValue({ + success: true, + questionIndex: 0, + question: 'What is your policy?', + answer: 'Answer', + sources: [], + error: undefined, + }); + + await controller.answerSingleQuestion(dto as any, 'org_1'); + + expect(dto.organizationId).toBe('org_1'); + expect(service.answerSingleQuestion).toHaveBeenCalledWith( + expect.objectContaining({ organizationId: 'org_1' }), + ); + }); }); describe('saveAnswer', () => { @@ -231,10 +265,25 @@ describe('QuestionnaireController', () => { }; mockService.saveAnswer.mockResolvedValue({ success: true }); - const result = await controller.saveAnswer(dto as any); + const result = await controller.saveAnswer(dto as any, 'org_1'); expect(result).toEqual({ success: true }); }); + + it('should override body organizationId with auth-derived org', async () => { + const dto = { + questionnaireId: 'q1', + organizationId: 'org_attacker', + questionIndex: 0, + answer: 'Yes', + status: 'manual', + }; + mockService.saveAnswer.mockResolvedValue({ success: true }); + + await controller.saveAnswer(dto as any, 'org_1'); + + expect(dto.organizationId).toBe('org_1'); + }); }); describe('deleteAnswer', () => { @@ -246,10 +295,23 @@ describe('QuestionnaireController', () => { }; mockService.deleteAnswer.mockResolvedValue({ success: true }); - const result = await controller.deleteAnswer(dto as any); + const result = await controller.deleteAnswer(dto as any, 'org_1'); expect(result).toEqual({ success: true }); }); + + it('should override body organizationId with auth-derived org', async () => { + const dto = { + questionnaireId: 'q1', + organizationId: 'org_attacker', + questionAnswerId: 'qa1', + }; + mockService.deleteAnswer.mockResolvedValue({ success: true }); + + await controller.deleteAnswer(dto as any, 'org_1'); + + expect(dto.organizationId).toBe('org_1'); + }); }); describe('uploadAndParse', () => { @@ -266,9 +328,30 @@ describe('QuestionnaireController', () => { totalQuestions: 10, }); - const result = await controller.uploadAndParse(dto as any); + const result = await controller.uploadAndParse(dto as any, 'org_1'); expect(result).toEqual({ questionnaireId: 'q1', totalQuestions: 10 }); }); + + it('should override body organizationId with auth-derived org', async () => { + const dto = { + organizationId: 'org_attacker', + fileName: 'test.pdf', + fileType: 'application/pdf', + fileData: 'base64data', + source: 'internal', + }; + mockService.uploadAndParse.mockResolvedValue({ + questionnaireId: 'q1', + totalQuestions: 10, + }); + + await controller.uploadAndParse(dto as any, 'org_1'); + + expect(dto.organizationId).toBe('org_1'); + expect(service.uploadAndParse).toHaveBeenCalledWith( + expect.objectContaining({ organizationId: 'org_1' }), + ); + }); }); }); diff --git a/apps/api/src/questionnaire/questionnaire.controller.ts b/apps/api/src/questionnaire/questionnaire.controller.ts index 52ab0d9e1a..9ea0369f8f 100644 --- a/apps/api/src/questionnaire/questionnaire.controller.ts +++ b/apps/api/src/questionnaire/questionnaire.controller.ts @@ -169,7 +169,11 @@ export class QuestionnaireController { }, }, }) - async answerSingleQuestion(@Body() dto: AnswerSingleQuestionDto) { + async answerSingleQuestion( + @Body() dto: AnswerSingleQuestionDto, + @OrganizationId() organizationId: string, + ) { + dto.organizationId = organizationId; const result = await this.questionnaireService.answerSingleQuestion(dto); return { success: result.success, @@ -196,7 +200,11 @@ export class QuestionnaireController { }, }, }) - async saveAnswer(@Body() dto: SaveAnswerDto) { + async saveAnswer( + @Body() dto: SaveAnswerDto, + @OrganizationId() organizationId: string, + ) { + dto.organizationId = organizationId; return this.questionnaireService.saveAnswer(dto); } @@ -213,7 +221,11 @@ export class QuestionnaireController { }, }, }) - async deleteAnswer(@Body() dto: DeleteAnswerDto) { + async deleteAnswer( + @Body() dto: DeleteAnswerDto, + @OrganizationId() organizationId: string, + ) { + dto.organizationId = organizationId; return this.questionnaireService.deleteAnswer(dto); } @@ -232,7 +244,9 @@ export class QuestionnaireController { async exportById( @Body() dto: ExportByIdDto, @Res({ passthrough: true }) res: Response, + @OrganizationId() organizationId: string, ): Promise { + dto.organizationId = organizationId; const result = await this.questionnaireService.exportById(dto); res.setHeader('Content-Type', result.mimeType); @@ -258,7 +272,11 @@ export class QuestionnaireController { }, }, }) - async uploadAndParse(@Body() dto: UploadAndParseDto) { + async uploadAndParse( + @Body() dto: UploadAndParseDto, + @OrganizationId() organizationId: string, + ) { + dto.organizationId = organizationId; return this.questionnaireService.uploadAndParse(dto); } @@ -307,16 +325,14 @@ export class QuestionnaireController { organizationId: string; source?: 'internal' | 'external'; }, + @OrganizationId() organizationId: string, ) { if (!file) { throw new BadRequestException('file is required'); } - if (!body.organizationId) { - throw new BadRequestException('organizationId is required'); - } const dto: UploadAndParseDto = { - organizationId: body.organizationId, + organizationId, fileName: file.originalname, fileType: file.mimetype, fileData: file.buffer.toString('base64'), @@ -374,18 +390,16 @@ export class QuestionnaireController { source?: 'internal' | 'external'; }, @Res({ passthrough: true }) res: Response, + @OrganizationId() organizationId: string, ): Promise { if (!file) { throw new BadRequestException('file is required'); } - if (!body.organizationId) { - throw new BadRequestException('organizationId is required'); - } const dto: ExportQuestionnaireDto = { fileData: file.buffer.toString('base64'), fileType: file.mimetype, - organizationId: body.organizationId, + organizationId, fileName: file.originalname, vendorName: undefined, format: body.format || 'xlsx', @@ -498,7 +512,9 @@ export class QuestionnaireController { async autoAnswerAndExport( @Body() dto: ExportQuestionnaireDto, @Res({ passthrough: true }) res: Response, + @OrganizationId() organizationId: string, ): Promise { + dto.organizationId = organizationId; const result = await this.questionnaireService.autoAnswerAndExport(dto); res.setHeader('Content-Type', result.mimeType); @@ -550,18 +566,16 @@ export class QuestionnaireController { @UploadedFile() file: Express.Multer.File, @Body() body: { organizationId: string; format?: 'pdf' | 'csv' | 'xlsx' }, @Res({ passthrough: true }) res: Response, + @OrganizationId() organizationId: string, ): Promise { if (!file) { throw new BadRequestException('file is required'); } - if (!body.organizationId) { - throw new BadRequestException('organizationId is required'); - } const dto: ExportQuestionnaireDto = { fileData: file.buffer.toString('base64'), fileType: file.mimetype, - organizationId: body.organizationId, + organizationId, fileName: file.originalname, vendorName: undefined, format: body.format || 'xlsx', @@ -589,7 +603,9 @@ export class QuestionnaireController { async autoAnswer( @Body() dto: AutoAnswerDto, @Res() res: Response, + @OrganizationId() organizationId: string, ): Promise { + dto.organizationId = organizationId; setupSSEHeaders(res); const send = createSafeSSESender(res); diff --git a/apps/app/src/app/api/auth/[...all]/route.ts b/apps/app/src/app/api/auth/[...all]/route.ts deleted file mode 100644 index 47804f415d..0000000000 --- a/apps/app/src/app/api/auth/[...all]/route.ts +++ /dev/null @@ -1,315 +0,0 @@ -/** - * Auth API route proxy. - * - * This route proxies auth requests to the API server. - * The actual auth server runs on the API - this app only forwards requests. - * - * SECURITY: - * - Rate limiting to prevent brute force attacks - * - Redirect URL validation to prevent open redirects - * - Conditional logging (development only) - */ - -import { NextRequest, NextResponse } from 'next/server'; - -// IMPORTANT: This proxy must always point to the actual API server. -// Do NOT use BETTER_AUTH_URL here - that may be set to the app's URL which would cause a loop. -const API_URL = - process.env.BACKEND_API_URL || process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333'; - -const IS_DEVELOPMENT = process.env.NODE_ENV === 'development'; - -// ============================================================================= -// Rate Limiting (in-memory, per-instance) -// ============================================================================= - -interface RateLimitEntry { - count: number; - resetTime: number; -} - -// Simple in-memory rate limiter -// In production, consider using Redis or a distributed rate limiter -const rateLimitMap = new Map(); - -const RATE_LIMIT_WINDOW_MS = 60 * 1000; // 1 minute -const RATE_LIMIT_MAX_REQUESTS = 60; // 60 requests per minute per IP - -// Stricter limits for sensitive endpoints -const SENSITIVE_ENDPOINTS = [ - '/api/auth/sign-in', - '/api/auth/sign-up', - '/api/auth/magic-link', - '/api/auth/email-otp', - '/api/auth/verify-otp', - '/api/auth/reset-password', -]; -const SENSITIVE_RATE_LIMIT_MAX = 10; // 10 requests per minute for sensitive endpoints - -function getClientIP(request: NextRequest): string { - // Check various headers for the real IP (behind proxies/load balancers) - const forwarded = request.headers.get('x-forwarded-for'); - if (forwarded) { - return forwarded.split(',')[0].trim(); - } - const realIP = request.headers.get('x-real-ip'); - if (realIP) { - return realIP; - } - // Fallback - not ideal but better than nothing - return 'unknown'; -} - -function checkRateLimit(ip: string, pathname: string): { allowed: boolean; retryAfter?: number } { - const now = Date.now(); - const key = `${ip}:${pathname}`; - - // Determine the rate limit based on endpoint sensitivity - const isSensitive = SENSITIVE_ENDPOINTS.some((ep) => pathname.startsWith(ep)); - const maxRequests = isSensitive ? SENSITIVE_RATE_LIMIT_MAX : RATE_LIMIT_MAX_REQUESTS; - - const entry = rateLimitMap.get(key); - - if (!entry || now > entry.resetTime) { - // New window - rateLimitMap.set(key, { - count: 1, - resetTime: now + RATE_LIMIT_WINDOW_MS, - }); - return { allowed: true }; - } - - if (entry.count >= maxRequests) { - const retryAfter = Math.ceil((entry.resetTime - now) / 1000); - return { allowed: false, retryAfter }; - } - - entry.count++; - return { allowed: true }; -} - -// Clean up old entries periodically (every 5 minutes) -setInterval(() => { - const now = Date.now(); - for (const [key, entry] of rateLimitMap.entries()) { - if (now > entry.resetTime) { - rateLimitMap.delete(key); - } - } -}, 5 * 60 * 1000); - -// ============================================================================= -// Redirect URL Validation -// ============================================================================= - -function getAllowedHosts(): string[] { - const hosts = [ - 'localhost:3000', - 'localhost:3002', - 'localhost:3333', - 'app.trycomp.ai', - 'portal.trycomp.ai', - 'api.trycomp.ai', - 'app.staging.trycomp.ai', - 'portal.staging.trycomp.ai', - 'api.staging.trycomp.ai', - ]; - - // Add any custom allowed hosts from environment - const customHosts = process.env.AUTH_ALLOWED_REDIRECT_HOSTS; - if (customHosts) { - hosts.push(...customHosts.split(',').map((h) => h.trim())); - } - - return hosts; -} - -function isAllowedRedirectUrl(redirectUrl: string, requestOrigin: string): boolean { - try { - const url = new URL(redirectUrl); - const allowedHosts = getAllowedHosts(); - - // Allow redirects to the request's own origin - const originUrl = new URL(requestOrigin); - if (url.host === originUrl.host) { - return true; - } - - // Allow redirects to configured allowed hosts - return allowedHosts.includes(url.host); - } catch { - // If URL parsing fails, check if it's a relative URL (which is safe) - return redirectUrl.startsWith('/'); - } -} - -// ============================================================================= -// Proxy Implementation -// ============================================================================= - -async function proxyRequest(request: NextRequest): Promise { - const url = new URL(request.url); - const clientIP = getClientIP(request); - - // Check rate limit - const rateLimit = checkRateLimit(clientIP, url.pathname); - if (!rateLimit.allowed) { - if (IS_DEVELOPMENT) { - console.log(`[auth proxy] Rate limit exceeded for ${clientIP} on ${url.pathname}`); - } - return NextResponse.json( - { error: 'Too many requests. Please try again later.' }, - { - status: 429, - headers: { - 'Retry-After': String(rateLimit.retryAfter || 60), - }, - } - ); - } - - const targetUrl = `${API_URL}${url.pathname}${url.search}`; - - if (IS_DEVELOPMENT) { - console.log(`[auth proxy] ${request.method} ${url.pathname} -> ${targetUrl}`); - } - - try { - // Forward the request to the API - const response = await fetch(targetUrl, { - method: request.method, - headers: { - // Forward all headers except host - ...Object.fromEntries( - Array.from(request.headers.entries()).filter( - ([key]) => key.toLowerCase() !== 'host' - ) - ), - }, - body: request.method !== 'GET' && request.method !== 'HEAD' ? await request.text() : undefined, - // Don't follow redirects - let the client handle them - redirect: 'manual', - }); - - if (IS_DEVELOPMENT) { - console.log(`[auth proxy] Response: ${response.status} ${response.statusText}`); - } - - // Create response with the same status and headers - const responseHeaders = new Headers(); - - // Handle Set-Cookie headers specially - they need to be appended, not set - // Use getSetCookie() with fallback for runtimes that don't support it - let setCookieHeaders: string[] = []; - if (typeof response.headers.getSetCookie === 'function') { - setCookieHeaders = response.headers.getSetCookie(); - } - - // Fallback: extract from raw headers if getSetCookie didn't work - // This handles Vercel/edge runtimes where getSetCookie may not be available - if (setCookieHeaders.length === 0) { - const raw = response.headers.get('set-cookie'); - if (raw) { - // Split on comma-space but NOT within Expires date values - // e.g. "Expires=Thu, 01 Jan 2026" contains a comma we must not split on - setCookieHeaders = raw.split(/,(?=\s*[a-zA-Z_\-.]+=)/); - } - } - - response.headers.forEach((value, key) => { - const lowerKey = key.toLowerCase(); - // Skip set-cookie here, we'll handle it separately - if (lowerKey === 'set-cookie') { - return; - } - responseHeaders.set(key, value); - }); - - // Process Set-Cookie headers - if (setCookieHeaders.length > 0) { - if (IS_DEVELOPMENT) { - console.log(`[auth proxy] Forwarding ${setCookieHeaders.length} Set-Cookie headers`); - } - for (const cookie of setCookieHeaders) { - let processedCookie = cookie; - - // In development, cookies between localhost:3000 and localhost:3333 - // need to have their domain removed to work correctly - if (IS_DEVELOPMENT) { - // Remove domain attribute so cookie is set for current host - processedCookie = processedCookie.replace(/;\s*domain=[^;]*/gi, ''); - } - - responseHeaders.append('set-cookie', processedCookie); - - // When a cookie has a Domain attribute (cross-subdomain), also delete - // any stale host-only cookie with the same name. Host-only cookies - // take precedence and would shadow the new cross-subdomain cookie. - if (!IS_DEVELOPMENT && /;\s*domain=/i.test(processedCookie)) { - const nameMatch = processedCookie.match(/^([^=]+)=/); - if (nameMatch) { - const cookieName = nameMatch[1].trim(); - responseHeaders.append( - 'set-cookie', - `${cookieName}=; Path=/; Max-Age=0; Secure; HttpOnly; SameSite=Lax`, - ); - } - } - } - } - - // Handle redirects with URL validation - if (response.status >= 300 && response.status < 400) { - const location = response.headers.get('location'); - if (location) { - // Rewrite API URLs to app URLs in redirects - const rewrittenLocation = location.replace(API_URL, url.origin); - - // Validate the redirect URL for security - if (!isAllowedRedirectUrl(rewrittenLocation, url.origin)) { - console.error(`[auth proxy] SECURITY: Blocked suspicious redirect to ${rewrittenLocation}`); - return NextResponse.json( - { error: 'Invalid redirect URL' }, - { status: 400 } - ); - } - - if (IS_DEVELOPMENT) { - console.log(`[auth proxy] Redirect: ${location} -> ${rewrittenLocation}`); - } - responseHeaders.set('location', rewrittenLocation); - } - } - - const body = response.status === 204 ? null : await response.text(); - - return new NextResponse(body, { - status: response.status, - statusText: response.statusText, - headers: responseHeaders, - }); - } catch (error) { - console.error('[auth proxy] Failed to proxy request:', error); - return NextResponse.json({ error: 'Auth service unavailable' }, { status: 503 }); - } -} - -export async function GET(request: NextRequest): Promise { - return proxyRequest(request); -} - -export async function POST(request: NextRequest): Promise { - return proxyRequest(request); -} - -export async function PUT(request: NextRequest): Promise { - return proxyRequest(request); -} - -export async function DELETE(request: NextRequest): Promise { - return proxyRequest(request); -} - -export async function PATCH(request: NextRequest): Promise { - return proxyRequest(request); -} diff --git a/apps/app/src/utils/auth-callback.ts b/apps/app/src/utils/auth-callback.ts index 84af3ac21d..dabb28e72c 100644 --- a/apps/app/src/utils/auth-callback.ts +++ b/apps/app/src/utils/auth-callback.ts @@ -28,6 +28,8 @@ export const getSafeRedirectPath = (path?: string | null): string | undefined => /** * Builds the auth callback URL for sign-in flows. + * Returns absolute URLs so that OAuth callbacks redirect to the correct app + * regardless of which server processes the OAuth flow. * * Priority: * 1. If inviteCode is provided, redirect to /invite/{code} @@ -40,17 +42,20 @@ export const buildAuthCallbackUrl = (options?: { }): string => { const { inviteCode, redirectTo } = options ?? {}; + // Determine the base origin for absolute URLs + const origin = typeof window !== 'undefined' ? window.location.origin : ''; + // Invite code takes priority if (inviteCode) { - return `/invite/${inviteCode}`; + return `${origin}/invite/${inviteCode}`; } // Use redirectTo if valid const safeRedirect = getSafeRedirectPath(redirectTo); if (safeRedirect) { - return safeRedirect; + return `${origin}${safeRedirect}`; } // Default to root - return '/'; + return `${origin}/`; }; diff --git a/apps/app/src/utils/auth-client.ts b/apps/app/src/utils/auth-client.ts index 48bc9b975b..6390df2ece 100644 --- a/apps/app/src/utils/auth-client.ts +++ b/apps/app/src/utils/auth-client.ts @@ -10,20 +10,15 @@ import { ac, allRoles } from './permissions'; /** * Auth client for browser-side authentication. * - * This client uses the app's own URL as the base, which routes through the - * auth proxy at /api/auth/[...all]. This ensures cookies are set for the - * correct domain (the app's domain). + * Points directly to the API server. Cross-subdomain cookies (.trycomp.ai) + * ensure the session works across all apps (app, portal, etc.). * * For server-side session validation, use auth.ts instead. * * SECURITY NOTE: Authentication is handled via httpOnly cookies set by the API. * We do not store tokens in localStorage to prevent XSS attacks. */ - -// Use empty string for relative URLs - this makes all auth requests go through -// the app's own /api/auth/* routes, which proxy to the API server. -// This ensures cookies are set for the app's domain, not the API's domain. -const BASE_URL = ''; +const BASE_URL = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333'; export const authClient = createAuthClient({ baseURL: BASE_URL, diff --git a/apps/app/src/utils/auth.ts b/apps/app/src/utils/auth.ts index 42b694ee65..f2a2349f64 100644 --- a/apps/app/src/utils/auth.ts +++ b/apps/app/src/utils/auth.ts @@ -14,9 +14,7 @@ import { ac, allRoles } from './permissions'; // Re-export permissions for convenience export { ac, allRoles }; -// IMPORTANT: This must point to the actual API server, not the app itself. -// Use BACKEND_API_URL for server-to-server communication, or fall back to NEXT_PUBLIC_API_URL. -// Do NOT use BETTER_AUTH_URL here - that may be the app's URL for the client-side auth. +// Must point to the API server for server-to-server auth calls. const API_URL = process.env.BACKEND_API_URL || process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333'; diff --git a/apps/portal/src/app/api/auth/[...all]/route.ts b/apps/portal/src/app/api/auth/[...all]/route.ts deleted file mode 100644 index b56615f17f..0000000000 --- a/apps/portal/src/app/api/auth/[...all]/route.ts +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Auth API route proxy for the Portal. - * - * Proxies all auth requests to the NestJS API server. - * The actual auth server (better-auth) runs on the API — the portal only forwards requests. - */ - -import { NextRequest, NextResponse } from 'next/server'; - -const API_URL = - process.env.BACKEND_API_URL || process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333'; - -const IS_DEVELOPMENT = process.env.NODE_ENV === 'development'; - -// Rate limiting -interface RateLimitEntry { count: number; resetTime: number; } -const rateLimitMap = new Map(); -const RATE_LIMIT_WINDOW_MS = 60 * 1000; -const RATE_LIMIT_MAX_REQUESTS = 60; -const SENSITIVE_ENDPOINTS = [ - '/api/auth/sign-in', - '/api/auth/sign-up', - '/api/auth/email-otp', - '/api/auth/verify-otp', -]; -const SENSITIVE_RATE_LIMIT_MAX = 10; - -function getClientIP(request: NextRequest): string { - return request.headers.get('x-forwarded-for')?.split(',')[0].trim() - || request.headers.get('x-real-ip') - || 'unknown'; -} - -function checkRateLimit(ip: string, pathname: string): { allowed: boolean; retryAfter?: number } { - const now = Date.now(); - const key = `${ip}:${pathname}`; - const isSensitive = SENSITIVE_ENDPOINTS.some((ep) => pathname.startsWith(ep)); - const maxRequests = isSensitive ? SENSITIVE_RATE_LIMIT_MAX : RATE_LIMIT_MAX_REQUESTS; - const entry = rateLimitMap.get(key); - - if (!entry || now > entry.resetTime) { - rateLimitMap.set(key, { count: 1, resetTime: now + RATE_LIMIT_WINDOW_MS }); - return { allowed: true }; - } - if (entry.count >= maxRequests) { - return { allowed: false, retryAfter: Math.ceil((entry.resetTime - now) / 1000) }; - } - entry.count++; - return { allowed: true }; -} - -setInterval(() => { - const now = Date.now(); - for (const [key, entry] of rateLimitMap.entries()) { - if (now > entry.resetTime) rateLimitMap.delete(key); - } -}, 5 * 60 * 1000); - -// Redirect URL validation -function isAllowedRedirectUrl(redirectUrl: string, requestOrigin: string): boolean { - try { - const url = new URL(redirectUrl); - const originUrl = new URL(requestOrigin); - if (url.host === originUrl.host) return true; - const allowedHosts = [ - 'localhost:3000', 'localhost:3002', 'localhost:3333', - 'app.trycomp.ai', 'portal.trycomp.ai', 'api.trycomp.ai', - 'app.staging.trycomp.ai', 'portal.staging.trycomp.ai', 'api.staging.trycomp.ai', - ]; - return allowedHosts.includes(url.host); - } catch { - return redirectUrl.startsWith('/'); - } -} - -// Proxy implementation -async function proxyRequest(request: NextRequest): Promise { - const url = new URL(request.url); - - const rateLimit = checkRateLimit(getClientIP(request), url.pathname); - if (!rateLimit.allowed) { - return NextResponse.json( - { error: 'Too many requests. Please try again later.' }, - { status: 429, headers: { 'Retry-After': String(rateLimit.retryAfter || 60) } }, - ); - } - - const targetUrl = `${API_URL}${url.pathname}${url.search}`; - - try { - const response = await fetch(targetUrl, { - method: request.method, - headers: { - ...Object.fromEntries( - Array.from(request.headers.entries()).filter(([key]) => key.toLowerCase() !== 'host'), - ), - }, - body: request.method !== 'GET' && request.method !== 'HEAD' ? await request.text() : undefined, - redirect: 'manual', - }); - - const responseHeaders = new Headers(); - const setCookieHeaders = response.headers.getSetCookie?.() || []; - - response.headers.forEach((value, key) => { - if (key.toLowerCase() === 'set-cookie') return; - responseHeaders.set(key, value); - }); - - if (setCookieHeaders.length > 0) { - for (const cookie of setCookieHeaders) { - let processedCookie = cookie; - if (IS_DEVELOPMENT) { - processedCookie = processedCookie.replace(/;\s*domain=[^;]*/gi, ''); - } - responseHeaders.append('set-cookie', processedCookie); - } - } - - if (response.status >= 300 && response.status < 400) { - const location = response.headers.get('location'); - if (location) { - const rewrittenLocation = location.replace(API_URL, url.origin); - if (!isAllowedRedirectUrl(rewrittenLocation, url.origin)) { - console.error(`[auth proxy] Blocked suspicious redirect to ${rewrittenLocation}`); - return NextResponse.json({ error: 'Invalid redirect URL' }, { status: 400 }); - } - responseHeaders.set('location', rewrittenLocation); - } - } - - const body = response.status === 204 ? null : await response.text(); - return new NextResponse(body, { status: response.status, statusText: response.statusText, headers: responseHeaders }); - } catch (error) { - console.error('[auth proxy] Failed to proxy request:', error); - return NextResponse.json({ error: 'Auth service unavailable' }, { status: 503 }); - } -} - -export async function GET(request: NextRequest) { return proxyRequest(request); } -export async function POST(request: NextRequest) { return proxyRequest(request); } -export async function PUT(request: NextRequest) { return proxyRequest(request); } -export async function DELETE(request: NextRequest) { return proxyRequest(request); } -export async function PATCH(request: NextRequest) { return proxyRequest(request); } diff --git a/apps/portal/src/app/components/otp-form.tsx b/apps/portal/src/app/components/otp-form.tsx index 2733d17cc8..3dce834044 100644 --- a/apps/portal/src/app/components/otp-form.tsx +++ b/apps/portal/src/app/components/otp-form.tsx @@ -1,5 +1,6 @@ 'use client'; +import { authClient } from '@/app/lib/auth-client'; import { Form, FormControl, FormField, FormItem, FormMessage } from '@comp/ui/form'; import { zodResolver } from '@hookform/resolvers/zod'; import { Button } from '@comp/ui/button'; @@ -13,9 +14,6 @@ import { z } from 'zod'; const INPUT_LENGTH = 6; -const API_URL = - process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333'; - const otpFormSchema = z.object({ email: z.string().email(), otp: z.string().min(INPUT_LENGTH, 'OTP is required'), @@ -43,20 +41,13 @@ export function OtpForm({ email, deviceAuthRedirect }: OtpFormProps) { try { setIsLoading(true); - const response = await fetch('/api/auth/sign-in/email-otp', { - method: 'POST', - credentials: 'include', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - email: formData.email, - otp: formData.otp, - }), + const { error } = await authClient.signIn.emailOtp({ + email: formData.email, + otp: formData.otp, }); - if (!response.ok) { - const errorData = await response.json().catch(() => ({})); - const errorMessage = errorData.message || 'Login failed'; - const lower = errorMessage.toLowerCase(); + if (error) { + const lower = (error.message || '').toLowerCase(); if (lower.includes('invalid') && lower.includes('otp')) { toast.error('Invalid OTP code. Please check your code and try again.'); diff --git a/apps/portal/src/app/lib/auth-client.ts b/apps/portal/src/app/lib/auth-client.ts index e023556e54..902966a9ef 100644 --- a/apps/portal/src/app/lib/auth-client.ts +++ b/apps/portal/src/app/lib/auth-client.ts @@ -7,8 +7,7 @@ import { createAuthClient } from 'better-auth/react'; import { ac, allRoles } from '@comp/auth'; export const authClient = createAuthClient({ - // Empty baseURL = calls go through the portal's own /api/auth/* proxy - baseURL: '', + baseURL: process.env.NEXT_PUBLIC_API_URL || 'http://localhost:3333', plugins: [ organizationClient({ ac, roles: allRoles }), emailOTPClient(),