diff --git a/src/core/analyzer.ts b/src/core/analyzer.ts index 2bbc4ac..5822052 100644 --- a/src/core/analyzer.ts +++ b/src/core/analyzer.ts @@ -8,6 +8,7 @@ import { collectRecognizedNames, collectStringVariables, decoratorExtractor, + factoryCallExtractor, getNodesByType, importExtractor, includeRouterExtractor, @@ -46,6 +47,10 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis { // Get all router assignments const assignments = nodesByType.get("assignment") ?? [] const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType) + const knownConstructors = new Set([...fastAPINames, ...apiRouterNames]) + const factoryCalls = assignments + .map((node) => factoryCallExtractor(node, knownConstructors)) + .filter(notNull) const routers = assignments .map((node) => routerExtractor(node, apiRouterNames, fastAPINames)) .filter(notNull) @@ -84,6 +89,7 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis { includeRouters, mounts, imports, + factoryCalls, } } diff --git a/src/core/extractors.ts b/src/core/extractors.ts index c306b6d..a9cc2df 100644 --- a/src/core/extractors.ts +++ b/src/core/extractors.ts @@ -4,6 +4,7 @@ import type { Node } from "web-tree-sitter" import type { + FactoryCallInfo, ImportedName, ImportInfo, IncludeRouterInfo, @@ -582,3 +583,41 @@ export function mountExtractor(node: Node): MountInfo | null { app: appNode?.text ?? "", } } + +export function factoryCallExtractor( + node: Node, + knownConstructors: Set, +): FactoryCallInfo | null { + if (node.type !== "assignment") { + return null + } + + const variableNameNode = node.childForFieldName("left") + const valueNode = node.childForFieldName("right") + if (!variableNameNode || valueNode?.type !== "call") { + return null + } + + const functionNode = valueNode.childForFieldName("function") + if (functionNode?.type !== "identifier") { + return null + } + + const functionName = functionNode.text + if (knownConstructors.has(functionName)) { + return null + } + + // Skip function and class-local variables to avoid false positives + if ( + hasAncestor(node, "function_definition") || + hasAncestor(node, "class_definition") + ) { + return null + } + + return { + variableName: variableNameNode.text, + functionName: functionName, + } +} diff --git a/src/core/internal.ts b/src/core/internal.ts index 4720cfc..ede8b20 100644 --- a/src/core/internal.ts +++ b/src/core/internal.ts @@ -77,6 +77,11 @@ export interface MountInfo { app: string } +export interface FactoryCallInfo { + variableName: string + functionName: string +} + export interface FileAnalysis { filePath: string routes: RouteInfo[] @@ -84,6 +89,7 @@ export interface FileAnalysis { includeRouters: IncludeRouterInfo[] mounts: MountInfo[] imports: ImportInfo[] + factoryCalls: FactoryCallInfo[] } export interface RouterNode { diff --git a/src/core/routerResolver.ts b/src/core/routerResolver.ts index 54a3262..c7a11bd 100644 --- a/src/core/routerResolver.ts +++ b/src/core/routerResolver.ts @@ -197,7 +197,52 @@ async function buildRouterGraphInternal( } } - if (!appRouter || !analysis) { + // Factory function in another module: if the entrypoint variable is assigned via + // `app = create_app()` where `create_app` is imported, follow the import to the + // factory file and build the router graph from there. This works because + // routerExtractor and includeRouterExtractor recurse into function bodies, so + // `app = FastAPI()` and `app.include_router(...)` inside `create_app` are visible + // when analyzing the factory file directly. + if (!appRouter && targetVariable) { + const factoryCall = analysis.factoryCalls.find( + (fc) => fc.variableName === targetVariable, + ) + if (factoryCall) { + const matchingImport = analysis.imports.find((imp) => + imp.names.includes(factoryCall.functionName), + ) + if (matchingImport) { + const namedImport = matchingImport.namedImports.find( + (ni) => (ni.alias ?? ni.name) === factoryCall.functionName, + ) + const originalName = namedImport?.name ?? factoryCall.functionName + const factoryFileUri = await resolveNamedImport( + { + modulePath: matchingImport.modulePath, + names: [originalName], + isRelative: matchingImport.isRelative, + relativeDots: matchingImport.relativeDots, + }, + entryFileUri, + projectRootUri, + fs, + analyzeFileFn, + ) + if (factoryFileUri && !visited.has(factoryFileUri)) { + const factoryGraph = await buildRouterGraphInternal( + factoryFileUri, + ctx, + ) + if (factoryGraph) { + factoryGraph.variableName = targetVariable + return factoryGraph + } + } + } + } + } + + if (!appRouter) { return null } diff --git a/src/test/core/routerResolver.test.ts b/src/test/core/routerResolver.test.ts index 3e92675..3512c51 100644 --- a/src/test/core/routerResolver.test.ts +++ b/src/test/core/routerResolver.test.ts @@ -571,6 +571,39 @@ suite("routerResolver", () => { assert.strictEqual(result, null) }) + test("follows imported factory function to resolve include_router calls", async () => { + const result = await buildRouterGraph( + fixtures.factoryFunc.factoryMainPy, + parser, + fixtures.factoryFunc.root, + nodeFileSystem, + "app", + ) + + assert.ok(result, "Should find app via imported factory function") + assert.strictEqual(result.type, "FastAPI") + assert.strictEqual(result.variableName, "app") + assert.strictEqual( + result.children.length, + 1, + "Should have one included router", + ) + assert.ok( + result.children[0].router.routes.length >= 2, + "Should have routes from routers.py", + ) + }) + + test("returns null without targetVariable when factory function has no local routes", async () => { + const result = await buildRouterGraph( + fixtures.factoryFunc.factoryMainPy, + parser, + fixtures.factoryFunc.root, + nodeFileSystem, + ) + assert.strictEqual(result, null) + }) + test("resolves custom APIRouter subclass as child router", async () => { const result = await buildRouterGraph( fixtures.customSubclass.mainPy, diff --git a/src/test/fixtures/factory-func/app.py b/src/test/fixtures/factory-func/app.py index a7f51c6..1234c9c 100644 --- a/src/test/fixtures/factory-func/app.py +++ b/src/test/fixtures/factory-func/app.py @@ -1,5 +1,12 @@ from fastapi import FastAPI +from routers import router def get_fastapi_app() -> FastAPI: return FastAPI() + + +def create_app() -> FastAPI: + app = FastAPI() + app.include_router(router, prefix="/users") + return app diff --git a/src/test/fixtures/factory-func/factory_main.py b/src/test/fixtures/factory-func/factory_main.py new file mode 100644 index 0000000..0a23b5a --- /dev/null +++ b/src/test/fixtures/factory-func/factory_main.py @@ -0,0 +1,3 @@ +from app import create_app + +app = create_app() diff --git a/src/test/fixtures/factory-func/routers.py b/src/test/fixtures/factory-func/routers.py new file mode 100644 index 0000000..27fb4c4 --- /dev/null +++ b/src/test/fixtures/factory-func/routers.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter + +router = APIRouter() + + +@router.get("/") +def list_users(): + return [] + + +@router.get("/{user_id}") +def get_user(user_id: int): + return {"id": user_id} diff --git a/src/test/testUtils.ts b/src/test/testUtils.ts index 3df3bb0..a146b0a 100644 --- a/src/test/testUtils.ts +++ b/src/test/testUtils.ts @@ -50,6 +50,7 @@ export const fixtures = { factoryFunc: { root: uri(join(fixturesPath, "factory-func")), mainPy: uri(join(fixturesPath, "factory-func", "main.py")), + factoryMainPy: uri(join(fixturesPath, "factory-func", "factory_main.py")), }, flat: { root: uri(join(fixturesPath, "flat")),