Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions apps/api/src/auth/auth-server-origins.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Tests for the getTrustedOrigins logic.
* Tests for the getTrustedOrigins / isTrustedOrigin logic.
*
* Because auth.server.ts has side effects at module load time (better-auth
* initialization, DB connections, validateSecurityConfig), we test the logic
Expand All @@ -25,6 +25,32 @@ function getTrustedOriginsLogic(authTrustedOrigins: string | undefined): string[
];
}

/**
* Mirror of isStaticTrustedOrigin from auth.server.ts for isolated testing.
* The full isTrustedOrigin is async (checks DB for custom domains) —
* that path is tested via integration tests.
*/
function isStaticTrustedOriginLogic(
origin: string,
trustedOrigins: string[],
): boolean {
if (trustedOrigins.includes(origin)) {
return true;
}

try {
const url = new URL(origin);
return (
url.hostname.endsWith('.trycomp.ai') ||
url.hostname.endsWith('.staging.trycomp.ai') ||
url.hostname.endsWith('.trust.inc') ||
url.hostname === 'trust.inc'
);
} catch {
return false;
}
}

describe('getTrustedOrigins', () => {
it('should return env-configured origins when AUTH_TRUSTED_ORIGINS is set', () => {
const origins = getTrustedOriginsLogic('https://a.com, https://b.com');
Expand All @@ -45,17 +71,47 @@ describe('getTrustedOrigins', () => {
const origins = getTrustedOriginsLogic(' https://a.com , https://b.com ');
expect(origins).toEqual(['https://a.com', 'https://b.com']);
});
});

describe('isStaticTrustedOrigin', () => {
const defaults = getTrustedOriginsLogic(undefined);

it('should allow static trusted origins', () => {
expect(isStaticTrustedOriginLogic('https://app.trycomp.ai', defaults)).toBe(true);
});

it('should allow trust portal subdomains of trycomp.ai', () => {
expect(isStaticTrustedOriginLogic('https://security.trycomp.ai', defaults)).toBe(true);
expect(isStaticTrustedOriginLogic('https://acme.trycomp.ai', defaults)).toBe(true);
});

it('should allow trust portal subdomains of staging.trycomp.ai', () => {
expect(isStaticTrustedOriginLogic('https://security.staging.trycomp.ai', defaults)).toBe(true);
});

it('should allow trust.inc and its subdomains', () => {
expect(isStaticTrustedOriginLogic('https://trust.inc', defaults)).toBe(true);
expect(isStaticTrustedOriginLogic('https://acme.trust.inc', defaults)).toBe(true);
});

it('should reject unknown origins', () => {
expect(isStaticTrustedOriginLogic('https://evil.com', defaults)).toBe(false);
expect(isStaticTrustedOriginLogic('https://trycomp.ai.evil.com', defaults)).toBe(false);
});

it('should handle invalid origins gracefully', () => {
expect(isStaticTrustedOriginLogic('not-a-url', defaults)).toBe(false);
});

it('main.ts should use getTrustedOrigins instead of origin: true', () => {
// Validate the CORS config change was made correctly by checking file content
it('main.ts should use isTrustedOrigin for CORS', () => {
const fs = require('fs');
const path = require('path');
const mainTs = fs.readFileSync(
path.join(__dirname, '..', 'main.ts'),
'utf-8',
) as string;
expect(mainTs).not.toContain('origin: true');
expect(mainTs).toContain('origin: getTrustedOrigins()');
expect(mainTs).toContain("import { getTrustedOrigins } from './auth/auth.server'");
expect(mainTs).toContain('isTrustedOrigin');
expect(mainTs).toContain("import { isTrustedOrigin } from './auth/auth.server'");
});
});
88 changes: 88 additions & 0 deletions apps/api/src/auth/auth.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
} from 'better-auth/plugins';
import { ac, allRoles } from '@trycompai/auth';
import { createAuthMiddleware } from 'better-auth/api';
import { Redis } from '@upstash/redis';

const MAGIC_LINK_EXPIRES_IN_SECONDS = 60 * 60; // 1 hour

Expand Down Expand Up @@ -56,6 +57,93 @@ export function getTrustedOrigins(): string[] {
];
}

/**
* Check if an origin matches a known trusted pattern (static list + subdomains).
* This is a fast synchronous check that doesn't hit the DB.
*/
export function isStaticTrustedOrigin(origin: string): boolean {
const trustedOrigins = getTrustedOrigins();
if (trustedOrigins.includes(origin)) {
return true;
}

try {
const url = new URL(origin);
return (
url.hostname.endsWith('.trycomp.ai') ||
url.hostname.endsWith('.staging.trycomp.ai') ||
url.hostname.endsWith('.trust.inc') ||
url.hostname === 'trust.inc'
);
} catch {
return false;
}
}

// ── Custom domain lookup via Redis cache ─────────────────────────────────────

const CORS_DOMAINS_CACHE_KEY = 'cors:custom-domains';
const CORS_DOMAINS_CACHE_TTL_SECONDS = 5 * 60; // 5 minutes

const corsRedisClient = new Redis({
url: process.env.UPSTASH_REDIS_REST_URL!,
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
});

async function getCustomDomains(): Promise<Set<string>> {
try {
// Try Redis cache first
const cached = await corsRedisClient.get<string[]>(CORS_DOMAINS_CACHE_KEY);
if (cached) {
return new Set(cached);
}

// Cache miss — query DB and store in Redis
const trusts = await db.trust.findMany({
where: {
domain: { not: null },
domainVerified: true,
status: 'published',
},
select: { domain: true },
});

const domains = trusts
.map((t) => t.domain)
.filter((d): d is string => d !== null);

await corsRedisClient.set(CORS_DOMAINS_CACHE_KEY, domains, {
ex: CORS_DOMAINS_CACHE_TTL_SECONDS,
});

return new Set(domains);
} catch (error) {
console.error('[CORS] Failed to fetch custom domains:', error);
return new Set();
}
}

/**
* Check if an origin is trusted. Checks (in order):
* 1. Static trusted origins list
* 2. *.trycomp.ai / *.trust.inc subdomains
* 3. Verified custom domains from the DB (cached in Redis, TTL 5 min)
*/
export async function isTrustedOrigin(origin: string): Promise<boolean> {
if (isStaticTrustedOrigin(origin)) {
return true;
}

// Check verified custom domains from DB via Redis cache
try {
const url = new URL(origin);
const customDomains = await getCustomDomains();
return customDomains.has(url.hostname);
} catch {
return false;
}
}

// Build social providers config
const socialProviders: Record<string, unknown> = {};

Expand Down
46 changes: 34 additions & 12 deletions apps/api/src/auth/origin-check.middleware.spec.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
import { originCheckMiddleware } from './origin-check.middleware';

// Mock getTrustedOrigins
// Mock isTrustedOrigin (async version)
jest.mock('./auth.server', () => ({
getTrustedOrigins: () => [
'http://localhost:3000',
'http://localhost:3002',
'https://app.trycomp.ai',
'https://portal.trycomp.ai',
],
isTrustedOrigin: async (origin: string) => {
const staticOrigins = [
'http://localhost:3000',
'http://localhost:3002',
'https://app.trycomp.ai',
'https://portal.trycomp.ai',
];
if (staticOrigins.includes(origin)) return true;
try {
const url = new URL(origin);
return (
url.hostname.endsWith('.trycomp.ai') ||
url.hostname.endsWith('.staging.trycomp.ai') ||
url.hostname.endsWith('.trust.inc') ||
url.hostname === 'trust.inc'
);
} catch {
return false;
}
},
}));

function createMockReq(
Expand All @@ -22,6 +36,9 @@ function createMockReq(
};
}

/** Flush the microtask queue so async middleware completes. */
const flushPromises = () => new Promise((resolve) => setImmediate(resolve));

function createMockRes(): Record<string, unknown> & { statusCode?: number; body?: unknown } {
const res: Record<string, unknown> & { statusCode?: number; body?: unknown } = {};
res.status = jest.fn().mockImplementation((code: number) => {
Expand Down Expand Up @@ -66,44 +83,48 @@ describe('originCheckMiddleware', () => {
expect(next).toHaveBeenCalled();
});

it('should allow POST from trusted origin', () => {
it('should allow POST from trusted origin', async () => {
const req = createMockReq('POST', '/v1/organization/api-keys', 'http://localhost:3000');
const res = createMockRes();
const next = jest.fn();

originCheckMiddleware(req as any, res as any, next);
await flushPromises();

expect(next).toHaveBeenCalled();
});

it('should block POST from untrusted origin', () => {
it('should block POST from untrusted origin', async () => {
const req = createMockReq('POST', '/v1/organization/transfer-ownership', 'http://evil.com');
const res = createMockRes();
const next = jest.fn();

originCheckMiddleware(req as any, res as any, next);
await flushPromises();

expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});

it('should block DELETE from untrusted origin', () => {
it('should block DELETE from untrusted origin', async () => {
const req = createMockReq('DELETE', '/v1/organization', 'http://evil.com');
const res = createMockRes();
const next = jest.fn();

originCheckMiddleware(req as any, res as any, next);
await flushPromises();

expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});

it('should block PATCH from untrusted origin', () => {
it('should block PATCH from untrusted origin', async () => {
const req = createMockReq('PATCH', '/v1/members/123/role', 'http://evil.com');
const res = createMockRes();
const next = jest.fn();

originCheckMiddleware(req as any, res as any, next);
await flushPromises();

expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
Expand Down Expand Up @@ -139,12 +160,13 @@ describe('originCheckMiddleware', () => {
expect(next).toHaveBeenCalled();
});

it('should allow production origins', () => {
it('should allow production origins', async () => {
const req = createMockReq('POST', '/v1/organization/api-keys', 'https://app.trycomp.ai');
const res = createMockRes();
const next = jest.fn();

originCheckMiddleware(req as any, res as any, next);
await flushPromises();

expect(next).toHaveBeenCalled();
});
Expand Down
29 changes: 18 additions & 11 deletions apps/api/src/auth/origin-check.middleware.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { Request, Response, NextFunction } from 'express';
import { getTrustedOrigins } from './auth.server';
import { isTrustedOrigin } from './auth.server';

const SAFE_METHODS = new Set(['GET', 'HEAD', 'OPTIONS']);

Expand Down Expand Up @@ -52,14 +52,21 @@ export function originCheckMiddleware(
return next();
}

// Validate Origin against trusted origins
const trustedOrigins = getTrustedOrigins();
if (trustedOrigins.includes(origin)) {
return next();
}

res.status(403).json({
statusCode: 403,
message: 'Forbidden',
});
// Validate Origin against trusted origins (includes dynamic subdomains + custom domains)
isTrustedOrigin(origin)
.then((trusted) => {
if (trusted) {
return next();
}
res.status(403).json({
statusCode: 403,
message: 'Forbidden',
});
})
.catch(() => {
res.status(403).json({
statusCode: 403,
message: 'Forbidden',
});
});
}
Loading
Loading