Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/core/analyzer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
collectRecognizedNames,
collectStringVariables,
decoratorExtractor,
factoryCallExtractor,
getNodesByType,
importExtractor,
includeRouterExtractor,
Expand Down Expand Up @@ -46,6 +47,10 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
// Get all router assignments
const assignments = nodesByType.get("assignment") ?? []
const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType)
const knownConstructors = new Set([...fastAPINames, ...apiRouterNames])
const factoryCalls = assignments
.map((node) => factoryCallExtractor(node, knownConstructors))
.filter(notNull)
const routers = assignments
.map((node) => routerExtractor(node, apiRouterNames, fastAPINames))
.filter(notNull)
Expand Down Expand Up @@ -84,6 +89,7 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
includeRouters,
mounts,
imports,
factoryCalls,
}
}

Expand Down
39 changes: 39 additions & 0 deletions src/core/extractors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import type { Node } from "web-tree-sitter"
import type {
FactoryCallInfo,
ImportedName,
ImportInfo,
IncludeRouterInfo,
Expand Down Expand Up @@ -582,3 +583,41 @@ export function mountExtractor(node: Node): MountInfo | null {
app: appNode?.text ?? "",
}
}

export function factoryCallExtractor(
node: Node,
knownConstructors: Set<string>,
): FactoryCallInfo | null {
if (node.type !== "assignment") {
return null
}

const variableNameNode = node.childForFieldName("left")
const valueNode = node.childForFieldName("right")
if (!variableNameNode || valueNode?.type !== "call") {
return null
}

const functionNode = valueNode.childForFieldName("function")
if (functionNode?.type !== "identifier") {
return null
}

const functionName = functionNode.text
if (knownConstructors.has(functionName)) {
return null
}

// Skip function and class-local variables to avoid false positives
if (
hasAncestor(node, "function_definition") ||
hasAncestor(node, "class_definition")
) {
return null
}

return {
variableName: variableNameNode.text,
functionName: functionName,
}
}
6 changes: 6 additions & 0 deletions src/core/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,19 @@ export interface MountInfo {
app: string
}

export interface FactoryCallInfo {
variableName: string
functionName: string
}

export interface FileAnalysis {
filePath: string
routes: RouteInfo[]
routers: RouterInfo[]
includeRouters: IncludeRouterInfo[]
mounts: MountInfo[]
imports: ImportInfo[]
factoryCalls: FactoryCallInfo[]
}

export interface RouterNode {
Expand Down
47 changes: 46 additions & 1 deletion src/core/routerResolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,52 @@ async function buildRouterGraphInternal(
}
}

if (!appRouter || !analysis) {
// Factory function in another module: if the entrypoint variable is assigned via
// `app = create_app()` where `create_app` is imported, follow the import to the
// factory file and build the router graph from there. This works because
// routerExtractor and includeRouterExtractor recurse into function bodies, so
// `app = FastAPI()` and `app.include_router(...)` inside `create_app` are visible
// when analyzing the factory file directly.
if (!appRouter && targetVariable) {
const factoryCall = analysis.factoryCalls.find(
(fc) => fc.variableName === targetVariable,
)
if (factoryCall) {
const matchingImport = analysis.imports.find((imp) =>
imp.names.includes(factoryCall.functionName),
)
if (matchingImport) {
const namedImport = matchingImport.namedImports.find(
(ni) => (ni.alias ?? ni.name) === factoryCall.functionName,
)
const originalName = namedImport?.name ?? factoryCall.functionName
const factoryFileUri = await resolveNamedImport(
{
modulePath: matchingImport.modulePath,
names: [originalName],
isRelative: matchingImport.isRelative,
relativeDots: matchingImport.relativeDots,
},
entryFileUri,
projectRootUri,
fs,
analyzeFileFn,
)
if (factoryFileUri && !visited.has(factoryFileUri)) {
const factoryGraph = await buildRouterGraphInternal(
factoryFileUri,
ctx,
)
if (factoryGraph) {
factoryGraph.variableName = targetVariable
return factoryGraph
}
}
}
}
}

if (!appRouter) {
return null
}

Expand Down
33 changes: 33 additions & 0 deletions src/test/core/routerResolver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,39 @@ suite("routerResolver", () => {
assert.strictEqual(result, null)
})

test("follows imported factory function to resolve include_router calls", async () => {
const result = await buildRouterGraph(
fixtures.factoryFunc.factoryMainPy,
parser,
fixtures.factoryFunc.root,
nodeFileSystem,
"app",
)

assert.ok(result, "Should find app via imported factory function")
assert.strictEqual(result.type, "FastAPI")
assert.strictEqual(result.variableName, "app")
assert.strictEqual(
result.children.length,
1,
"Should have one included router",
)
assert.ok(
result.children[0].router.routes.length >= 2,
"Should have routes from routers.py",
)
})

test("returns null without targetVariable when factory function has no local routes", async () => {
const result = await buildRouterGraph(
fixtures.factoryFunc.factoryMainPy,
parser,
fixtures.factoryFunc.root,
nodeFileSystem,
)
assert.strictEqual(result, null)
})

test("resolves custom APIRouter subclass as child router", async () => {
const result = await buildRouterGraph(
fixtures.customSubclass.mainPy,
Expand Down
7 changes: 7 additions & 0 deletions src/test/fixtures/factory-func/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from fastapi import FastAPI
from routers import router


def get_fastapi_app() -> FastAPI:
return FastAPI()


def create_app() -> FastAPI:
app = FastAPI()
app.include_router(router, prefix="/users")
return app
3 changes: 3 additions & 0 deletions src/test/fixtures/factory-func/factory_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app import create_app

app = create_app()
13 changes: 13 additions & 0 deletions src/test/fixtures/factory-func/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi import APIRouter

router = APIRouter()


@router.get("/")
def list_users():
return []


@router.get("/{user_id}")
def get_user(user_id: int):
return {"id": user_id}
1 change: 1 addition & 0 deletions src/test/testUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export const fixtures = {
factoryFunc: {
root: uri(join(fixturesPath, "factory-func")),
mainPy: uri(join(fixturesPath, "factory-func", "main.py")),
factoryMainPy: uri(join(fixturesPath, "factory-func", "factory_main.py")),
},
flat: {
root: uri(join(fixturesPath, "flat")),
Expand Down
Loading