Skip to content

Commit 8d25bcc

Browse files
authored
Merge pull request #1605 from codeflash-ai/fix-tracer-replay-discovery
fix: resolve test file paths in discover_tests_pytest to fix path com…
2 parents 6346c74 + 54cb606 commit 8d25bcc

6 files changed

Lines changed: 64 additions & 326 deletions

File tree

codeflash/cli_cmds/cli.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -352,32 +352,52 @@ def _handle_show_config() -> None:
352352
from codeflash.setup.detector import detect_project, has_existing_config
353353

354354
project_root = Path.cwd()
355-
detected = detect_project(project_root)
355+
config_exists, _ = has_existing_config(project_root)
356356

357-
# Check if config exists or is auto-detected
358-
config_exists, config_file = has_existing_config(project_root)
359-
status = "Saved config" if config_exists else "Auto-detected (not saved)"
357+
if config_exists:
358+
from codeflash.code_utils.config_parser import parse_config_file
360359

361-
console.print()
362-
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
363-
if config_exists and config_file:
364-
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
365-
console.print()
360+
config, config_file_path = parse_config_file()
361+
status = "Saved config"
366362

367-
table = Table(show_header=True, header_style="bold cyan")
368-
table.add_column("Setting", style="dim")
369-
table.add_column("Value")
370-
371-
table.add_row("Language", detected.language)
372-
table.add_row("Project root", str(detected.project_root))
373-
table.add_row("Module root", str(detected.module_root))
374-
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
375-
table.add_row("Test runner", detected.test_runner or "(not detected)")
376-
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
377-
table.add_row(
378-
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
379-
)
380-
table.add_row("Confidence", f"{detected.confidence:.0%}")
363+
console.print()
364+
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
365+
console.print(f"[dim]Config file: {config_file_path}[/dim]")
366+
console.print()
367+
368+
table = Table(show_header=True, header_style="bold cyan")
369+
table.add_column("Setting", style="dim")
370+
table.add_column("Value")
371+
372+
table.add_row("Project root", str(project_root))
373+
table.add_row("Module root", config.get("module_root", "(not set)"))
374+
table.add_row("Tests root", config.get("tests_root", "(not set)"))
375+
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
376+
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
377+
ignore_paths = config.get("ignore_paths", [])
378+
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
379+
else:
380+
detected = detect_project(project_root)
381+
status = "Auto-detected (not saved)"
382+
383+
console.print()
384+
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
385+
console.print()
386+
387+
table = Table(show_header=True, header_style="bold cyan")
388+
table.add_column("Setting", style="dim")
389+
table.add_column("Value")
390+
391+
table.add_row("Language", detected.language)
392+
table.add_row("Project root", str(detected.project_root))
393+
table.add_row("Module root", str(detected.module_root))
394+
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
395+
table.add_row("Test runner", detected.test_runner or "(not detected)")
396+
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
397+
table.add_row(
398+
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
399+
)
400+
table.add_row("Confidence", f"{detected.confidence:.0%}")
381401

382402
console.print(table)
383403
console.print()

codeflash/code_utils/time_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from codeflash.result.critic import performance_gain
4+
35

46
def humanize_runtime(time_in_ns: int) -> str:
57
runtime_human: str = str(time_in_ns)
@@ -89,3 +91,13 @@ def format_perf(percentage: float) -> str:
8991
if abs_perc >= 1:
9092
return f"{percentage:.2f}"
9193
return f"{percentage:.3f}"
94+
95+
96+
def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str:
97+
perf_gain = format_perf(
98+
abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100)
99+
)
100+
status = "slower" if optimized_time_ns > original_time_ns else "faster"
101+
return (
102+
f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})"
103+
)

