diff --git a/backend/app/api/admin.py b/backend/app/api/admin.py index 8921f6bb..428da39b 100644 --- a/backend/app/api/admin.py +++ b/backend/app/api/admin.py @@ -26,6 +26,7 @@ # ─── Schemas ──────────────────────────────────────────── + class CompanyStats(BaseModel): id: uuid.UUID name: str @@ -43,6 +44,14 @@ class CompanyStats(BaseModel): class CompanyCreateRequest(BaseModel): name: str = Field(min_length=1, max_length=200) + slug: str | None = None + + +class CompanyUpdateRequest(BaseModel): + name: str | None = None + slug: str | None = None + sso_enabled: bool | None = None + sso_domain: str | None = None class CompanyCreateResponse(BaseModel): @@ -62,6 +71,7 @@ class PlatformSettingsUpdate(BaseModel): # ─── Company Management ──────────────────────────────── + @router.get("/companies", response_model=list[CompanyStats]) async def list_companies( current_user: User = Depends(require_role("platform_admin")), @@ -75,30 +85,22 @@ async def list_companies( tid = tenant.id # User count - uc = await db.execute( - select(sqla_func.count()).select_from(User).where(User.tenant_id == tid) - ) + uc = await db.execute(select(sqla_func.count()).select_from(User).where(User.tenant_id == tid)) user_count = uc.scalar() or 0 # Agent count - ac = await db.execute( - select(sqla_func.count()).select_from(Agent).where(Agent.tenant_id == tid) - ) + ac = await db.execute(select(sqla_func.count()).select_from(Agent).where(Agent.tenant_id == tid)) agent_count = ac.scalar() or 0 # Running agents rc = await db.execute( - select(sqla_func.count()).select_from(Agent).where( - Agent.tenant_id == tid, Agent.status == "running" - ) + select(sqla_func.count()).select_from(Agent).where(Agent.tenant_id == tid, Agent.status == "running") ) agent_running = rc.scalar() or 0 # Total tokens tc = await db.execute( - select(sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0)).where( - Agent.tenant_id == tid - ) + select(sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0)).where(Agent.tenant_id == tid) ) total_tokens = tc.scalar() or 0 @@ -112,20 +114,22 @@ async def list_companies( ) org_admin_email = admin_q.scalar() - result.append(CompanyStats( - id=tenant.id, - name=tenant.name, - slug=tenant.slug, - is_active=tenant.is_active, - sso_enabled=tenant.sso_enabled, - sso_domain=tenant.sso_domain, - created_at=tenant.created_at, - user_count=user_count, - agent_count=agent_count, - agent_running_count=agent_running, - total_tokens=total_tokens, - org_admin_email=org_admin_email, - )) + result.append( + CompanyStats( + id=tenant.id, + name=tenant.name, + slug=tenant.slug, + is_active=tenant.is_active, + sso_enabled=tenant.sso_enabled, + sso_domain=tenant.sso_domain, + created_at=tenant.created_at, + user_count=user_count, + agent_count=agent_count, + agent_running_count=agent_running, + total_tokens=total_tokens, + org_admin_email=org_admin_email, + ) + ) return result @@ -139,10 +143,15 @@ async def create_company( """Create a new company and generate an admin invitation code (max_uses=1).""" import re - slug = re.sub(r"[^a-z0-9]+", "-", data.name.lower().strip()).strip("-")[:40] - if not slug: - slug = "company" - slug = f"{slug}-{secrets.token_hex(3)}" + # Use provided slug or generate one from name + if data.slug: + slug = re.sub(r"[^a-z0-9]+", "-", data.slug.lower().strip()).strip("-")[:40] + if not slug: + slug = "company" + else: + slug = re.sub(r"[^a-z0-9]+", "-", data.name.lower().strip()).strip("-")[:40] + if not slug: + slug = "company" tenant = Tenant(name=data.name, slug=slug, im_provider="web_only") db.add(tenant) @@ -188,9 +197,7 @@ async def toggle_company( # When disabling: pause all running agents if not new_state: - agents = await db.execute( - select(Agent).where(Agent.tenant_id == company_id, Agent.status == "running") - ) + agents = await db.execute(select(Agent).where(Agent.tenant_id == company_id, Agent.status == "running")) for agent in agents.scalars().all(): agent.status = "paused" @@ -198,11 +205,38 @@ async def toggle_company( return {"ok": True, "is_active": new_state} +@router.put("/companies/{company_id}") +async def update_company( + company_id: uuid.UUID, + data: CompanyUpdateRequest, + current_user: User = Depends(require_role("platform_admin")), + db: AsyncSession = Depends(get_db), +): + """Update a company's settings including SSO configuration.""" + result = await db.execute(select(Tenant).where(Tenant.id == company_id)) + tenant = result.scalar_one_or_none() + if not tenant: + raise HTTPException(status_code=404, detail="Company not found") + + if data.name is not None: + tenant.name = data.name + if data.slug is not None: + tenant.slug = data.slug + if data.sso_enabled is not None: + tenant.sso_enabled = data.sso_enabled + if data.sso_domain is not None: + tenant.sso_domain = data.sso_domain if data.sso_domain.strip() else None + + await db.flush() + return {"ok": True} + + # ─── Platform Metrics Dashboard ───────────────────────── from typing import Any from fastapi import Query + @router.get("/metrics/timeseries", response_model=list[dict[str, Any]]) async def get_platform_timeseries( start_date: datetime, @@ -222,50 +256,37 @@ async def get_platform_timeseries( # 1. New Companies per day companies_q = await db.execute( - select( - cast(Tenant.created_at, Date).label('d'), - sqla_func.count().label('c') - ).where( - Tenant.created_at >= start_date, - Tenant.created_at <= end_date - ).group_by('d') + select(cast(Tenant.created_at, Date).label("d"), sqla_func.count().label("c")) + .where(Tenant.created_at >= start_date, Tenant.created_at <= end_date) + .group_by("d") ) companies_by_day = {row.d: row.c for row in companies_q.all()} # 2. New Users per day users_q = await db.execute( - select( - cast(User.created_at, Date).label('d'), - sqla_func.count().label('c') - ).where( - User.created_at >= start_date, - User.created_at <= end_date - ).group_by('d') + select(cast(User.created_at, Date).label("d"), sqla_func.count().label("c")) + .where(User.created_at >= start_date, User.created_at <= end_date) + .group_by("d") ) users_by_day = {row.d: row.c for row in users_q.all()} # 3. Tokens consumed per day tokens_q = await db.execute( - select( - cast(DailyTokenUsage.date, Date).label('d'), - sqla_func.sum(DailyTokenUsage.tokens_used).label('c') - ).where( - DailyTokenUsage.date >= start_date, - DailyTokenUsage.date <= end_date - ).group_by('d') + select(cast(DailyTokenUsage.date, Date).label("d"), sqla_func.sum(DailyTokenUsage.tokens_used).label("c")) + .where(DailyTokenUsage.date >= start_date, DailyTokenUsage.date <= end_date) + .group_by("d") ) tokens_by_day = {row.d: row.c for row in tokens_q.all()} # 4. New Sessions per day (DAU = distinct users with sessions that day) sessions_q = await db.execute( select( - cast(ChatSession.created_at, Date).label('d'), - sqla_func.count().label('sessions'), - sqla_func.count(sqla_func.distinct(ChatSession.user_id)).label('dau'), - ).where( - ChatSession.created_at >= start_date, - ChatSession.created_at <= end_date - ).group_by('d') + cast(ChatSession.created_at, Date).label("d"), + sqla_func.count().label("sessions"), + sqla_func.count(sqla_func.distinct(ChatSession.user_id)).label("dau"), + ) + .where(ChatSession.created_at >= start_date, ChatSession.created_at <= end_date) + .group_by("d") ) sessions_by_day = {} dau_by_day = {} @@ -275,7 +296,8 @@ async def get_platform_timeseries( # 5. WAU/MAU: for each day, count distinct users in rolling 7/30-day window. # Use a single SQL query with window functions for efficiency. - wau_mau_q = await db.execute(text(""" + wau_mau_q = await db.execute( + text(""" WITH daily_users AS ( SELECT DISTINCT DATE(created_at) AS d, @@ -299,12 +321,14 @@ async def get_platform_timeseries( WHERE du.d BETWEEN ds.d - 29 AND ds.d) AS mau FROM day_series ds ORDER BY ds.d - """), { - "range_start": start_date - timedelta(days=30), - "range_end": end_date, - "series_start": start_date.date(), - "series_end": end_date.date(), - }) + """), + { + "range_start": start_date - timedelta(days=30), + "range_end": end_date, + "series_start": start_date.date(), + "series_end": end_date.date(), + }, + ) wau_by_day = {} mau_by_day = {} for row in wau_mau_q.all(): @@ -317,10 +341,20 @@ async def get_platform_timeseries( end_d = end_date.date() # Cumulative totals up to start_date - total_companies = (await db.execute(select(sqla_func.count()).select_from(Tenant).where(Tenant.created_at < start_date))).scalar() or 0 - total_users = (await db.execute(select(sqla_func.count()).select_from(User).where(User.created_at < start_date))).scalar() or 0 - total_tokens = (await db.execute(select(sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0)).where(Agent.created_at < start_date))).scalar() or 0 - total_sessions = (await db.execute(select(sqla_func.count()).select_from(ChatSession).where(ChatSession.created_at < start_date))).scalar() or 0 + total_companies = ( + await db.execute(select(sqla_func.count()).select_from(Tenant).where(Tenant.created_at < start_date)) + ).scalar() or 0 + total_users = ( + await db.execute(select(sqla_func.count()).select_from(User).where(User.created_at < start_date)) + ).scalar() or 0 + total_tokens = ( + await db.execute( + select(sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0)).where(Agent.created_at < start_date) + ) + ).scalar() or 0 + total_sessions = ( + await db.execute(select(sqla_func.count()).select_from(ChatSession).where(ChatSession.created_at < start_date)) + ).scalar() or 0 while current_d <= end_d: nc = companies_by_day.get(current_d, 0) @@ -333,21 +367,23 @@ async def get_platform_timeseries( total_tokens += nt total_sessions += ns - result.append({ - "date": current_d.isoformat(), - "new_companies": nc, - "total_companies": total_companies, - "new_users": nu, - "total_users": total_users, - "new_tokens": nt, - "total_tokens": total_tokens, - # New metrics - "new_sessions": ns, - "total_sessions": total_sessions, - "dau": dau_by_day.get(current_d, 0), - "wau": wau_by_day.get(current_d, 0), - "mau": mau_by_day.get(current_d, 0), - }) + result.append( + { + "date": current_d.isoformat(), + "new_companies": nc, + "total_companies": total_companies, + "new_users": nu, + "total_users": total_users, + "new_tokens": nt, + "total_tokens": total_tokens, + # New metrics + "new_sessions": ns, + "total_sessions": total_sessions, + "dau": dau_by_day.get(current_d, 0), + "wau": wau_by_day.get(current_d, 0), + "mau": mau_by_day.get(current_d, 0), + } + ) current_d += timedelta(days=1) return result @@ -361,7 +397,7 @@ async def get_platform_leaderboards( """Get Top 20 token consuming companies and agents.""" # Top 20 Companies by total tokens top_companies_q = await db.execute( - select(Tenant.name, sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0).label('total')) + select(Tenant.name, sqla_func.coalesce(sqla_func.sum(Agent.tokens_used_total), 0).label("total")) .join(Agent, Agent.tenant_id == Tenant.id) .group_by(Tenant.id) .order_by(sqla_func.sum(Agent.tokens_used_total).desc()) @@ -371,17 +407,16 @@ async def get_platform_leaderboards( # Top 20 Agents by total tokens top_agents_q = await db.execute( - select(Agent.name, Tenant.name.label('tenant_name'), Agent.tokens_used_total) + select(Agent.name, Tenant.name.label("tenant_name"), Agent.tokens_used_total) .join(Tenant, Tenant.id == Agent.tenant_id) .order_by(Agent.tokens_used_total.desc()) .limit(20) ) - top_agents = [{"name": row.name, "company": row.tenant_name, "tokens": row.tokens_used_total} for row in top_agents_q.all()] + top_agents = [ + {"name": row.name, "company": row.tenant_name, "tokens": row.tokens_used_total} for row in top_agents_q.all() + ] - return { - "top_companies": top_companies, - "top_agents": top_agents - } + return {"top_companies": top_companies, "top_agents": top_agents} @router.get("/metrics/enhanced") @@ -403,20 +438,25 @@ async def get_enhanced_metrics( # Sum of daily_token_usage / count of chat_sessions in last 30 days thirty_days_ago = now - timedelta(days=30) from app.models.activity_log import DailyTokenUsage - total_tok_30d = (await db.execute( - select(sqla_func.coalesce(sqla_func.sum(DailyTokenUsage.tokens_used), 0)) - .where(DailyTokenUsage.date >= thirty_days_ago) - )).scalar() or 0 - total_sess_30d = (await db.execute( - select(sqla_func.count()) - .select_from(ChatSession) - .where(ChatSession.created_at >= thirty_days_ago) - )).scalar() or 1 # avoid div by zero + + total_tok_30d = ( + await db.execute( + select(sqla_func.coalesce(sqla_func.sum(DailyTokenUsage.tokens_used), 0)).where( + DailyTokenUsage.date >= thirty_days_ago + ) + ) + ).scalar() or 0 + total_sess_30d = ( + await db.execute( + select(sqla_func.count()).select_from(ChatSession).where(ChatSession.created_at >= thirty_days_ago) + ) + ).scalar() or 1 # avoid div by zero avg_tokens_per_session = round(total_tok_30d / max(total_sess_30d, 1)) # ── 2. 7-Day Retention Rate (excluding companies <14 days old) ── # Last week = 14..7 days ago, This week = 7..0 days ago - retention_q = await db.execute(text(""" + retention_q = await db.execute( + text(""" WITH established AS ( SELECT id FROM tenants WHERE created_at < NOW() - INTERVAL '14 days' ), @@ -440,7 +480,8 @@ async def get_enhanced_metrics( WHERE lw.tenant_id IN (SELECT tenant_id FROM this_week_active) ) AS retained FROM last_week_active lw - """)) + """) + ) ret_row = retention_q.first() last_week_total = ret_row[0] if ret_row else 0 retained = ret_row[1] if ret_row else 0 @@ -448,38 +489,28 @@ async def get_enhanced_metrics( # ── 3. Channel Distribution (last 30 days) ── channel_q = await db.execute( - select( - ChatSession.source_channel, - sqla_func.count().label('count') - ).where( - ChatSession.created_at >= thirty_days_ago - ).group_by(ChatSession.source_channel) + select(ChatSession.source_channel, sqla_func.count().label("count")) + .where(ChatSession.created_at >= thirty_days_ago) + .group_by(ChatSession.source_channel) .order_by(sqla_func.count().desc()) ) - channel_distribution = [ - {"channel": row.source_channel, "count": row.count} - for row in channel_q.all() - ] + channel_distribution = [{"channel": row.source_channel, "count": row.count} for row in channel_q.all()] # ── 4. Top 10 Tool Categories ── # Count enabled agent_tools grouped by tool category tool_q = await db.execute( - select( - Tool.category, - sqla_func.count().label('count') - ).join(AgentTool, AgentTool.tool_id == Tool.id) + select(Tool.category, sqla_func.count().label("count")) + .join(AgentTool, AgentTool.tool_id == Tool.id) .where(AgentTool.enabled == True) # noqa: E712 .group_by(Tool.category) .order_by(sqla_func.count().desc()) .limit(10) ) - tool_category_top10 = [ - {"category": row.category or "uncategorized", "count": row.count} - for row in tool_q.all() - ] + tool_category_top10 = [{"category": row.category or "uncategorized", "count": row.count} for row in tool_q.all()] # ── 5. Churn Warnings (>10M tokens, 14+ days inactive) ── - churn_q = await db.execute(text(""" + churn_q = await db.execute( + text(""" SELECT t.name, SUM(a.tokens_used_total) AS total_tokens, @@ -495,15 +526,18 @@ async def get_enhanced_metrics( OR MAX(cs.created_at) < NOW() - INTERVAL '14 days' ) ORDER BY SUM(a.tokens_used_total) DESC - """)) + """) + ) churn_warnings = [] for row in churn_q.all(): - churn_warnings.append({ - "name": row[0], - "total_tokens": row[1], - "last_active": row[2].isoformat() if row[2] else None, - "days_inactive": row[3] if row[3] else None, - }) + churn_warnings.append( + { + "name": row[0], + "total_tokens": row[1], + "last_active": row[2].isoformat() if row[2] else None, + "days_inactive": row[3] if row[3] else None, + } + ) return { "avg_tokens_per_session_30d": avg_tokens_per_session, @@ -518,6 +552,7 @@ async def get_enhanced_metrics( # ─── Platform Settings ───────────────────────────────── + @router.get("/platform-settings", response_model=PlatformSettingsOut) async def get_platform_settings( current_user: User = Depends(require_role("platform_admin")), diff --git a/backend/app/api/tenants.py b/backend/app/api/tenants.py index 0bb4e0a7..91031013 100644 --- a/backend/app/api/tenants.py +++ b/backend/app/api/tenants.py @@ -24,10 +24,13 @@ # ─── Schemas ──────────────────────────────────────────── + class TenantCreate(BaseModel): name: str = Field(min_length=1, max_length=200) + slug: str | None = None target_tenant_id: uuid.UUID | None = None + class TenantOut(BaseModel): id: uuid.UUID name: str @@ -53,6 +56,7 @@ class TenantUpdate(BaseModel): # ─── Helpers ──────────────────────────────────────────── + def _slugify(name: str) -> str: """Generate a URL-friendly slug from a company name.""" # Replace CJK and non-alphanumeric chars with hyphens @@ -67,6 +71,7 @@ def _slugify(name: str) -> str: class SelfCreateResponse(BaseModel): """Response for self-create company, includes token for context switching.""" + tenant: TenantOut access_token: str | None = None # Non-null when a new User record was created (multi-tenant switch) @@ -85,23 +90,38 @@ async def self_create_company( """ # Block self-creation if locked to a specific tenant (Dedicated Link flow) if data.target_tenant_id is not None: - raise HTTPException(status_code=403, detail="Company creation is not allowed via this link. Please join your assigned organization.") + raise HTTPException( + status_code=403, + detail="Company creation is not allowed via this link. Please join your assigned organization.", + ) # Check if self-creation is allowed from app.models.system_settings import SystemSetting - setting = await db.execute( - select(SystemSetting).where(SystemSetting.key == "allow_self_create_company") - ) + + setting = await db.execute(select(SystemSetting).where(SystemSetting.key == "allow_self_create_company")) s = setting.scalar_one_or_none() allowed = s.value.get("enabled", True) if s else True if not allowed and current_user.role != "platform_admin": raise HTTPException(status_code=403, detail="Company self-creation is currently disabled") - slug = _slugify(data.name) + if data.slug: + import re + + slug = re.sub(r"[^a-z0-9]+", "-", data.slug.lower().strip()).strip("-")[:40] + if not slug: + slug = "company" + else: + slug = _slugify(data.name) tenant = Tenant(name=data.name, slug=slug, im_provider="web_only") db.add(tenant) await db.flush() + from app.services.platform_service import platform_service + + sso_base = await platform_service.get_tenant_sso_base_url(db, tenant) + tenant.sso_domain = sso_base + await db.flush() + access_token = None if current_user.tenant_id is not None: @@ -126,12 +146,14 @@ async def self_create_company( await db.flush() # Create Participant for the new user record - db.add(Participant( - type="user", - ref_id=new_user.id, - display_name=new_user.display_name, - avatar_url=new_user.avatar_url, - )) + db.add( + Participant( + type="user", + ref_id=new_user.id, + display_name=new_user.display_name, + avatar_url=new_user.avatar_url, + ) + ) await db.flush() # Generate token scoped to the new user so frontend can switch context @@ -155,6 +177,7 @@ async def self_create_company( # ─── Self-Service: Join Company via Invite Code ───────── + class JoinRequest(BaseModel): invitation_code: str = Field(min_length=1, max_length=32) target_tenant_id: uuid.UUID | None = None @@ -178,6 +201,7 @@ async def join_company( - Registration flow (user has no tenant yet): assigns tenant directly - Switch-org flow (user already has a tenant): creates a new User record""" from app.models.invitation_code import InvitationCode + ic_result = await db.execute( select(InvitationCode).where( InvitationCode.code == data.invitation_code, @@ -191,7 +215,9 @@ async def join_company( # Verify matching tenant if locked (Dedicated Link flow) if data.target_tenant_id and str(code_obj.tenant_id) != str(data.target_tenant_id): - raise HTTPException(status_code=403, detail="This invitation code does not belong to the required organization.") + raise HTTPException( + status_code=403, detail="This invitation code does not belong to the required organization." + ) if code_obj.used_count >= code_obj.max_uses: raise HTTPException(status_code=400, detail="Invitation code has reached its usage limit") @@ -214,7 +240,9 @@ async def join_company( # Check if this company has an org_admin already admin_check = await db.execute( - select(sqla_func.count()).select_from(User).where( + select(sqla_func.count()) + .select_from(User) + .where( User.tenant_id == tenant.id, User.role.in_(["org_admin", "platform_admin"]), ) @@ -248,12 +276,14 @@ async def join_company( await db.flush() # Create Participant for the new user record - db.add(Participant( - type="user", - ref_id=new_user.id, - display_name=new_user.display_name, - avatar_url=new_user.avatar_url, - )) + db.add( + Participant( + type="user", + ref_id=new_user.id, + display_name=new_user.display_name, + avatar_url=new_user.avatar_url, + ) + ) await db.flush() # Generate token scoped to the new user so frontend can switch context @@ -284,13 +314,13 @@ async def join_company( # ─── Registration Config ─────────────────────────────── + @router.get("/registration-config") async def get_registration_config(db: AsyncSession = Depends(get_db)): """Public — returns whether self-creation of companies is allowed.""" from app.models.system_settings import SystemSetting - result = await db.execute( - select(SystemSetting).where(SystemSetting.key == "allow_self_create_company") - ) + + result = await db.execute(select(SystemSetting).where(SystemSetting.key == "allow_self_create_company")) s = result.scalar_one_or_none() allowed = s.value.get("enabled", True) if s else True return {"allow_self_create_company": allowed} @@ -298,6 +328,7 @@ async def get_registration_config(db: AsyncSession = Depends(get_db)): # ─── Public: Resolve Tenant by Domain ─────────────────── + @router.get("/resolve-by-domain") async def resolve_tenant_by_domain( domain: str, @@ -317,9 +348,7 @@ async def resolve_tenant_by_domain( # 1. Match by stripping protocol from stored sso_domain # sso_domain = "https://acme.clawith.ai" → compare against "acme.clawith.ai" for proto in ("https://", "http://"): - result = await db.execute( - select(Tenant).where(Tenant.sso_domain == f"{proto}{domain}") - ) + result = await db.execute(select(Tenant).where(Tenant.sso_domain == f"{proto}{domain}")) tenant = result.scalar_one_or_none() if tenant: break @@ -328,9 +357,7 @@ async def resolve_tenant_by_domain( if not tenant and ":" in domain: domain_no_port = domain.split(":")[0] for proto in ("https://", "http://"): - result = await db.execute( - select(Tenant).where(Tenant.sso_domain.like(f"{proto}{domain_no_port}%")) - ) + result = await db.execute(select(Tenant).where(Tenant.sso_domain.like(f"{proto}{domain_no_port}%"))) tenant = result.scalar_one_or_none() if tenant: break @@ -338,6 +365,7 @@ async def resolve_tenant_by_domain( # 3. Fallback: extract slug from subdomain pattern if not tenant: import re + m = re.match(r"^([a-z0-9][a-z0-9\-]*[a-z0-9])\.clawith\.ai$", domain.lower()) if m: slug = m.group(1) @@ -356,8 +384,10 @@ async def resolve_tenant_by_domain( "is_active": tenant.is_active, } + # ─── Authenticated: List / Get ────────────────────────── + @router.get("/", response_model=list[TenantOut]) async def list_tenants( current_user: User = Depends(require_role("platform_admin")), @@ -402,7 +432,7 @@ async def update_tenant( raise HTTPException(status_code=404, detail="Tenant not found") update_data = data.model_dump(exclude_unset=True) - + # SSO configuration is managed exclusively by the company's own org_admin # via the Enterprise Settings page. Platform admins should not override it here. if current_user.role == "platform_admin": diff --git a/backend/app/core/public_url.py b/backend/app/core/public_url.py new file mode 100644 index 00000000..37d1b646 --- /dev/null +++ b/backend/app/core/public_url.py @@ -0,0 +1,56 @@ +"""Utility functions for getting platform public URL.""" + +import os +from urllib.parse import urlparse + + +def get_public_base_url_sync() -> str: + """Get the platform public base URL (sync version - only checks env var). + + For async version with database lookup, use get_public_base_url_async(). + """ + env_url = os.environ.get("PUBLIC_BASE_URL", "").strip() + if env_url: + return env_url.rstrip("/") + return "" + + +async def get_public_base_url_async(db) -> str: + """Get the platform public base URL from database. + + Args: + db: Database session + + Returns: + The public base URL or empty string if not set + """ + try: + from sqlalchemy import select + from app.models.system_settings import SystemSetting + + result = await db.execute(select(SystemSetting).where(SystemSetting.key == "platform")) + setting = result.scalar_one_or_none() + if setting and setting.value: + url = setting.value.get("public_base_url", "") + if url: + return url.rstrip("/") + except Exception: + pass + return "" + + +def get_sso_domain_from_slug(slug: str, public_url: str = "") -> str: + """Generate SSO domain from slug using the platform public URL. + + Args: + slug: The tenant slug (subdomain) + public_url: Optional pre-fetched public URL + + Returns: + Full SSO domain like "slug.example.com" or "slug.example.com:3008" + """ + if public_url: + parsed = urlparse(public_url) + return f"{slug}.{parsed.netloc}" + else: + return f"{slug}.clawith.ai" diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index b317b358..79004501 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any -from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, delete, func, or_, select, update +from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, delete, func, select, update import httpx from loguru import logger @@ -18,7 +18,7 @@ from app.models.identity import IdentityProvider from app.models.org import OrgDepartment, OrgMember -from app.models.user import User, Identity +from app.models.user import User from pypinyin import pinyin, Style from app.core.security import hash_password @@ -310,7 +310,7 @@ async def _ensure_provider(self, db: AsyncSession) -> IdentityProvider: query = query.where(IdentityProvider.tenant_id.is_(None)) result = await db.execute(query) - provider = result.scalars().first() + provider = result.scalar_one_or_none() if not provider: provider = IdentityProvider( @@ -337,7 +337,7 @@ async def _upsert_department( OrgDepartment.provider_id == provider.id, ) ) - existing = result.scalars().first() + existing = result.scalar_one_or_none() now = datetime.now() path = f"{dept.parent_external_id}/{dept.name}" if dept.parent_external_id else dept.name @@ -351,7 +351,7 @@ async def _upsert_department( OrgDepartment.provider_id == provider.id, ) ) - parent_dept = parent_result.scalars().first() + parent_dept = parent_result.scalar_one_or_none() if parent_dept: parent_id = parent_dept.id @@ -402,7 +402,7 @@ async def _upsert_member( OrgDepartment.provider_id == provider.id, ) ) - department = dept_result.scalars().first() + department = dept_result.scalar_one_or_none() if department: break # Fallback: use the department_external_id that was set during fetch_users @@ -413,25 +413,16 @@ async def _upsert_member( OrgDepartment.provider_id == provider.id, ) ) - department = dept_result.scalars().first() - - # Check if exists by unionid or external_id or open_id (any matches), and provider - conditions = [] - if user.unionid: - conditions.append(OrgMember.unionid == user.unionid) - if user.external_id: - conditions.append(OrgMember.external_id == user.external_id) - if user.open_id: - conditions.append(OrgMember.open_id == user.open_id) - - if conditions: - result = await db.execute( - select(OrgMember).where( - OrgMember.provider_id == provider.id, - or_(*conditions) - ) + department = dept_result.scalar_one_or_none() + + # Check if exists by external_id and provider + result = await db.execute( + select(OrgMember).where( + OrgMember.external_id == user.external_id, + OrgMember.provider_id == provider.id, ) - existing_member = result.scalars().first() + ) + existing_member = result.scalar_one_or_none() now = datetime.now() @@ -445,20 +436,20 @@ async def _upsert_member( mobile = _normalize_contact(user.mobile) if email: - user_query = select(User).join(User.identity).where(Identity.email == email) + user_query = select(User).where(User.email.ilike(email)) if self.tenant_id: user_query = user_query.where(User.tenant_id == self.tenant_id) user_res = await db.execute(user_query) - platform_user = user_res.scalars().first() + platform_user = user_res.scalar_one_or_none() if platform_user: user_id = platform_user.id if not user_id and mobile: - user_query = select(User).join(User.identity).where(Identity.phone == mobile) + user_query = select(User).where(User.primary_mobile == mobile) if self.tenant_id: user_query = user_query.where(User.tenant_id == self.tenant_id) user_res = await db.execute(user_query) - platform_user = user_res.scalars().first() + platform_user = user_res.scalar_one_or_none() if platform_user: user_id = platform_user.id @@ -482,8 +473,7 @@ async def _upsert_member( # Universal ID fields existing_member.external_id = user.external_id existing_member.open_id = user.open_id - existing_member.unionid = user.unionid - + existing_member.provider_id = provider.id existing_member.synced_at = now if user_id and not existing_member.user_id: @@ -496,7 +486,6 @@ async def _upsert_member( new_member = OrgMember( external_id=user.external_id, open_id=user.open_id, - unionid=user.unionid, provider_id=provider.id, user_id=user_id, @@ -521,7 +510,7 @@ async def _upsert_member( if not target_user and (user_id or (existing_member and existing_member.user_id)): target_id = user_id or existing_member.user_id user_res = await db.execute(select(User).where(User.id == target_id)) - target_user = user_res.scalars().first() + target_user = user_res.scalar_one_or_none() if target_user: if email and target_user.email != email: @@ -538,18 +527,18 @@ async def _resolve_platform_user(self, db: AsyncSession, user: ExternalUser) -> email = _normalize_contact(user.email) if email: result = await db.execute( - select(User).join(User.identity).where(Identity.email == email) + select(User).where(User.email.ilike(email)) ) - u = result.scalars().first() + u = result.scalar_one_or_none() if u: return u # 2. Try by mobile matching mobile = _normalize_contact(user.mobile) if mobile: result = await db.execute( - select(User).join(User.identity).where(Identity.phone == mobile) + select(User).where(User.primary_mobile == mobile) ) - u = result.scalars().first() + u = result.scalar_one_or_none() if u: return u return None @@ -658,13 +647,149 @@ async def fetch_children(parent_id: str): logger.info(f"Feishu fetched {len(all_depts)} departments total.") return all_depts - async def fetch_users(self, department_external_id: str) -> list[ExternalUser]: - """Fetch users in a department. - - Uses user_id_type=user_id which requires the contact:user.employee_id:readonly - permission. If the Feishu API returns an error due to missing permission, raises - a clear error instructing the user to add the required scope. - """ + async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: + """Override to use global user list API so we can get users regardless of department hierarchy.""" + errors = [] + dept_count = 0 + member_count = 0 + user_count = 0 + profile_count = 0 + sync_start = datetime.now() + + provider = await self._ensure_provider(db) + + try: + # Fetch and sync departments + departments = await self.fetch_departments() + for dept in departments: + try: + async with db.begin_nested(): + await self._upsert_department(db, provider, dept) + dept_count += 1 + except Exception as e: + errors.append(f"Department {dept.external_id}: {str(e)}") + logger.error(f"[OrgSync] Failed to sync department {dept.external_id}: {e}") + + # Fetch ALL users using global user list API (works even without department access) + all_users = await self._fetch_all_users() + logger.info(f"Feishu fetched {len(all_users)} total users globally.") + + for user in all_users: + try: + async with db.begin_nested(): + # Use first department from user's department_ids, fallback to "0" + dept_ext_id = user.department_ids[0] if user.department_ids else "0" + + # Ensure department exists - if not found, create it on the fly + dept_result = await db.execute( + select(OrgDepartment).where( + OrgDepartment.external_id == dept_ext_id, + OrgDepartment.provider_id == provider.id, + ) + ) + dept = dept_result.scalar_one_or_none() + if not dept: + # Check if department exists but was marked deleted + del_result = await db.execute( + select(OrgDepartment).where( + OrgDepartment.external_id == dept_ext_id, + OrgDepartment.provider_id == provider.id, + OrgDepartment.status == "deleted", + ) + ) + dept = del_result.scalar_one_or_none() + fetched_dept_name = None + if dept: + # Reactivate deleted department + dept.status = "active" + dept.synced_at = datetime.now() + if fetched_dept_name: + dept.name = fetched_dept_name + await db.flush() + logger.info(f"[OrgSync] Reactivated deleted department: {dept.external_id} -> {fetched_dept_name or dept.name}") + # Try to fetch real name for reactivated dept + try: + token = await self.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.get( + f"https://open.feishu.cn/open-apis/contact/v3/departments/{dept.external_id}", + params={"department_id_type": "open_department_id"}, + headers={"Authorization": f"Bearer {token}"}, + ) + data = resp.json() + if data.get("code") == 0: + fetched_dept_name = data.get("data", {}).get("department", {}).get("name") + except Exception: + pass + + if not dept: + # Fetch department details from Feishu API + dept_name = fetched_dept_name or f"部门{dept_ext_id[:8]}" + try: + token = await self.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.get( + f"https://open.feishu.cn/open-apis/contact/v3/departments/{dept_ext_id}", + params={"department_id_type": "open_department_id"}, + headers={"Authorization": f"Bearer {token}"}, + ) + data = resp.json() + if data.get("code") == 0: + dept_name = data.get("data", {}).get("department", {}).get("name", dept_name) + except Exception as e: + logger.warning(f"[OrgSync] Failed to fetch dept name for {dept_ext_id}: {e}") + + dept = OrgDepartment( + external_id=dept_ext_id, + provider_id=provider.id, + name=dept_name, + tenant_id=self.tenant_id, + synced_at=datetime.now(), + ) + db.add(dept) + await db.flush() + logger.warning(f"[OrgSync] Auto-created missing department: {dept_ext_id} - {dept_name}") + # Add to departments list so reconciliation doesn't delete it + departments.append(dept) + + stats = await self._upsert_member(db, provider, user, dept_ext_id) + if stats.get("user_created"): + user_count += 1 + if stats.get("profile_synced"): + profile_count += 1 + member_count += 1 + except Exception as e: + logger.error(f"[OrgSync] Failed to sync member {user.external_id} ({user.name}): {e}") + errors.append(f"Member {user.external_id}: {str(e)}") + + # Update provider metadata + if self.provider: + config = (self.provider.config or {}).copy() + config["last_synced_at"] = datetime.now().isoformat() + self.provider.config = config + await db.flush() + await self._reconcile(db, provider.id, sync_start) + await db.flush() + await self._update_member_counts(db, provider.id) + await db.flush() + + except Exception as e: + import traceback + logger.error(f"[OrgSync] Critical error during sync: {e}\n{traceback.format_exc()}") + errors.append(f"Critical: {str(e)}") + + return { + "departments": dept_count, + "members": member_count, + "users_created": user_count, + "profiles_synced": profile_count, + "errors": errors, + "provider": self.provider_type, + "synced_at": datetime.now().isoformat() + } + + async def _fetch_all_users(self) -> list[ExternalUser]: + """Fetch all users from Feishu using global users API.""" token = await self.get_access_token() users: list[ExternalUser] = [] page_token = "" @@ -672,74 +797,57 @@ async def fetch_users(self, department_external_id: str) -> list[ExternalUser]: async with httpx.AsyncClient() as client: while True: params = { - "department_id": department_external_id, - "department_id_type": "open_department_id", - "user_id_type": "user_id", # Requires contact:user.employee_id:readonly "page_size": "50", + "user_id_type": "user_id", + "sort_type": "NameOrder", } if page_token: params["page_token"] = page_token resp = await client.get( - self.FEISHU_USERS_URL, + "https://open.feishu.cn/open-apis/contact/v3/users", params=params, headers={"Authorization": f"Bearer {token}"}, ) data = resp.json() if data.get("code") != 0: - error_code = data.get("code") - error_msg = data.get("msg", "") - logger.error( - f"Feishu fetch users error for dept {department_external_id}: " - f"code={error_code}, msg={error_msg}" - ) - # Raise a user-friendly error for permission issues - raise RuntimeError( - f"Feishu API error (code {error_code}): {error_msg}. " - f"Please ensure the Feishu app has the 'contact:user.employee_id:readonly' " - f"permission enabled. Go to Feishu Open Platform -> App -> Permissions -> " - f"search 'employee_id' -> enable and publish a new version." - ) + logger.error(f"Feishu fetch all users error: {data}") + break res_data = data.get("data", {}) items = res_data.get("items", []) or [] for item in items: - # Collect all departments the user belongs to raw_dept_ids = item.get("department_ids", []) - department_ids = [str(did) for did in raw_dept_ids] if raw_dept_ids else [department_external_id] - - external_id = item.get("user_id", "") or item.get("open_id", "") - - # For Feishu, a user is considered inactive if they are explicitly frozen or resigned. - # Merely not being activated (is_activated=False) shouldn't hide them from the org chart. - feishu_status = item.get("status", {}) - is_frozen = feishu_status.get("is_frozen", False) - is_resigned = feishu_status.get("is_resigned", False) - member_status = "inactive" if (is_frozen or is_resigned) else "active" + department_ids = [str(did) for did in raw_dept_ids] if raw_dept_ids else ["0"] user = ExternalUser( - external_id=external_id, + external_id=item.get("user_id", "") or item.get("open_id", ""), open_id=item.get("open_id", ""), unionid=item.get("union_id", ""), name=item.get("name", ""), email=item.get("email", ""), avatar_url=item.get("avatar_url", ""), title=item.get("title", ""), - department_external_id=department_external_id, + department_external_id=department_ids[0] if department_ids else "0", department_ids=department_ids, mobile=item.get("mobile", ""), - status=member_status, + status="active" if item.get("status", {}).get("is_activated") else "inactive", raw_data=item, ) users.append(user) page_token = res_data.get("page_token", "") - if not page_token: + has_more = res_data.get("has_more", False) + if not has_more or not page_token: break return users + async def fetch_users(self, department_external_id: str) -> list[ExternalUser]: + # Dummy implementation - not used since we override sync_org_structure + return [] + class DingTalkOrgSyncAdapter(BaseOrgSyncAdapter): """DingTalk organization sync adapter.""" diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 00000000..66fba5cc --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1 @@ +pypinyin diff --git a/docker-compose.yml b/docker-compose.yml index 4c84de15..ffb03a3e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -67,12 +67,16 @@ services: max-size: "10m" max-file: "3" frontend: - build: ./frontend + build: + context: ./frontend + args: + VITE_PUBLIC_URL: ${VITE_PUBLIC_URL:-} restart: unless-stopped ports: - "${FRONTEND_PORT:-3008}:3000" environment: VITE_API_URL: http://localhost:8000 + VITE_PUBLIC_URL: ${VITE_PUBLIC_URL:-} volumes: - ./frontend/src:/app/src - ./frontend/public:/app/public diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 7d8b64e1..2defa246 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -3,6 +3,8 @@ WORKDIR /app COPY package*.json ./ RUN npm ci --registry https://registry.npmmirror.com COPY . . +ARG VITE_PUBLIC_URL +ENV VITE_PUBLIC_URL=$VITE_PUBLIC_URL RUN npm run build FROM nginx:alpine diff --git a/frontend/nginx.conf b/frontend/nginx.conf index e82afab7..4c174276 100644 --- a/frontend/nginx.conf +++ b/frontend/nginx.conf @@ -1,6 +1,6 @@ server { listen 3000; - server_name localhost; + server_name _; root /usr/share/nginx/html; index index.html; diff --git a/frontend/src/pages/AdminCompanies.tsx b/frontend/src/pages/AdminCompanies.tsx index 884b0f62..cacc7b48 100644 --- a/frontend/src/pages/AdminCompanies.tsx +++ b/frontend/src/pages/AdminCompanies.tsx @@ -668,10 +668,14 @@ function CompaniesTab() { // Create company const [showCreate, setShowCreate] = useState(false); const [newName, setNewName] = useState(''); + const [newSlug, setNewSlug] = useState(''); const [creating, setCreating] = useState(false); const [createdCode, setCreatedCode] = useState(''); const [createdCompanyName, setCreatedCompanyName] = useState(''); + // Edit company + const [editingCompany, setEditingCompany] = useState(null); + // Toast const [toast, setToast] = useState<{ msg: string; type: 'success' | 'error' } | null>(null); const showToast = (msg: string, type: 'success' | 'error' = 'success') => { @@ -734,10 +738,11 @@ function CompaniesTab() { if (!newName.trim()) return; setCreating(true); try { - const result = await adminApi.createCompany({ name: newName.trim() }); + const result = await adminApi.createCompany({ name: newName.trim(), slug: newSlug.trim() || undefined }); setCreatedCompanyName(newName.trim()); setCreatedCode(result.admin_invitation_code || ''); setNewName(''); + setNewSlug(''); setShowCreate(false); loadCompanies(); } catch (e: any) { @@ -888,17 +893,26 @@ function CompaniesTab() {
{t('admin.createCompany', 'Create Company')}
-
- setNewName(e.target.value)} - placeholder={t('admin.companyNamePlaceholder', 'Company name')} - onKeyDown={e => e.key === 'Enter' && handleCreate()} - style={{ flex: 1 }} autoFocus /> - - +
+
+ setNewName(e.target.value)} + placeholder={t('admin.companyNamePlaceholder', 'Company name')} + onKeyDown={e => e.key === 'Enter' && handleCreate()} + style={{ flex: 1 }} autoFocus /> + setNewSlug(e.target.value)} + placeholder={t('admin.slugPlaceholder', 'Custom slug (optional)')} + onKeyDown={e => e.key === 'Enter' && handleCreate()} + style={{ flex: 1 }} /> + + +
+
+ {t('admin.slugHelp', 'Slug is used for SSO domain: slug.bigbear.cool. Leave empty to auto-generate from company name.')} +
)} @@ -1023,6 +1037,16 @@ function CompaniesTab() {
+
)} + + {/* Edit Company Modal */} + {editingCompany && ( + setEditingCompany(null)} + onUpdated={loadCompanies} + /> + )} ); @@ -1085,6 +1118,8 @@ function CompaniesTab() { // ─── Edit Company Modal ─────────────────────────────── function EditCompanyModal({ company, onClose, onUpdated }: { company: any, onClose: () => void, onUpdated: () => void }) { const { t } = useTranslation(); + const [name, setName] = useState(company.name || ''); + const [slug, setSlug] = useState(company.slug || ''); const [ssoEnabled, setSsoEnabled] = useState(!!company.sso_enabled); const [ssoDomain, setSsoDomain] = useState(company.sso_domain || ''); const [saving, setSaving] = useState(false); @@ -1095,6 +1130,8 @@ function EditCompanyModal({ company, onClose, onUpdated }: { company: any, onClo setError(''); try { await adminApi.updateCompany(company.id, { + name: name.trim() || undefined, + slug: slug.trim() || undefined, sso_enabled: ssoEnabled, sso_domain: ssoDomain.trim() || null, }); @@ -1113,12 +1150,12 @@ function EditCompanyModal({ company, onClose, onUpdated }: { company: any, onClo backdropFilter: 'blur(4px)', }} onClick={onClose}>
e.stopPropagation()}>

- {t('admin.editCompany', 'Edit Company')}: {company.name} + {t('admin.editCompany', 'Edit Company')}

-

- {t('admin.ssoConfigTitle', 'SSO & Domain Configuration')} -

-

- {t('admin.ssoConfigDesc', 'Configure SSO and custom domain for this company.')} -

+
+

+ {t('admin.basicInfo', 'Basic Information')} +

+
+ + setName(e.target.value)} + style={{ fontSize: '13px' }} + /> +
+
+ + setSlug(e.target.value)} + placeholder={t('admin.slugPlaceholder', 'e.g. acme')} + style={{ fontSize: '13px' }} + /> +
+ {t('admin.slugHelp', 'Used for SSO domain: slug.bigbear.cool')} +
+
+
+

+ {t('admin.ssoConfigTitle', 'SSO & Domain Configuration')} +