diff --git a/api/analyzers/analyzer.py b/api/analyzers/analyzer.py index 33ca5a2b..8e3e855e 100644 --- a/api/analyzers/analyzer.py +++ b/api/analyzers/analyzer.py @@ -149,3 +149,32 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ pass + @abstractmethod + def add_file_imports(self, file: File) -> None: + """ + Add import statements to the file. + + Args: + file (File): The file to add imports to. + """ + + pass + + @abstractmethod + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to entities. + + Args: + files (dict[Path, File]): All files in the project. + lsp (SyncLanguageServer): The language server. + file_path (Path): The path to the file containing the import. + path (Path): The path to the project root. + import_node (Node): The import statement node. + + Returns: + list[Entity]: List of resolved entities. + """ + + pass + diff --git a/api/analyzers/csharp/analyzer.py b/api/analyzers/csharp/analyzer.py index 74c3906e..aa51034c 100644 --- a/api/analyzers/csharp/analyzer.py +++ b/api/analyzers/csharp/analyzer.py @@ -136,3 +136,11 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ return self.resolve_method(files, lsp, file_path, path, symbol) else: raise ValueError(f"Unknown key {key}") + + def add_file_imports(self, file: File) -> None: + # C# import tracking not yet implemented + pass + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + # C# import resolution not yet implemented + return [] diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 5269d698..77ed63c9 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -2,6 +2,8 @@ from pathlib import Path import subprocess from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -127,3 +129,19 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ return self.resolve_method(files, lsp, file_path, path, symbol) else: raise ValueError(f"Unknown key {key}") + + def add_file_imports(self, file: File) -> None: + """ + Extract and add import statements from the file. + Java imports are not yet implemented. + """ + # TODO: Implement Java import tracking + pass + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to the entities it imports. + Java imports are not yet implemented. + """ + # TODO: Implement Java import resolution + return [] diff --git a/api/analyzers/javascript/analyzer.py b/api/analyzers/javascript/analyzer.py index abc2879f..becbb7f6 100644 --- a/api/analyzers/javascript/analyzer.py +++ b/api/analyzers/javascript/analyzer.py @@ -158,6 +158,14 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ res.append(file.entities[method_dec]) return res + def add_file_imports(self, file: File) -> None: + """JavaScript import tracking not yet implemented.""" + pass + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """JavaScript import resolution not yet implemented.""" + return [] + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: """Dispatch symbol resolution based on the symbol category. diff --git a/api/analyzers/kotlin/analyzer.py b/api/analyzers/kotlin/analyzer.py index 3758c302..ea720abe 100644 --- a/api/analyzers/kotlin/analyzer.py +++ b/api/analyzers/kotlin/analyzer.py @@ -148,6 +148,14 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ break return res + def add_file_imports(self, file: File) -> None: + """Kotlin import tracking not yet implemented.""" + pass + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """Kotlin import resolution not yet implemented.""" + return [] + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: if key in ["implement_interface", "base_class", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, symbol) diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 7a991202..074dbd39 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -5,6 +5,8 @@ import tomllib from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -96,9 +98,11 @@ def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_pa if node.type == 'attribute': node = node.child_by_field_name('attribute') for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_definition']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) + decl = resolved_node + if decl.type not in ['class_definition', 'function_definition']: + decl = self.find_parent(resolved_node, ['class_definition', 'function_definition']) + if decl in file.entities: + res.append(file.entities[decl]) return res def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: @@ -122,3 +126,92 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ return self.resolve_method(files, lsp, file_path, path, symbol) else: raise ValueError(f"Unknown key {key}") + + def add_file_imports(self, file: File) -> None: + """ + Extract and add import statements from the file. + + Supports: + - import module + - import module as alias + - from module import name + - from module import name1, name2 + - from module import name as alias + """ + try: + captures = self._captures(""" + (import_statement) @import + (import_from_statement) @import_from + """, file.tree.root_node) + + # Add all import statement nodes to the file + if 'import' in captures: + for import_node in captures['import']: + file.add_import(import_node) + + if 'import_from' in captures: + for import_node in captures['import_from']: + file.add_import(import_node) + except Exception as e: + logger.debug(f"Failed to extract imports from {file.path}: {e}") + + def _resolve_import_name(self, files, lsp, file_path, path, identifier): + """Try to resolve an imported name as both a type and a function.""" + resolved = self.resolve_type(files, lsp, file_path, path, identifier) + if not resolved: + resolved = self.resolve_method(files, lsp, file_path, path, identifier) + return resolved + + def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]: + """ + Resolve an import statement to the entities it imports. + """ + res = [] + + try: + if import_node.type == 'import_statement': + # Handle "import module" or "import module as alias" + # Find all dotted_name and aliased_import nodes + for child in import_node.children: + if child.type == 'dotted_name': + # Try to resolve the module/name + identifier = child.children[0] if child.child_count > 0 else child + res.extend(self._resolve_import_name(files, lsp, file_path, path, identifier)) + elif child.type == 'aliased_import': + # Get the actual name from aliased import (before 'as') + if child.child_count > 0: + actual_name = child.children[0] + if actual_name.type == 'dotted_name' and actual_name.child_count > 0: + identifier = actual_name.children[0] + else: + identifier = actual_name + res.extend(self._resolve_import_name(files, lsp, file_path, path, identifier)) + + elif import_node.type == 'import_from_statement': + # Handle "from module import name1, name2" + # Find the 'import' keyword to know where imported names start + import_keyword_found = False + for child in import_node.children: + if child.type == 'import': + import_keyword_found = True + continue + + # After 'import' keyword, dotted_name nodes are the imported names + if import_keyword_found and child.type == 'dotted_name': + # Try to resolve the imported name + identifier = child.children[0] if child.child_count > 0 else child + res.extend(self._resolve_import_name(files, lsp, file_path, path, identifier)) + elif import_keyword_found and child.type == 'aliased_import': + # Handle "from module import name as alias" + if child.child_count > 0: + actual_name = child.children[0] + if actual_name.type == 'dotted_name' and actual_name.child_count > 0: + identifier = actual_name.children[0] + else: + identifier = actual_name + res.extend(self._resolve_import_name(files, lsp, file_path, path, identifier)) + + except Exception as e: + logger.debug(f"Failed to resolve import: {e}") + + return res diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 9046abcf..6396c199 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -119,6 +119,10 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr # Walk thought the AST graph.add_file(file) self.create_hierarchy(file, analyzer, graph) + + # Extract import statements + if not analyzer.is_dependency(str(file_path)): + analyzer.add_file_imports(file) def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: """ @@ -162,6 +166,8 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: continue file = self.files[file_path] logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') + + # Resolve entity symbols for _, entity in file.entities.items(): entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol)) for key, resolved_set in entity.resolved_symbols.items(): @@ -178,6 +184,13 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: graph.connect_entities("RETURNS", entity.id, resolved.id) elif key == "parameters": graph.connect_entities("PARAMETERS", entity.id, resolved.id) + + # Resolve file imports + for import_node in file.imports: + resolved_entities = analyzers[file_path.suffix].resolve_import(self.files, lsps[file_path.suffix], file_path, path, import_node) + for resolved_entity in resolved_entities: + file.add_resolved_import(resolved_entity) + graph.connect_entities("IMPORTS", file.id, resolved_entity.id) def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: self.first_pass(path, files, [], graph) diff --git a/api/entities/file.py b/api/entities/file.py index c59e2b6a..a8937349 100644 --- a/api/entities/file.py +++ b/api/entities/file.py @@ -21,10 +21,30 @@ def __init__(self, path: Path, tree: Tree) -> None: self.path = path self.tree = tree self.entities: dict[Node, Entity] = {} + self.imports: list[Node] = [] + self.resolved_imports: set[Entity] = set() def add_entity(self, entity: Entity): entity.parent = self self.entities[entity.node] = entity + + def add_import(self, import_node: Node): + """ + Add an import statement node to track. + + Args: + import_node (Node): The import statement node. + """ + self.imports.append(import_node) + + def add_resolved_import(self, resolved_entity: Entity): + """ + Add a resolved import entity. + + Args: + resolved_entity (Entity): The resolved entity that is imported. + """ + self.resolved_imports.add(resolved_entity) def __str__(self) -> str: return f"path: {self.path}" diff --git a/test-project/a.c b/test-project/a.c new file mode 100644 index 00000000..bdde24d5 --- /dev/null +++ b/test-project/a.c @@ -0,0 +1,11 @@ +#include +#include "src/ff.h" + + +/* Create an empty intset. */ +intset* intsetNew(void) { + intset *is = zmalloc(sizeof(intset)); + is->encoding = intrev32ifbe(INTSET_ENC_INT16); + is->length = 0; + return is; +} \ No newline at end of file diff --git a/test-project/c.java b/test-project/c.java new file mode 100644 index 00000000..a2cec443 --- /dev/null +++ b/test-project/c.java @@ -0,0 +1,26 @@ +package test_project; + +public class c { + + private int a; + + public static void main(String[] args) { + System.out.println("Hello, World!"); + } + + public static void print() { + System.out.println("Hello, World!"); + } + + public int getA() { + return a; + } + + public void setA(int a) { + this.a = a; + } + + public void inc() { + setA(getA() + 1); + } +} diff --git a/tests/source_files/py_imports/module_a.py b/tests/source_files/py_imports/module_a.py new file mode 100644 index 00000000..b6323048 --- /dev/null +++ b/tests/source_files/py_imports/module_a.py @@ -0,0 +1,12 @@ +"""Module A with a class definition.""" + +class ClassA: + """A simple class in module A.""" + + def method_a(self): + """A method in ClassA.""" + return "Method A" + +def function_a(): + """A function in module A.""" + return "Function A" diff --git a/tests/source_files/py_imports/module_b.py b/tests/source_files/py_imports/module_b.py new file mode 100644 index 00000000..c0c1c307 --- /dev/null +++ b/tests/source_files/py_imports/module_b.py @@ -0,0 +1,11 @@ +"""Module B that imports from module A.""" + +from module_a import ClassA, function_a + +class ClassB(ClassA): + """A class that extends ClassA.""" + + def method_b(self): + """A method in ClassB.""" + result = function_a() + return f"Method B: {result}" diff --git a/tests/test_py_imports.py b/tests/test_py_imports.py new file mode 100644 index 00000000..587da8f2 --- /dev/null +++ b/tests/test_py_imports.py @@ -0,0 +1,64 @@ +import os +import unittest + +from api import SourceAnalyzer, Graph + + +class Test_PY_Imports(unittest.TestCase): + def test_import_tracking(self): + """Test that Python imports are tracked correctly.""" + # Get test file path + current_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(current_dir, 'source_files', 'py_imports') + + # Create graph and analyze + g = Graph("py_imports_test") + analyzer = SourceAnalyzer() + + try: + analyzer.analyze_local_folder(test_path, g) + + # Verify files were created + module_a = g.get_file('', 'module_a.py', '.py') + self.assertIsNotNone(module_a, "module_a.py should be in the graph") + + module_b = g.get_file('', 'module_b.py', '.py') + self.assertIsNotNone(module_b, "module_b.py should be in the graph") + + # Verify classes were created + class_a = g.get_class_by_name('ClassA') + self.assertIsNotNone(class_a, "ClassA should be in the graph") + + class_b = g.get_class_by_name('ClassB') + self.assertIsNotNone(class_b, "ClassB should be in the graph") + + # Verify function was created + func_a = g.get_function_by_name('function_a') + self.assertIsNotNone(func_a, "function_a should be in the graph") + + # Test: module_b should have IMPORTS relationship to ClassA + # Query to check if module_b imports ClassA + query = """ + MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(c:Class {name: 'ClassA'}) + RETURN c + """ + result = g._query(query, {}) + self.assertGreater(len(result.result_set), 0, + "module_b.py should import ClassA") + + # Test: module_b should have IMPORTS relationship to function_a + query = """ + MATCH (f:File {name: 'module_b.py'})-[:IMPORTS]->(fn:Function {name: 'function_a'}) + RETURN fn + """ + result = g._query(query, {}) + self.assertGreater(len(result.result_set), 0, + "module_b.py should import function_a") + + finally: + # Cleanup: delete the test graph + g.delete() + + +if __name__ == '__main__': + unittest.main()