codeflash/discovery/discover_unit_tests.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,10 @@ def discover_tests_pytest(
728728
logger.debug(f"Pytest collection exit code: {exitcode}")
729729
if pytest_rootdir is not None:
730730
cfg.tests_project_rootdir = Path(pytest_rootdir)
731+
if discover_only_these_tests:
732+
resolved_discover_only = {p.resolve() for p in discover_only_these_tests}
733+
else:
734+
resolved_discover_only = None
731735
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
732736
for test in tests:
733737
if "__replay_test" in test["test_file"]:
@@ -737,13 +741,14 @@ def discover_tests_pytest(
737741
else:
738742
test_type = TestType.EXISTING_UNIT_TEST
739743

744+
test_file_path = Path(test["test_file"]).resolve()
740745
test_obj = TestsInFile(
741-
test_file=Path(test["test_file"]),
746+
test_file=test_file_path,
742747
test_class=test["test_class"],
743748
test_function=test["test_function"],
744749
test_type=test_type,
745750
)
746-
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
751+
if resolved_discover_only and test_obj.test_file not in resolved_discover_only:
747752
continue
748753
file_to_test_map[test_obj.test_file].append(test_obj)
749754
# Within these test files, find the project functions they are referring to and return their names/locations

codeflash/languages/javascript/edit_tests.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,8 @@
1111
from pathlib import Path
1212

1313
from codeflash.cli_cmds.console import logger
14-
from codeflash.code_utils.time_utils import format_perf, format_time
14+
from codeflash.code_utils.time_utils import format_runtime_comment
1515
from codeflash.models.models import GeneratedTests, GeneratedTestsList
16-
from codeflash.result.critic import performance_gain
17-
18-
19-
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
20-
"""Format a runtime comparison comment for JavaScript.
21-
22-
Args:
23-
original_time: Original runtime in nanoseconds.
24-
optimized_time: Optimized runtime in nanoseconds.
25-
26-
Returns:
27-
Formatted comment string with // prefix.
28-
29-
"""
30-
perf_gain = format_perf(
31-
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
32-
)
33-
status = "slower" if optimized_time > original_time else "faster"
34-
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
3516

3617

3718
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
@@ -120,7 +101,7 @@ def find_matching_test(test_description: str) -> str | None:
120101
# Only add comment if line has a function call and doesn't already have a comment
121102
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
122103
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
123-
comment = format_runtime_comment(orig_time, opt_time)
104+
comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//")
124105
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
125106
# Add comment at end of line
126107
line = f"{line.rstrip()} {comment}"

codeflash/languages/python/static_analysis/code_replacer.py

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -239,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
239239
test_path.write_text(modified_module.code, encoding="utf-8")
240240

241241

242-
class OptimFunctionCollector(cst.CSTVisitor):
243-
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
244-
245-
def __init__(
246-
self,
247-
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
248-
function_names: set[tuple[str | None, str]] | None = None,
249-
) -> None:
250-
super().__init__()
251-
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
252-
253-
self.function_names = function_names # set of (class_name, function_name)
254-
self.modified_functions: dict[
255-
tuple[str | None, str], cst.FunctionDef
256-
] = {} # keys are (class_name, function_name)
257-
self.new_functions: list[cst.FunctionDef] = []
258-
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
259-
self.new_classes: list[cst.ClassDef] = []
260-
self.current_class = None
261-
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
262-
263-
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
264-
if (self.current_class, node.name.value) in self.function_names:
265-
self.modified_functions[(self.current_class, node.name.value)] = node
266-
elif self.current_class and node.name.value == "__init__":
267-
self.modified_init_functions[self.current_class] = node
268-
elif (
269-
self.preexisting_objects
270-
and (node.name.value, ()) not in self.preexisting_objects
271-
and self.current_class is None
272-
):
273-
self.new_functions.append(node)
274-
return False
275-
276-
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
277-
if self.current_class:
278-
return False # If already in a class, do not recurse deeper
279-
self.current_class = node.name.value
280-
281-
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
282-
283-
if (node.name.value, ()) not in self.preexisting_objects:
284-
self.new_classes.append(node)
285-
286-
for child_node in node.body.body:
287-
if (
288-
self.preexisting_objects
289-
and isinstance(child_node, cst.FunctionDef)
290-
and (child_node.name.value, parents) not in self.preexisting_objects
291-
):
292-
self.new_class_functions[node.name.value].append(child_node)
293-
294-
return True
295-
296-
def leave_ClassDef(self, node: cst.ClassDef) -> None:
297-
if self.current_class:
298-
self.current_class = None
299-
300-
301-
class OptimFunctionReplacer(cst.CSTTransformer):
302-
def __init__(
303-
self,
304-
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
305-
new_classes: Optional[list[cst.ClassDef]] = None,
306-
new_functions: Optional[list[cst.FunctionDef]] = None,
307-
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
308-
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
309-
) -> None:
310-
super().__init__()
311-
self.modified_functions = modified_functions if modified_functions is not None else {}
312-
self.new_functions = new_functions if new_functions is not None else []
313-
self.new_classes = new_classes if new_classes is not None else []
314-
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
315-
self.modified_init_functions: dict[str, cst.FunctionDef] = (
316-
modified_init_functions if modified_init_functions is not None else {}
317-
)
318-
self.current_class = None
319-
320-
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
321-
return False
322-
323-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
324-
if (self.current_class, original_node.name.value) in self.modified_functions:
325-
node = self.modified_functions[(self.current_class, original_node.name.value)]
326-
return updated_node.with_changes(body=node.body, decorators=node.decorators)
327-
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
328-
return self.modified_init_functions[self.current_class]
329-
330-
return updated_node
331-
332-
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
333-
if self.current_class:
334-
return False # If already in a class, do not recurse deeper
335-
self.current_class = node.name.value
336-
return True
337-
338-
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
339-
if self.current_class and self.current_class == original_node.name.value:
340-
self.current_class = None
341-
if original_node.name.value in self.new_class_functions:
342-
return updated_node.with_changes(
343-
body=updated_node.body.with_changes(
344-
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
345-
)
346-
)
347-
return updated_node
348-
349-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
350-
node = updated_node
351-
max_function_index = None
352-
max_class_index = None
353-
for index, _node in enumerate(node.body):
354-
if isinstance(_node, cst.FunctionDef):
355-
max_function_index = index
356-
if isinstance(_node, cst.ClassDef):
357-
max_class_index = index
358-
359-
if self.new_classes:
360-
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
361-
362-
unique_classes = [
363-
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
364-
]
365-
if unique_classes:
366-
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
367-
new_body = list(
368-
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
369-
)
370-
node = node.with_changes(body=new_body)
371-
372-
if max_function_index is not None:
373-
node = node.with_changes(
374-
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
375-
)
376-
elif max_class_index is not None:
377-
node = node.with_changes(
378-
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
379-
)
380-
else:
381-
node = node.with_changes(body=(*self.new_functions, *node.body))
382-
return node
383-
384-
385242
def replace_functions_in_file(
386243
source_code: str,
387244
original_function_names: list[str],

0 commit comments

Comments
 (0)