From 261e872047790d0a73cbf055df3150965f0ff245 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 03:56:03 -0500 Subject: [PATCH 1/9] feat: add codeflash_core and codeflash_python packages with unified types --- codeflash/api/aiservice.py | 2 +- codeflash/benchmarking/function_ranker.py | 4 +- .../instrument_codeflash_trace.py | 2 +- codeflash/cli_cmds/console.py | 2 +- .../code_utils/instrument_existing_tests.py | 2 +- codeflash/discovery/discover_unit_tests.py | 12 +- codeflash/discovery/functions_to_optimize.py | 7 +- codeflash/languages/__init__.py | 2 +- codeflash/languages/base.py | 35 +- codeflash/languages/code_replacer.py | 2 +- codeflash/languages/function_optimizer.py | 7 +- codeflash/languages/java/context.py | 2 +- codeflash/languages/java/discovery.py | 3 +- codeflash/languages/java/instrumentation.py | 2 +- codeflash/languages/java/parse.py | 2 +- codeflash/languages/java/remove_asserts.py | 2 +- codeflash/languages/java/replacement.py | 2 +- codeflash/languages/java/support.py | 4 +- codeflash/languages/java/test_discovery.py | 2 +- .../languages/javascript/find_references.py | 6 +- .../languages/javascript/import_resolver.py | 4 +- codeflash/languages/javascript/instrument.py | 2 +- .../languages/javascript/line_profiler.py | 2 +- codeflash/languages/javascript/optimizer.py | 2 +- codeflash/languages/javascript/parse.py | 2 +- codeflash/languages/javascript/support.py | 4 +- codeflash/languages/javascript/tracer.py | 2 +- .../python/context/code_context_extractor.py | 2 +- .../context/unused_definition_remover.py | 2 +- .../languages/python/function_optimizer.py | 2 +- .../python/instrument_codeflash_capture.py | 2 +- codeflash/languages/python/optimizer.py | 2 +- codeflash/languages/python/parse_xml.py | 2 +- .../python/static_analysis/code_extractor.py | 4 +- .../static_analysis/line_profile_utils.py | 2 +- .../python/static_analysis/static_analysis.py | 2 +- codeflash/languages/python/support.py | 11 +- codeflash/lsp/beta.py | 2 +- codeflash/models/function_types.py | 93 -- codeflash/models/models.py | 2 +- codeflash/optimization/optimizer.py | 14 +- codeflash/plugin.py | 554 +++++++++++ codeflash/plugin_ai_ops.py | 242 +++++ codeflash/plugin_helpers.py | 167 ++++ codeflash/plugin_results.py | 179 ++++ codeflash/plugin_test_lifecycle.py | 269 +++++ codeflash/result/create_pr.py | 6 +- codeflash/verification/parse_test_output.py | 2 +- codeflash/verification/test_runner.py | 277 ++++++ codeflash/verification/verification_utils.py | 48 +- codeflash/verification/verifier.py | 4 +- pyproject.toml | 4 +- src/codeflash_core/config.py | 14 +- src/codeflash_core/models.py | 25 +- src/codeflash_python/__init__.py | 0 src/codeflash_python/api/__init__.py | 0 src/codeflash_python/api/aiservice.py | 129 +++ .../api/aiservice_optimize.py | 366 +++++++ src/codeflash_python/api/aiservice_results.py | 341 +++++++ src/codeflash_python/api/aiservice_testgen.py | 213 ++++ src/codeflash_python/api/cfapi.py | 448 +++++++++ src/codeflash_python/api/types.py | 95 ++ src/codeflash_python/benchmarking/__init__.py | 0 .../benchmarking/codeflash_trace.py | 233 +++++ .../benchmarking/function_ranker.py | 239 +++++ .../instrument_codeflash_trace.py | 130 +++ .../parse_line_profile_test_output.py | 131 +++ .../benchmarking/plugin/__init__.py | 0 .../benchmarking/plugin/plugin.py | 300 ++++++ .../benchmarking/profile_stats.py | 93 ++ .../pytest_new_process_trace_benchmarks.py | 50 + .../benchmarking/replay_test.py | 305 ++++++ .../benchmarking/trace_benchmarks.py | 50 + .../benchmarking/tracing_new_process.py | 864 +++++++++++++++++ .../benchmarking/tracing_utils.py | 97 ++ src/codeflash_python/benchmarking/utils.py | 124 +++ src/codeflash_python/cli.py | 191 ++++ src/codeflash_python/cli_common.py | 15 + src/codeflash_python/code_utils/__init__.py | 0 src/codeflash_python/code_utils/checkpoint.py | 158 +++ src/codeflash_python/code_utils/code_utils.py | 139 +++ .../code_utils/codeflash_wrap_decorator.py | 210 ++++ src/codeflash_python/code_utils/compat.py | 17 + .../code_utils/config_consts.py | 137 +++ .../code_utils/config_parser.py | 164 ++++ src/codeflash_python/code_utils/env_utils.py | 181 ++++ src/codeflash_python/code_utils/formatter.py | 175 ++++ src/codeflash_python/code_utils/git_utils.py | 216 +++++ .../code_utils/shell_utils.py | 278 ++++++ src/codeflash_python/code_utils/tabulate.py | 915 ++++++++++++++++++ src/codeflash_python/code_utils/time_utils.py | 106 ++ .../code_utils/version_check.py | 86 ++ src/codeflash_python/context/__init__.py | 0 src/codeflash_python/context/ast_helpers.py | 323 +++++++ .../context/call_graph_index.py | 668 +++++++++++++ .../context/class_extraction.py | 562 +++++++++++ .../context/code_context_extractor.py | 331 +++++++ src/codeflash_python/context/cst_pruning.py | 264 +++++ src/codeflash_python/context/jedi_helpers.py | 174 ++++ .../context/type_extraction.py | 249 +++++ src/codeflash_python/context/types.py | 131 +++ .../context/unused_definition_remover.py | 568 +++++++++++ .../context/unused_helper_detection.py | 313 ++++++ src/codeflash_python/context/utils.py | 14 + src/codeflash_python/discovery/__init__.py | 0 .../discovery/discover_unit_tests.py | 509 ++++++++++ .../discovery/filter_criteria.py | 47 + .../discovery/function_filtering.py | 281 ++++++ .../discovery/function_visitors.py | 250 +++++ .../discovery/functions_to_optimize.py | 415 ++++++++ .../discovery/import_analyzer.py | 369 +++++++ .../discovery/pytest_new_process_discovery.py | 64 ++ src/codeflash_python/discovery/tests_cache.py | 167 ++++ src/codeflash_python/function_optimizer.py | 634 ++++++++++++ src/codeflash_python/init_config.py | 254 +++++ src/codeflash_python/models/__init__.py | 0 src/codeflash_python/models/call_graph.py | 221 +++++ .../models/experiment_metadata.py | 8 + src/codeflash_python/models/function_types.py | 17 + src/codeflash_python/models/models.py | 817 ++++++++++++++++ src/codeflash_python/models/test_result.py | 60 ++ src/codeflash_python/models/test_type.py | 22 + src/codeflash_python/normalizer.py | 181 ++++ src/codeflash_python/optimization/__init__.py | 7 + .../optimization/optimizer.py | 285 ++++++ src/codeflash_python/optimizer.py | 66 ++ .../optimizer_mixins/__init__.py | 25 + .../optimizer_mixins/_protocol.py | 388 ++++++++ .../optimizer_mixins/baseline.py | 246 +++++ .../optimizer_mixins/candidate_evaluation.py | 587 +++++++++++ .../optimizer_mixins/candidate_structures.py | 309 ++++++ .../optimizer_mixins/code_replacement.py | 81 ++ .../optimizer_mixins/refinement.py | 143 +++ .../optimizer_mixins/result_processing.py | 373 +++++++ .../optimizer_mixins/scoring.py | 104 ++ .../optimizer_mixins/test_execution.py | 319 ++++++ .../optimizer_mixins/test_generation.py | 259 +++++ .../optimizer_mixins/test_review.py | 322 ++++++ src/codeflash_python/picklepatch/__init__.py | 0 .../picklepatch/pickle_patcher.py | 373 +++++++ .../picklepatch/pickle_placeholder.py | 57 ++ src/codeflash_python/plugin.py | 594 ++++++++++++ src/codeflash_python/plugin_ai_ops.py | 242 +++++ src/codeflash_python/plugin_helpers.py | 167 ++++ src/codeflash_python/plugin_results.py | 177 ++++ src/codeflash_python/plugin_test_lifecycle.py | 267 +++++ src/codeflash_python/result/__init__.py | 0 src/codeflash_python/result/create_pr.py | 356 +++++++ src/codeflash_python/result/critic.py | 215 ++++ src/codeflash_python/result/explanation.py | 140 +++ src/codeflash_python/result/github_utils.py | 38 + src/codeflash_python/result/pr_comment.py | 54 ++ src/codeflash_python/setup/__init__.py | 58 ++ src/codeflash_python/setup/config_schema.py | 118 +++ src/codeflash_python/setup/config_writer.py | 118 +++ src/codeflash_python/setup/detector.py | 246 +++++ src/codeflash_python/setup/detector_python.py | 141 +++ src/codeflash_python/setup/first_run.py | 300 ++++++ .../static_analysis/__init__.py | 0 .../static_analysis/code_extractor.py | 172 ++++ .../static_analysis/code_replacer.py | 391 ++++++++ .../static_analysis/code_replacer_base.py | 39 + .../static_analysis/concolic_utils.py | 125 +++ .../static_analysis/coverage_utils.py | 93 ++ .../static_analysis/global_code_transforms.py | 503 ++++++++++ .../static_analysis/import_analysis.py | 313 ++++++ .../static_analysis/line_profile_utils.py | 387 ++++++++ .../static_analysis/numerical_detection.py | 200 ++++ .../static_analysis/reference_analysis.py | 568 +++++++++++ .../static_analysis/static_analysis.py | 167 ++++ src/codeflash_python/telemetry/__init__.py | 0 src/codeflash_python/telemetry/posthog_cf.py | 47 + src/codeflash_python/verification/__init__.py | 0 src/codeflash_python/verification/addopts.py | 118 +++ .../verification/async_instrumentation.py | 326 +++++++ .../verification/codeflash_capture.py | 198 ++++ .../verification/comparator.py | 666 +++++++++++++ src/codeflash_python/verification/concolic.py | 105 ++ .../verification/coverage_utils.py | 236 +++++ .../verification/device_sync.py | 314 ++++++ .../verification/edit_generated_tests.py | 343 +++++++ .../verification/equivalence.py | 204 ++++ .../instrument_codeflash_capture.py | 310 ++++++ .../verification/instrument_existing_tests.py | 731 ++++++++++++++ .../verification/parse_test_output.py | 262 +++++ .../verification/parse_xml.py | 245 +++++ .../verification/path_utils.py | 24 + .../verification/pytest_plugin.py | 592 +++++++++++ .../verification/test_output_utils.py | 357 +++++++ .../verification/test_runner.py | 511 ++++++++++ .../verification/verification_utils.py | 89 ++ src/codeflash_python/verification/verifier.py | 105 ++ .../verification/wrapper_generation.py | 399 ++++++++ ...est_benchmark_code_extract_code_context.py | 4 +- .../test_benchmark_discover_unit_tests.py | 6 +- tests/code_utils/test_coverage_utils.py | 2 +- tests/test_add_needed_imports_from_module.py | 4 +- tests/test_add_runtime_comments.py | 94 +- tests/test_async_function_discovery.py | 6 +- tests/test_async_run_and_parse_tests.py | 18 +- tests/test_code_context_extractor.py | 38 +- tests/test_code_replacement.py | 40 +- tests/test_codeflash_capture.py | 32 +- tests/test_existing_tests_source_for.py | 4 +- tests/test_formatter.py | 6 +- tests/test_function_dependencies.py | 12 +- tests/test_function_discovery.py | 10 +- tests/test_get_code.py | 2 +- tests/test_get_helper_code.py | 14 +- .../test_inject_profiling_used_frameworks.py | 2 +- tests/test_instrument_all_and_run.py | 12 +- tests/test_instrument_async_tests.py | 2 +- tests/test_instrument_codeflash_capture.py | 2 +- tests/test_instrument_codeflash_trace.py | 2 +- tests/test_instrument_line_profiler.py | 24 +- tests/test_instrument_tests.py | 48 +- ...t_instrumentation_run_results_aiservice.py | 10 +- tests/test_java_assertion_removal.py | 2 +- tests/test_java_test_discovery.py | 2 +- tests/test_java_tests_project_rootdir.py | 6 +- tests/test_javascript_assertion_removal.py | 2 +- tests/test_javascript_function_discovery.py | 8 +- .../test_code_context_extraction.py | 6 +- tests/test_languages/test_find_references.py | 2 +- .../test_languages/test_java/test_context.py | 4 +- .../test_java/test_instrumentation.py | 22 +- .../test_java/test_java_tracer_integration.py | 12 +- .../test_java/test_replacement.py | 12 +- .../test_java/test_run_and_parse.py | 4 +- .../test_java/test_test_discovery.py | 2 +- .../test_javascript_instrumentation.py | 2 +- .../test_javascript_optimization_flow.py | 24 +- .../test_javascript_run_and_parse.py | 8 +- .../test_javascript_setup_test_config.py | 4 +- .../test_languages/test_js_code_extractor.py | 8 +- tests/test_languages/test_js_code_replacer.py | 4 +- .../test_multi_file_code_replacer.py | 8 +- tests/test_languages/test_python_support.py | 8 +- tests/test_mock_candidate_replacement.py | 10 +- tests/test_multi_file_code_replacement.py | 8 +- tests/test_pickle_patcher.py | 4 +- tests/test_ranking_boost.py | 2 +- tests/test_test_runner.py | 26 +- tests/test_unit_test_discovery.py | 50 +- tests/test_unused_helper_revert.py | 80 +- tests/test_worktree.py | 4 +- 246 files changed, 33587 insertions(+), 618 deletions(-) delete mode 100644 codeflash/models/function_types.py create mode 100644 codeflash/plugin.py create mode 100644 codeflash/plugin_ai_ops.py create mode 100644 codeflash/plugin_helpers.py create mode 100644 codeflash/plugin_results.py create mode 100644 codeflash/plugin_test_lifecycle.py create mode 100644 codeflash/verification/test_runner.py create mode 100644 src/codeflash_python/__init__.py create mode 100644 src/codeflash_python/api/__init__.py create mode 100644 src/codeflash_python/api/aiservice.py create mode 100644 src/codeflash_python/api/aiservice_optimize.py create mode 100644 src/codeflash_python/api/aiservice_results.py create mode 100644 src/codeflash_python/api/aiservice_testgen.py create mode 100644 src/codeflash_python/api/cfapi.py create mode 100644 src/codeflash_python/api/types.py create mode 100644 src/codeflash_python/benchmarking/__init__.py create mode 100644 src/codeflash_python/benchmarking/codeflash_trace.py create mode 100644 src/codeflash_python/benchmarking/function_ranker.py create mode 100644 src/codeflash_python/benchmarking/instrument_codeflash_trace.py create mode 100644 src/codeflash_python/benchmarking/parse_line_profile_test_output.py create mode 100644 src/codeflash_python/benchmarking/plugin/__init__.py create mode 100644 src/codeflash_python/benchmarking/plugin/plugin.py create mode 100644 src/codeflash_python/benchmarking/profile_stats.py create mode 100644 src/codeflash_python/benchmarking/pytest_new_process_trace_benchmarks.py create mode 100644 src/codeflash_python/benchmarking/replay_test.py create mode 100644 src/codeflash_python/benchmarking/trace_benchmarks.py create mode 100644 src/codeflash_python/benchmarking/tracing_new_process.py create mode 100644 src/codeflash_python/benchmarking/tracing_utils.py create mode 100644 src/codeflash_python/benchmarking/utils.py create mode 100644 src/codeflash_python/cli.py create mode 100644 src/codeflash_python/cli_common.py create mode 100644 src/codeflash_python/code_utils/__init__.py create mode 100644 src/codeflash_python/code_utils/checkpoint.py create mode 100644 src/codeflash_python/code_utils/code_utils.py create mode 100644 src/codeflash_python/code_utils/codeflash_wrap_decorator.py create mode 100644 src/codeflash_python/code_utils/compat.py create mode 100644 src/codeflash_python/code_utils/config_consts.py create mode 100644 src/codeflash_python/code_utils/config_parser.py create mode 100644 src/codeflash_python/code_utils/env_utils.py create mode 100644 src/codeflash_python/code_utils/formatter.py create mode 100644 src/codeflash_python/code_utils/git_utils.py create mode 100644 src/codeflash_python/code_utils/shell_utils.py create mode 100644 src/codeflash_python/code_utils/tabulate.py create mode 100644 src/codeflash_python/code_utils/time_utils.py create mode 100644 src/codeflash_python/code_utils/version_check.py create mode 100644 src/codeflash_python/context/__init__.py create mode 100644 src/codeflash_python/context/ast_helpers.py create mode 100644 src/codeflash_python/context/call_graph_index.py create mode 100644 src/codeflash_python/context/class_extraction.py create mode 100644 src/codeflash_python/context/code_context_extractor.py create mode 100644 src/codeflash_python/context/cst_pruning.py create mode 100644 src/codeflash_python/context/jedi_helpers.py create mode 100644 src/codeflash_python/context/type_extraction.py create mode 100644 src/codeflash_python/context/types.py create mode 100644 src/codeflash_python/context/unused_definition_remover.py create mode 100644 src/codeflash_python/context/unused_helper_detection.py create mode 100644 src/codeflash_python/context/utils.py create mode 100644 src/codeflash_python/discovery/__init__.py create mode 100644 src/codeflash_python/discovery/discover_unit_tests.py create mode 100644 src/codeflash_python/discovery/filter_criteria.py create mode 100644 src/codeflash_python/discovery/function_filtering.py create mode 100644 src/codeflash_python/discovery/function_visitors.py create mode 100644 src/codeflash_python/discovery/functions_to_optimize.py create mode 100644 src/codeflash_python/discovery/import_analyzer.py create mode 100644 src/codeflash_python/discovery/pytest_new_process_discovery.py create mode 100644 src/codeflash_python/discovery/tests_cache.py create mode 100644 src/codeflash_python/function_optimizer.py create mode 100644 src/codeflash_python/init_config.py create mode 100644 src/codeflash_python/models/__init__.py create mode 100644 src/codeflash_python/models/call_graph.py create mode 100644 src/codeflash_python/models/experiment_metadata.py create mode 100644 src/codeflash_python/models/function_types.py create mode 100644 src/codeflash_python/models/models.py create mode 100644 src/codeflash_python/models/test_result.py create mode 100644 src/codeflash_python/models/test_type.py create mode 100644 src/codeflash_python/normalizer.py create mode 100644 src/codeflash_python/optimization/__init__.py create mode 100644 src/codeflash_python/optimization/optimizer.py create mode 100644 src/codeflash_python/optimizer.py create mode 100644 src/codeflash_python/optimizer_mixins/__init__.py create mode 100644 src/codeflash_python/optimizer_mixins/_protocol.py create mode 100644 src/codeflash_python/optimizer_mixins/baseline.py create mode 100644 src/codeflash_python/optimizer_mixins/candidate_evaluation.py create mode 100644 src/codeflash_python/optimizer_mixins/candidate_structures.py create mode 100644 src/codeflash_python/optimizer_mixins/code_replacement.py create mode 100644 src/codeflash_python/optimizer_mixins/refinement.py create mode 100644 src/codeflash_python/optimizer_mixins/result_processing.py create mode 100644 src/codeflash_python/optimizer_mixins/scoring.py create mode 100644 src/codeflash_python/optimizer_mixins/test_execution.py create mode 100644 src/codeflash_python/optimizer_mixins/test_generation.py create mode 100644 src/codeflash_python/optimizer_mixins/test_review.py create mode 100644 src/codeflash_python/picklepatch/__init__.py create mode 100644 src/codeflash_python/picklepatch/pickle_patcher.py create mode 100644 src/codeflash_python/picklepatch/pickle_placeholder.py create mode 100644 src/codeflash_python/plugin.py create mode 100644 src/codeflash_python/plugin_ai_ops.py create mode 100644 src/codeflash_python/plugin_helpers.py create mode 100644 src/codeflash_python/plugin_results.py create mode 100644 src/codeflash_python/plugin_test_lifecycle.py create mode 100644 src/codeflash_python/result/__init__.py create mode 100644 src/codeflash_python/result/create_pr.py create mode 100644 src/codeflash_python/result/critic.py create mode 100644 src/codeflash_python/result/explanation.py create mode 100644 src/codeflash_python/result/github_utils.py create mode 100644 src/codeflash_python/result/pr_comment.py create mode 100644 src/codeflash_python/setup/__init__.py create mode 100644 src/codeflash_python/setup/config_schema.py create mode 100644 src/codeflash_python/setup/config_writer.py create mode 100644 src/codeflash_python/setup/detector.py create mode 100644 src/codeflash_python/setup/detector_python.py create mode 100644 src/codeflash_python/setup/first_run.py create mode 100644 src/codeflash_python/static_analysis/__init__.py create mode 100644 src/codeflash_python/static_analysis/code_extractor.py create mode 100644 src/codeflash_python/static_analysis/code_replacer.py create mode 100644 src/codeflash_python/static_analysis/code_replacer_base.py create mode 100644 src/codeflash_python/static_analysis/concolic_utils.py create mode 100644 src/codeflash_python/static_analysis/coverage_utils.py create mode 100644 src/codeflash_python/static_analysis/global_code_transforms.py create mode 100644 src/codeflash_python/static_analysis/import_analysis.py create mode 100644 src/codeflash_python/static_analysis/line_profile_utils.py create mode 100644 src/codeflash_python/static_analysis/numerical_detection.py create mode 100644 src/codeflash_python/static_analysis/reference_analysis.py create mode 100644 src/codeflash_python/static_analysis/static_analysis.py create mode 100644 src/codeflash_python/telemetry/__init__.py create mode 100644 src/codeflash_python/telemetry/posthog_cf.py create mode 100644 src/codeflash_python/verification/__init__.py create mode 100644 src/codeflash_python/verification/addopts.py create mode 100644 src/codeflash_python/verification/async_instrumentation.py create mode 100644 src/codeflash_python/verification/codeflash_capture.py create mode 100644 src/codeflash_python/verification/comparator.py create mode 100644 src/codeflash_python/verification/concolic.py create mode 100644 src/codeflash_python/verification/coverage_utils.py create mode 100644 src/codeflash_python/verification/device_sync.py create mode 100644 src/codeflash_python/verification/edit_generated_tests.py create mode 100644 src/codeflash_python/verification/equivalence.py create mode 100644 src/codeflash_python/verification/instrument_codeflash_capture.py create mode 100644 src/codeflash_python/verification/instrument_existing_tests.py create mode 100644 src/codeflash_python/verification/parse_test_output.py create mode 100644 src/codeflash_python/verification/parse_xml.py create mode 100644 src/codeflash_python/verification/path_utils.py create mode 100644 src/codeflash_python/verification/pytest_plugin.py create mode 100644 src/codeflash_python/verification/test_output_utils.py create mode 100644 src/codeflash_python/verification/test_runner.py create mode 100644 src/codeflash_python/verification/verification_utils.py create mode 100644 src/codeflash_python/verification/verifier.py create mode 100644 src/codeflash_python/verification/wrapper_generation.py diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 338d5eaf9..8caa61425 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -32,7 +32,6 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( AIServiceAdaptiveOptimizeRequest, @@ -40,6 +39,7 @@ AIServiceRefinerRequest, ) from codeflash.result.explanation import Explanation + from codeflash_core.models import FunctionToOptimize class AiServiceClient: diff --git a/codeflash/benchmarking/function_ranker.py b/codeflash/benchmarking/function_ranker.py index da565c6d7..644ad409b 100644 --- a/codeflash/benchmarking/function_ranker.py +++ b/codeflash/benchmarking/function_ranker.py @@ -4,14 +4,14 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.tracing.profile_stats import ProfileStats +from codeflash_core.models import FunctionToOptimize if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.jfr_parser import JfrProfile + from codeflash_core.models import FunctionToOptimize pytest_patterns = { " tuple[dict[str, set[FunctionCalledInTest]], int, int]: tests_root = cfg.tests_root - project_root = cfg.project_root_path + project_root = cfg.project_root tmp_pickle_path = get_run_tmp_file("collected_tests.pkl") with custom_addopts(): @@ -863,7 +861,7 @@ def process_test_files( ) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: import jedi - project_root_path = cfg.project_root_path + project_root_path = cfg.project_root test_framework = cfg.test_framework if functions_to_optimize: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 5780f4def..bff90fa36 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -29,17 +29,14 @@ from codeflash.languages.language_enum import Language from codeflash.languages.registry import get_language_support, get_supported_extensions, is_language_supported from codeflash.lsp.helpers import is_LSP_enabled -from codeflash.models.function_types import FunctionParent, FunctionToOptimize from codeflash.telemetry.posthog_cf import ph -# Re-export for backward compatibility -__all__ = ["FunctionParent", "FunctionToOptimize"] - if TYPE_CHECKING: from argparse import Namespace from codeflash.models.models import CodeOptimizationContext - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig + from codeflash_core.models import FunctionToOptimize @dataclass(frozen=True) diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index b0daea0fb..b66c3211b 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -64,7 +64,7 @@ # Lazy imports to avoid circular imports def __getattr__(name: str): if name == "FunctionInfo": - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize return FunctionToOptimize if name == "JavaScriptSupport": diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index bcdabeb8d..9c385a271 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -17,13 +17,15 @@ from collections.abc import Callable, Iterable, Sequence from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.call_graph import CallGraph from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig + from codeflash_core.models import FunctionToOptimize from codeflash.languages.language_enum import Language -from codeflash.models.function_types import FunctionParent +from codeflash_core.models import FunctionParent, HelperFunction + +__all__ = ["FunctionParent", "HelperFunction"] # Backward compatibility aliases - ParentInfo is now FunctionParent ParentInfo = FunctionParent @@ -33,7 +35,7 @@ # This allows `from codeflash.languages.base import FunctionInfo` to work at runtime def __getattr__(name: str) -> Any: if name == "FunctionInfo": - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize return FunctionToOptimize msg = f"module {__name__!r} has no attribute {name!r}" @@ -50,31 +52,6 @@ class IndexResult: error: bool -@dataclass -class HelperFunction: - """A helper function that is a dependency of the target function. - - Helper functions are functions called by the target function that are - within the same module/project (not external libraries). - - Attributes: - name: The simple function name. - qualified_name: Full qualified name including parent scopes. - file_path: Path to the file containing the helper. - source_code: The source code of the helper function. - start_line: Starting line number. - end_line: Ending line number. - - """ - - name: str - qualified_name: str - file_path: Path - source_code: str - start_line: int - end_line: int - - @dataclass class CodeContext: """Code context extracted for optimization. diff --git a/codeflash/languages/code_replacer.py b/codeflash/languages/code_replacer.py index 17879ace7..550094156 100644 --- a/codeflash/languages/code_replacer.py +++ b/codeflash/languages/code_replacer.py @@ -15,9 +15,9 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import LanguageSupport from codeflash.models.models import CodeStringsMarkdown + from codeflash_core.models import FunctionToOptimize # Permissive criteria for discovering functions in code snippets (no export/return filtering) _SOURCE_CRITERIA = FunctionFilterCriteria(require_return=False, require_export=False) diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 9b6e01976..9cd8407a4 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -115,10 +115,8 @@ from argparse import Namespace from typing import Any - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result from codeflash.languages.base import DependencyResolver - from codeflash.models.function_types import FunctionParent from codeflash.models.models import ( BenchmarkKey, CodeOptimizationContext, @@ -130,7 +128,8 @@ TestDiff, TestFileReview, ) - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig + from codeflash_core.models import FunctionParent, FunctionToOptimize def log_optimization_context(function_name: str, code_context: CodeOptimizationContext) -> None: @@ -475,7 +474,7 @@ def __init__( call_graph: DependencyResolver | None = None, effort_override: str | None = None, ) -> None: - self.project_root = test_cfg.project_root_path.resolve() + self.project_root = test_cfg.project_root.resolve() self.test_cfg = test_cfg self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient() resolved_file_path = function_to_optimize.file_path.resolve() diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 338ac5102..d0c31c533 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -21,8 +21,8 @@ from tree_sitter import Node - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) diff --git a/codeflash/languages/java/discovery.py b/codeflash/languages/java/discovery.py index cb610cb18..e588ce9bf 100644 --- a/codeflash/languages/java/discovery.py +++ b/codeflash/languages/java/discovery.py @@ -10,10 +10,9 @@ from pathlib import Path from typing import TYPE_CHECKING -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import FunctionFilterCriteria from codeflash.languages.java.parser import get_java_analyzer -from codeflash.models.function_types import FunctionParent +from codeflash_core.models import FunctionParent, FunctionToOptimize if TYPE_CHECKING: from tree_sitter import Node diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 9ecbd613e..34ce11740 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -28,8 +28,8 @@ from pathlib import Path from typing import Any - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import JavaAnalyzer + from codeflash_core.models import FunctionToOptimize _WORD_RE = re.compile(r"^\w+$") diff --git a/codeflash/languages/java/parse.py b/codeflash/languages/java/parse.py index 1d4b8f2f4..00dc6bc55 100644 --- a/codeflash/languages/java/parse.py +++ b/codeflash/languages/java/parse.py @@ -24,7 +24,7 @@ from pathlib import Path from codeflash.models.models import TestFiles - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!") end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!") diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index 462fcc486..c7688193b 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -25,8 +25,8 @@ if TYPE_CHECKING: from tree_sitter import Node - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import JavaAnalyzer + from codeflash_core.models import FunctionToOptimize _ASSIGN_RE = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$") diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 5ed9bf8f1..2f8fd3eae 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -18,8 +18,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import get_java_analyzer +from codeflash_core.models import FunctionToOptimize if TYPE_CHECKING: from codeflash.languages.java.parser import JavaAnalyzer diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py index 825c7e7da..d88fb5b65 100644 --- a/codeflash/languages/java/support.py +++ b/codeflash/languages/java/support.py @@ -38,10 +38,10 @@ from collections.abc import Sequence from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult from codeflash.languages.java.concurrency_analyzer import ConcurrencyInfo from codeflash.models.models import GeneratedTestsList, InvocationId + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) @@ -405,7 +405,7 @@ def load_coverage( def setup_test_config(self, test_cfg: Any, file_path: Path, current_worktree: Path | None = None) -> None: """Detect test framework from project build config (pom.xml / build.gradle).""" - config = detect_java_project(test_cfg.project_root_path) + config = detect_java_project(test_cfg.project_root) if config is not None: self._test_framework = config.test_framework diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index 7db04298b..ba8a37d86 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -26,8 +26,8 @@ from tree_sitter import Node - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import JavaAnalyzer + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index ed6e30636..0843e4f26 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -22,8 +22,8 @@ from tree_sitter import Node - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ class ReferenceFinder: Example usage: ```python from codeflash.languages.javascript.find_references import ReferenceFinder - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize func = FunctionToOptimize( function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript" @@ -827,7 +827,7 @@ def find_references( ```python from pathlib import Path from codeflash.languages.javascript.find_references import find_references - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize func = FunctionToOptimize( function_name="myHelper", file_path=Path("/my/project/src/utils.ts"), parents=[], language="javascript" diff --git a/codeflash/languages/javascript/import_resolver.py b/codeflash/languages/javascript/import_resolver.py index 34dd1990f..bc3563827 100644 --- a/codeflash/languages/javascript/import_resolver.py +++ b/codeflash/languages/javascript/import_resolver.py @@ -12,9 +12,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import HelperFunction from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) @@ -544,9 +544,9 @@ def _find_helpers_recursive( Dictionary mapping file paths to lists of helper functions. """ - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.javascript.treesitter import get_analyzer_for_file from codeflash.languages.registry import get_language_support + from codeflash_core.models import FunctionToOptimize if context.current_depth >= context.max_depth: return {} diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index 8bcd0b2ee..15540aa10 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from codeflash.code_utils.code_position import CodePosition - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize class TestingMode: diff --git a/codeflash/languages/javascript/line_profiler.py b/codeflash/languages/javascript/line_profiler.py index 81b38983c..144a2e22c 100644 --- a/codeflash/languages/javascript/line_profiler.py +++ b/codeflash/languages/javascript/line_profiler.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) diff --git a/codeflash/languages/javascript/optimizer.py b/codeflash/languages/javascript/optimizer.py index bc88786b1..b57ebc8ab 100644 --- a/codeflash/languages/javascript/optimizer.py +++ b/codeflash/languages/javascript/optimizer.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig def prepare_javascript_module( diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index 8e50100da..7f5f45f97 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -22,7 +22,7 @@ import subprocess from codeflash.models.models import TestFiles - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig # Jest timing marker patterns (from codeflash-jest-helper.js console.log output) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 039d1ce98..6bc8314e2 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -13,11 +13,11 @@ from typing import TYPE_CHECKING, Any from codeflash.code_utils.git_utils import git_root_dir, mirror_path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, Language, TestInfo, TestResult from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file from codeflash.languages.registry import register_language from codeflash.models.models import FunctionParent +from codeflash_core.models import FunctionToOptimize if TYPE_CHECKING: from collections.abc import Sequence @@ -25,7 +25,7 @@ from codeflash.languages.base import ReferenceInfo from codeflash.languages.javascript.treesitter import TypeDefinition from codeflash.models.models import GeneratedTestsList, InvocationId, ValidCode - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig logger = logging.getLogger(__name__) diff --git a/codeflash/languages/javascript/tracer.py b/codeflash/languages/javascript/tracer.py index 2f5791ee0..e1d83fecc 100644 --- a/codeflash/languages/javascript/tracer.py +++ b/codeflash/languages/javascript/tracer.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize logger = logging.getLogger(__name__) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index e94cede2d..86718805a 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -19,7 +19,6 @@ TESTGEN_CONTEXT_TOKEN_LIMIT, TESTGEN_LIMIT_ERROR, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 from codeflash.languages.python.context.unused_definition_remover import ( collect_top_level_defs_with_dependencies, collect_top_level_defs_with_usages, @@ -41,6 +40,7 @@ CodeStringsMarkdown, FunctionSource, ) +from codeflash_core.models import FunctionToOptimize # noqa: TC001 if TYPE_CHECKING: from pathlib import Path diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index aaa0435f8..dd8c2c216 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: from collections.abc import Callable - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext, FunctionSource + from codeflash_core.models import FunctionToOptimize @dataclass diff --git a/codeflash/languages/python/function_optimizer.py b/codeflash/languages/python/function_optimizer.py index 1677bf8bb..0d3e38aae 100644 --- a/codeflash/languages/python/function_optimizer.py +++ b/codeflash/languages/python/function_optimizer.py @@ -30,7 +30,6 @@ from codeflash.either import Result from codeflash.languages.base import Language - from codeflash.models.function_types import FunctionParent from codeflash.models.models import ( CodeOptimizationContext, CodeStringsMarkdown, @@ -41,6 +40,7 @@ TestDiff, TestFileReview, ) + from codeflash_core.models import FunctionParent class PythonFunctionOptimizer(FunctionOptimizer): diff --git a/codeflash/languages/python/instrument_codeflash_capture.py b/codeflash/languages/python/instrument_codeflash_capture.py index cda06b1dc..3e1bd98ca 100644 --- a/codeflash/languages/python/instrument_codeflash_capture.py +++ b/codeflash/languages/python/instrument_codeflash_capture.py @@ -9,7 +9,7 @@ from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES if TYPE_CHECKING: - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize def instrument_codeflash_capture( diff --git a/codeflash/languages/python/optimizer.py b/codeflash/languages/python/optimizer.py index 475c834fc..e60551d61 100644 --- a/codeflash/languages/python/optimizer.py +++ b/codeflash/languages/python/optimizer.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.models.function_types import FunctionParent + from codeflash_core.models import FunctionParent def prepare_python_module( diff --git a/codeflash/languages/python/parse_xml.py b/codeflash/languages/python/parse_xml.py index 840fa2055..2a308eac4 100644 --- a/codeflash/languages/python/parse_xml.py +++ b/codeflash/languages/python/parse_xml.py @@ -26,7 +26,7 @@ from pathlib import Path from codeflash.models.models import TestFiles - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig matches_re_start = re.compile( r"!\$######([^:]*)" # group 1: module path diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 454aeac9a..37204efcd 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -21,8 +21,8 @@ from libcst.helpers import ModuleNameAndPackage - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionSource + from codeflash_core.models import FunctionToOptimize _SENTINEL = object() @@ -953,9 +953,9 @@ def get_opt_review_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: Language ) -> str: """Get markdown-formatted calling function context for optimization review.""" - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.registry import get_language_support from codeflash.models.models import FunctionParent + from codeflash_core.models import FunctionToOptimize start_time = time.perf_counter() diff --git a/codeflash/languages/python/static_analysis/line_profile_utils.py b/codeflash/languages/python/static_analysis/line_profile_utils.py index 93997b2c6..78326f4f0 100644 --- a/codeflash/languages/python/static_analysis/line_profile_utils.py +++ b/codeflash/languages/python/static_analysis/line_profile_utils.py @@ -13,8 +13,8 @@ from codeflash.code_utils.formatter import sort_imports if TYPE_CHECKING: - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext + from codeflash_core.models import FunctionToOptimize # Known JIT decorators organized by module # Format: {module_path: {decorator_name, ...}} diff --git a/codeflash/languages/python/static_analysis/static_analysis.py b/codeflash/languages/python/static_analysis/static_analysis.py index a0d04bfb1..f36c3b591 100644 --- a/codeflash/languages/python/static_analysis/static_analysis.py +++ b/codeflash/languages/python/static_analysis/static_analysis.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict, field_validator if TYPE_CHECKING: - from codeflash.models.function_types import FunctionParent + from codeflash_core.models import FunctionParent ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index ccf74ea86..6639d5b37 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -9,7 +9,6 @@ import libcst as cst -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import ( CodeContext, FunctionFilterCriteria, @@ -20,7 +19,7 @@ TestResult, ) from codeflash.languages.registry import register_language -from codeflash.models.function_types import FunctionParent +from codeflash_core.models import FunctionParent, FunctionToOptimize if TYPE_CHECKING: import ast @@ -31,7 +30,7 @@ from codeflash.languages.base import DependencyResolver from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig logger = logging.getLogger(__name__) @@ -1045,7 +1044,7 @@ def prepare_module( pytest_cmd: str = "pytest" def setup_test_config(self, test_cfg: TestConfig, file_path: Path, current_worktree: Path | None = None) -> None: - self.pytest_cmd = test_cfg.pytest_cmd or "pytest" + self.pytest_cmd = test_cfg.test_command or "pytest" def pytest_cmd_tokens(self, is_posix: bool) -> list[str]: import shlex @@ -1274,7 +1273,7 @@ def generate_concolic_tests( from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters from codeflash.lsp.helpers import is_LSP_enabled from codeflash.telemetry.posthog_cf import ph - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig crosshair_available = importlib.util.find_spec("crosshair") is not None @@ -1342,7 +1341,7 @@ def generate_concolic_tests( concolic_test_cfg = TestConfig( tests_root=concolic_test_suite_dir, tests_project_rootdir=test_cfg.concolic_test_root_dir, - project_root_path=project_root, + project_root=project_root, ) function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg) logger.info( diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 31349b841..ab1a28d8f 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -42,8 +42,8 @@ from lsprotocol import types - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.lsp.server import WrappedInitializationResultT + from codeflash_core.models import FunctionToOptimize @dataclass diff --git a/codeflash/models/function_types.py b/codeflash/models/function_types.py deleted file mode 100644 index bea6672b0..000000000 --- a/codeflash/models/function_types.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Simple function-related types with no dependencies. - -This module contains basic types used for function representation. -It is intentionally kept dependency-free to avoid circular imports. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Optional - -from pydantic import Field -from pydantic.dataclasses import dataclass - - -@dataclass(frozen=True) -class FunctionParent: - name: str - type: str - - def __str__(self) -> str: - return f"{self.type}:{self.name}" - - -@dataclass(frozen=True, config={"arbitrary_types_allowed": True}) -class FunctionToOptimize: - """Represent a function that is a candidate for optimization. - - This is the canonical dataclass for representing functions across all languages - (Python, JavaScript, TypeScript). It captures all information needed to identify, - locate, and work with a function. - - Attributes - ---------- - function_name: The name of the function. - file_path: The absolute file path where the function is located. - parents: A list of parent scopes, which could be classes or functions. - starting_line: The starting line number of the function in the file (1-indexed). - ending_line: The ending line number of the function in the file (1-indexed). - starting_col: The starting column offset (0-indexed, for precise location). - ending_col: The ending column offset (0-indexed, for precise location). - is_async: Whether this function is defined as async. - is_method: Whether this is a method (belongs to a class). - language: The programming language of this function (default: "python"). - doc_start_line: Line where docstring/JSDoc starts (or None if no doc comment). - - The qualified_name property provides the full name of the function, including - any parent class or function names. The qualified_name_with_modules_from_root - method extends this with the module name from the project root. - - """ - - function_name: str - file_path: Path - parents: list[FunctionParent] = Field(default_factory=list) - starting_line: Optional[int] = None - ending_line: Optional[int] = None - starting_col: Optional[int] = None - ending_col: Optional[int] = None - is_async: bool = False - is_method: bool = False - language: str = "python" - doc_start_line: Optional[int] = None - - @property - def top_level_parent_name(self) -> str: - return self.function_name if not self.parents else self.parents[0].name - - @property - def class_name(self) -> str | None: - """Get the immediate parent class name, if any.""" - for parent in reversed(self.parents): - if parent.type == "ClassDef": - return parent.name - return None - - def __str__(self) -> str: - qualified = f"{'.'.join([p.name for p in self.parents])}{'.' if self.parents else ''}{self.function_name}" - line_info = f":{self.starting_line}-{self.ending_line}" if self.starting_line and self.ending_line else "" - return f"{self.file_path}:{qualified}{line_info}" - - @property - def qualified_name(self) -> str: - if not self.parents: - return self.function_name - parent_path = ".".join(parent.name for parent in self.parents) - return f"{parent_path}.{self.function_name}" - - def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: - # Import here to avoid circular imports - from codeflash.code_utils.code_utils import module_name_from_file_path - - return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" diff --git a/codeflash/models/models.py b/codeflash/models/models.py index b8345dc2f..709cd7ccc 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -627,7 +627,7 @@ class CodePosition: # Re-export FunctionParent for backward compatibility -from codeflash.models.function_types import FunctionParent # noqa: E402 +from codeflash_core.models import FunctionParent # noqa: E402 class OriginalCodeBaseline(BaseModel): diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 8e9c08ac2..6ff310f1e 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -27,7 +27,7 @@ from codeflash.languages import current_language_support, set_current_language from codeflash.lsp.helpers import is_subagent_mode from codeflash.telemetry.posthog_cf import ph -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig if TYPE_CHECKING: import ast @@ -35,10 +35,10 @@ from codeflash.benchmarking.function_ranker import FunctionRanker from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import DependencyResolver from codeflash.languages.function_optimizer import FunctionOptimizer from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode + from codeflash_core.models import FunctionToOptimize def _extract_java_package_from_path(file_path: Path) -> str | None: @@ -60,9 +60,8 @@ def __init__(self, args: Namespace) -> None: self.test_cfg = TestConfig( tests_root=args.tests_root, tests_project_rootdir=args.test_project_root, - project_root_path=args.project_root, - # TODO: Can rename it for language agnostic - pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest", + project_root=args.project_root, + test_command=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest", benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None, ) @@ -506,12 +505,11 @@ def run(self) -> None: function_optimizer = None file_to_funcs_to_optimize, num_optimizable_functions, trace_file_path = self.get_optimizable_functions() - # Set language on TestConfig and global singleton based on discovered functions + # Set language global singleton based on discovered functions if file_to_funcs_to_optimize: for file_path, funcs in file_to_funcs_to_optimize.items(): if funcs and funcs[0].language: set_current_language(funcs[0].language) - self.test_cfg.set_language(funcs[0].language) current_language_support().setup_test_config(self.test_cfg, file_path, self.current_worktree) break @@ -799,7 +797,7 @@ def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None: # mirror project_root self.args.project_root = mirror_path(self.args.project_root, original_git_root, worktree_dir) - self.test_cfg.project_root_path = mirror_path(self.test_cfg.project_root_path, original_git_root, worktree_dir) + self.test_cfg.project_root = mirror_path(self.test_cfg.project_root, original_git_root, worktree_dir) # mirror module_root self.args.module_root = mirror_path(self.args.module_root, original_git_root, worktree_dir) diff --git a/codeflash/plugin.py b/codeflash/plugin.py new file mode 100644 index 000000000..84cf3357e --- /dev/null +++ b/codeflash/plugin.py @@ -0,0 +1,554 @@ +"""PythonPlugin — adapter wiring codeflash to the codeflash_core LanguagePlugin protocol.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.plugin_ai_ops import PluginAiOpsMixin +from codeflash.plugin_helpers import ( + format_code_with_ruff_or_black, + make_test_env, + read_return_values, + replace_function_simple, +) +from codeflash.plugin_results import PluginResultsMixin +from codeflash.plugin_test_lifecycle import PluginTestLifecycleMixin +from codeflash.verification.test_runner import run_tests +from codeflash_core.models import BenchmarkResults, CodeContext, TestOutcome, TestOutcomeStatus, TestResults + +if TYPE_CHECKING: + import threading + + from codeflash.api.aiservice import AiServiceClient + from codeflash.models.models import CodeOptimizationContext + from codeflash_core.config import TestConfig + from codeflash_core.models import CoverageData, FunctionToOptimize + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Plugin +# --------------------------------------------------------------------------- + + +class PythonPlugin(PluginAiOpsMixin, PluginTestLifecycleMixin, PluginResultsMixin): + """Implements the codeflash_core LanguagePlugin protocol for Python. + + Converts between core types and internal types at the boundary. + """ + + def __init__(self, project_root: Path) -> None: + self.project_root = project_root + self.last_internal_context: CodeOptimizationContext | None = None # cache for get_candidates + self.current_function: FunctionToOptimize | None = None # cache for coverage + self.tests_project_rootdir: Path | None = None # cached from test_config + self.is_numerical_code: bool | None = None # cached from generate_tests + self.ai_client: AiServiceClient | None = None + self.pending_code_markdown: str = "" # set by optimizer before replace_function + self.cancel_event: threading.Event | None = None # set by optimizer for cooperative cancellation + self.dependency_counts: dict[str, int] = {} + + def is_cancelled(self) -> bool: + return self.cancel_event is not None and self.cancel_event.is_set() + + def get_ai_client(self) -> AiServiceClient: + if self.ai_client is not None: + return self.ai_client + from codeflash.api.aiservice import AiServiceClient + + client = AiServiceClient() + self.ai_client = client + return client + + # -- cleanup, comparison, environment validation -------------------------- + + def cleanup_run(self, tests_root: Path) -> None: + import contextlib + import shutil + + from codeflash.code_utils.code_utils import get_run_tmp_file + from codeflash.optimization.optimizer import Optimizer as PyOptimizer + + # Remove leftover instrumented test files + if tests_root.exists(): + leftover = PyOptimizer.find_leftover_instrumented_test_files(tests_root) + for p in leftover: + with contextlib.suppress(OSError): + p.unlink(missing_ok=True) + + # Remove leftover return-value files (indices 0-30 match max_total in evaluate_candidates) + for i in range(31): + with contextlib.suppress(OSError): + get_run_tmp_file(Path(f"test_return_values_{i}.bin")).unlink(missing_ok=True) + with contextlib.suppress(OSError): + get_run_tmp_file(Path(f"test_return_values_{i}.sqlite")).unlink(missing_ok=True) + + # Remove the shared temp directory + if hasattr(get_run_tmp_file, "tmpdir_path"): + shutil.rmtree(get_run_tmp_file.tmpdir_path, ignore_errors=True) + del get_run_tmp_file.tmpdir_path + + def compare_outputs(self, baseline_output: object, candidate_output: object) -> bool: + from codeflash.verification.comparator import comparator + + return comparator(baseline_output, candidate_output) + + def validate_environment(self, config: object) -> bool: + from codeflash.code_utils.env_utils import check_formatter_installed + + if hasattr(config, "formatter_cmds") and config.formatter_cmds: + return check_formatter_installed(config.formatter_cmds) + return True + + # -- discover_functions -------------------------------------------------- + + def discover_functions(self, paths: list[Path]) -> list[FunctionToOptimize]: + from codeflash.languages.python.support import PythonSupport + + support = PythonSupport() + results: list[FunctionToOptimize] = [] + for path in paths: + try: + source = path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError) as exc: + logger.warning("Skipping %s: %s", path, exc) + continue + + try: + internal_fns = support.discover_functions(source, path) + except Exception as exc: + logger.warning("Skipping %s: failed to parse (%s)", path, exc) + continue + for fn in internal_fns: + # Attach source code so the core optimizer has it + lines = source.splitlines() + if fn.starting_line and fn.ending_line: + fn.source_code = "\n".join(lines[fn.starting_line - 1 : fn.ending_line]) + results.append(fn) + return results + + # -- build_index / rank_functions ----------------------------------------- + + def build_index(self, files: list[Path], on_progress: object = None) -> None: + # CallGraphIndex not available in main repo — no-op for now + pass + + def rank_functions( + self, + functions: list[FunctionToOptimize], + trace_file: Path | None = None, + test_counts: dict[tuple[Path, str], int] | None = None, + ) -> list[FunctionToOptimize]: + if not functions: + return functions + + # Primary: rank by trace-based addressable time (filters low-importance functions) + if trace_file and trace_file.exists(): + try: + from codeflash.benchmarking.function_ranker import FunctionRanker + + ranker = FunctionRanker(trace_file) + ranked = ranker.rank_functions(functions) + if test_counts: + ranked.sort( + key=lambda f: ( + -ranker.get_function_addressable_time(f), + -test_counts.get((f.file_path, f.qualified_name), 0), + ) + ) + logger.debug( + "Ranked %d functions by addressable time (filtered %d low-importance)", + len(ranked), + len(functions) - len(ranked), + ) + return ranked + except Exception: + logger.warning("Trace-based ranking failed, falling back to original order") + + # Fallback: return as-is (no CallGraphIndex available) + return functions + + def get_dependency_counts(self) -> dict[str, int]: + return self.dependency_counts + + # -- extract_context ----------------------------------------------------- + + def extract_context(self, function: FunctionToOptimize) -> CodeContext: + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context + from codeflash.languages.python.support import function_sources_to_helpers + + internal_fn = function + ctx = get_code_optimization_context(internal_fn, self.project_root, call_graph=None) + self.last_internal_context = ctx + self.current_function = function + + helpers = function_sources_to_helpers(ctx.helper_functions) + + return CodeContext( + target_function=function, + target_code=ctx.read_writable_code.flat if ctx.read_writable_code else function.source_code, + target_file=function.file_path, + helper_functions=helpers, + read_only_context=ctx.read_only_context_code, + ) + + # -- run_tests ----------------------------------------------------------- + + def run_tests( + self, + test_config: TestConfig, + test_files: list[Path] | None = None, + test_iteration: int = 0, + enable_coverage: bool = False, + ) -> TestResults | tuple[TestResults, CoverageData | None]: + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + + if not files_to_run: + return TestResults(passed=True) + + # Clean up stale return-value files before this iteration (matches original) + from codeflash.code_utils.code_utils import get_run_tmp_file + + for ext in (".bin", ".sqlite"): + get_run_tmp_file(Path(f"test_return_values_{test_iteration}{ext}")).unlink(missing_ok=True) + + env = make_test_env(test_config.project_root, test_iteration=test_iteration) + timeout = int(test_config.timeout) + + results, _, cov_db, cov_config = run_tests( + test_files=files_to_run, + cwd=test_config.project_root, + env=env, + timeout=timeout, + enable_coverage=enable_coverage, + ) + + # Read return values from SQLite written by instrumented tests + return_values = read_return_values(test_iteration) + + outcomes = [] + for r in results: + # Match JUnit test name to SQLite test_function_name + # The pytest plugin strips parametrize brackets from CODEFLASH_TEST_FUNCTION + base_name = r.test_name.split("[", 1)[0] if "[" in r.test_name else r.test_name + ret_vals = return_values.get(base_name) + output = tuple(ret_vals) if ret_vals else None + + outcomes.append( + TestOutcome( + test_id=r.test_name, + status=TestOutcomeStatus.PASSED if r.passed else TestOutcomeStatus.FAILED, + duration=r.runtime_ns / 1e9 if r.runtime_ns else 0.0, + error_message=r.error_message or "", + output=output, + ) + ) + + test_results = TestResults(passed=all(r.passed for r in results), outcomes=outcomes, error=None) + + if enable_coverage: + coverage_data = self.load_coverage(cov_db, cov_config) + return test_results, coverage_data + + return test_results + + def load_coverage(self, cov_db: Path | None, cov_config: Path | None) -> CoverageData | None: + """Load coverage data from SQLite database and convert to core CoverageData.""" + if cov_db is None or cov_config is None: + return None + + function = self.current_function + code_context = self.last_internal_context + if function is None or code_context is None: + return None + + try: + from codeflash.verification.coverage_utils import CoverageUtils + from codeflash_core.models import CoverageData as CoreCoverageData + from codeflash_core.models import FunctionCoverage as CoreFunctionCoverage + + internal_cov = CoverageUtils.load_from_sqlite_database( + database_path=cov_db, + config_path=cov_config, + function_name=function.qualified_name, + code_context=code_context, + source_code_path=function.file_path, + ) + + main_fc = internal_cov.main_func_coverage + core_main = CoreFunctionCoverage( + name=main_fc.name, + coverage=main_fc.coverage, + executed_lines=list(main_fc.executed_lines), + unexecuted_lines=list(main_fc.unexecuted_lines), + executed_branches=list(main_fc.executed_branches), + unexecuted_branches=list(main_fc.unexecuted_branches), + ) + + core_dep = None + if internal_cov.dependent_func_coverage: + dep = internal_cov.dependent_func_coverage + core_dep = CoreFunctionCoverage( + name=dep.name, + coverage=dep.coverage, + executed_lines=list(dep.executed_lines), + unexecuted_lines=list(dep.unexecuted_lines), + executed_branches=list(dep.executed_branches), + unexecuted_branches=list(dep.unexecuted_branches), + ) + + from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD + + return CoreCoverageData( + file_path=function.file_path, + coverage=internal_cov.coverage, + function_name=function.qualified_name, + main_func_coverage=core_main, + dependent_func_coverage=core_dep, + threshold_percentage=COVERAGE_THRESHOLD, + ) + except Exception: + logger.debug("Failed to load coverage data", exc_info=True) + return None + + # -- replace_function ---------------------------------------------------- + + def replace_function(self, file: Path, function: FunctionToOptimize, new_code: str) -> None: + internal_ctx = self.last_internal_context + code_markdown = self.pending_code_markdown + + if internal_ctx is not None and code_markdown: + try: + self.replace_function_full(function, internal_ctx, code_markdown) + return + except Exception: + logger.debug("Full replace_function failed, falling back to simple replacement", exc_info=True) + + # Fallback: simple single-file replacement + source = file.read_text(encoding="utf-8") + internal_fn = function + modified = replace_function_simple(source, internal_fn, new_code) + file.write_text(modified, encoding="utf-8") + + def replace_function_full( + self, function: FunctionToOptimize, internal_ctx: CodeOptimizationContext, code_markdown: str + ) -> None: + """Port of FunctionOptimizer.replace_function_and_helpers_with_optimized_code.""" + from collections import defaultdict + + from codeflash.languages.python.context.unused_definition_remover import ( + detect_unused_helper_functions, + revert_unused_helper_functions, + ) + from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module + from codeflash.models.models import CodeStringsMarkdown + + optimized_code = CodeStringsMarkdown.parse_markdown_code(code_markdown) + + internal_fn = function + + # Group functions by file (target + helpers where definition_type in ("function", None)) + functions_by_file: dict[Path, set[str]] = defaultdict(set) + functions_by_file[function.file_path].add(internal_fn.qualified_name) + for helper in internal_ctx.helper_functions: + if helper.definition_type in ("function", None): + functions_by_file[helper.file_path].add(helper.qualified_name) + + # Capture original helper code for unused-helper revert + original_helper_code: dict[Path, str] = {} + for hp in functions_by_file: + if hp != function.file_path and hp.exists(): + original_helper_code[hp] = hp.read_text("utf-8") + + # Replace in each file + for module_abspath, qualified_names in functions_by_file.items(): + replace_function_definitions_in_module( + function_names=list(qualified_names), + optimized_code=optimized_code, + module_abspath=module_abspath, + preexisting_objects=internal_ctx.preexisting_objects, + project_root_path=self.project_root, + ) + + # Detect and revert unused helpers + unused_helpers = detect_unused_helper_functions(internal_fn, internal_ctx, optimized_code) + if unused_helpers: + revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code) + + # -- restore_function ---------------------------------------------------- + + def restore_function(self, file: Path, function: FunctionToOptimize, original_code: str) -> None: + self.replace_function(file, function, original_code) + + # -- run_benchmarks ------------------------------------------------------ + + def run_benchmarks( + self, + function: FunctionToOptimize, + test_config: TestConfig, + test_files: list[Path] | None = None, + test_iteration: int = 0, + ) -> BenchmarkResults: + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + + if not files_to_run: + return BenchmarkResults() + + env = make_test_env(test_config.project_root, test_iteration=test_iteration) + timeout = int(test_config.timeout) + + results, *_ = run_tests( + test_files=files_to_run, + cwd=test_config.project_root, + env=env, + timeout=timeout, + min_loops=5, + max_loops=100_000, + target_seconds=10.0, + stability_check=True, + ) + + timings: dict[str, float] = {} + total = 0.0 + for r in results: + if r.runtime_ns: + secs = r.runtime_ns / 1e9 + timings[r.test_name] = secs + total += secs + + return BenchmarkResults(timings=timings, total_time=total) + + # -- format_code --------------------------------------------------------- + + def format_code(self, code: str, file: Path) -> str: + return format_code_with_ruff_or_black(code, file) + + def validate_candidate(self, code: str) -> bool: + import ast + + try: + ast.parse(code) + return True + except SyntaxError: + return False + + def normalize_code(self, code: str) -> str: + from codeflash.languages.python.normalizer import normalize_python_code + + try: + return normalize_python_code(code, remove_docstrings=True) + except Exception: + return code + + # ======================================================================== + # Phase 2: Split behavioral / performance test running + # ======================================================================== + + def run_behavioral_tests(self, test_files: list[Path], test_config: TestConfig) -> TestResults: + result = self.run_tests(test_config, test_files=test_files) + if isinstance(result, tuple): + return result[0] + return result + + def run_performance_tests( + self, test_files: list[Path], function: FunctionToOptimize, test_config: TestConfig + ) -> BenchmarkResults: + return self.run_benchmarks(function, test_config, test_files=test_files) + + # ======================================================================== + # Phase 3: Line profiler (stays here — uses run_tests directly) + # ======================================================================== + + def run_line_profiler( + self, function: FunctionToOptimize, test_config: TestConfig, test_files: list[Path] | None = None + ) -> str: + """Run line profiler on the target function and return formatted output. + + Returns empty string if profiling fails or is not applicable. + """ + from codeflash.languages.python.parse_line_profile_test_output import parse_line_profile_results + from codeflash.languages.python.static_analysis.line_profile_utils import ( + add_decorator_imports, + contains_jit_decorator, + ) + + internal_fn = function + code_context = self.last_internal_context + if code_context is None: + logger.warning("No code context available for line profiler") + return "" + + # Read original source of function file + helper files for restore + original_sources: dict[Path, str] = {} + try: + original_sources[function.file_path] = function.file_path.read_text("utf-8") + except (OSError, UnicodeDecodeError): + logger.warning("Cannot read function file %s for line profiler", function.file_path) + return "" + + # Check JIT decorators in function file + if contains_jit_decorator(original_sources[function.file_path]): + logger.info("Skipping line profiler for %s - code contains JIT decorator", function.function_name) + return "" + + # Save and check helper file sources + for helper in code_context.helper_functions: + hp = helper.file_path + if hp not in original_sources: + try: + content = hp.read_text("utf-8") + except (OSError, UnicodeDecodeError): + continue + original_sources[hp] = content + if contains_jit_decorator(content): + logger.info( + "Skipping line profiler for %s - helper code contains JIT decorator", function.function_name + ) + return "" + + # Determine test files + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + if not files_to_run: + return "" + + try: + # Inject line profiler decorators and imports into function + helper files + lprof_output_file = add_decorator_imports(internal_fn, code_context) + + # Run tests with LINE_PROFILE=1 env var + env = make_test_env(test_config.project_root, test_iteration=0) + env["LINE_PROFILE"] = "1" + + run_tests(test_files=files_to_run, cwd=test_config.project_root, env=env, timeout=int(test_config.timeout)) + + # Parse line profiler results from .lprof file + results, _ = parse_line_profile_results(lprof_output_file) + return str(results.get("str_out", "")) + except Exception: + logger.debug("Line profiler failed for %s", function.function_name, exc_info=True) + return "" + finally: + # Restore original source files + for file_path, content in original_sources.items(): + try: + file_path.write_text(content, "utf-8") + except OSError: + logger.warning("Failed to restore %s after line profiler", file_path) diff --git a/codeflash/plugin_ai_ops.py b/codeflash/plugin_ai_ops.py new file mode 100644 index 000000000..8b18e3960 --- /dev/null +++ b/codeflash/plugin_ai_ops.py @@ -0,0 +1,242 @@ +"""Mixin: AI candidate generation, repair, refinement, adaptive optimization.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.plugin_helpers import format_speedup_pct, map_candidate_source +from codeflash_core.models import Candidate + +if TYPE_CHECKING: + from codeflash.plugin import PythonPlugin as _Base # type: ignore[attr-defined] + from codeflash_core.models import BenchmarkResults, CodeContext, ScoredCandidate, TestDiff +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginAiOpsMixin(_Base): # type: ignore[misc] + def get_candidates(self, context: CodeContext, trace_id: str = "") -> list[Candidate]: + client = self.get_ai_client() + assert trace_id, "trace_id must be provided" + + # Use cached internal context for markdown-formatted code (what the API expects) + internal_ctx = self.last_internal_context + if internal_ctx is not None: + source_code = internal_ctx.read_writable_code.markdown + dependency_code = internal_ctx.read_only_context_code + else: + source_code = context.target_code + dependency_code = context.read_only_context + + optimized = client.optimize_code( + source_code=source_code, + dependency_code=dependency_code, + trace_id=trace_id, + language="python", + is_numerical_code=self.is_numerical_code, + ) + + candidates = [] + for opt in optimized: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate(code=code, explanation=opt.explanation or "", source="optimize", code_markdown=code_md) + ) + return candidates + + def get_line_profiler_candidates( + self, context: CodeContext, line_profile_data: str, trace_id: str = "" + ) -> list[Candidate]: + assert trace_id, "trace_id must be provided" + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for line profiler") + return [] + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + dependency_code = internal_ctx.read_only_context_code if internal_ctx else context.read_only_context + + optimized = client.optimize_python_code_line_profiler( + source_code=source_code, + dependency_code=dependency_code, + trace_id=trace_id, + line_profiler_results=line_profile_data, + n_candidates=3, + ) + + candidates = [] + for opt in optimized: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate( + code=code, explanation=opt.explanation or "", source="line_profiler", code_markdown=code_md + ) + ) + return candidates + + def repair_candidate( + self, context: CodeContext, candidate: Candidate, test_diffs: list[TestDiff], trace_id: str = "" + ) -> Candidate | None: + assert trace_id, "trace_id must be provided" + from codeflash.models.models import AIServiceCodeRepairRequest, TestDiffScope + from codeflash.models.models import TestDiff as InternalTestDiff + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for repair") + return None + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + modified_code = candidate.code_markdown or candidate.code + + internal_diffs = [ + InternalTestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_pass=True, + candidate_pass=False, + original_value=str(d.baseline_output) if d.baseline_output is not None else None, + candidate_value=str(d.candidate_output) if d.candidate_output is not None else None, + ) + for d in test_diffs + ] + + request = AIServiceCodeRepairRequest( + optimization_id=candidate.candidate_id, + original_source_code=source_code, + modified_source_code=modified_code, + trace_id=trace_id, + test_diffs=internal_diffs, + ) + + try: + result = client.code_repair(request) + except Exception: + logger.exception("Code repair API call failed") + return None + + if result is None: + return None + + code = result.source_code.flat if result.source_code else "" + code_md = result.source_code.markdown if result.source_code else "" + if not code: + return None + + return Candidate( + code=code, + explanation=result.explanation or "", + source="repair", + parent_id=candidate.candidate_id, + code_markdown=code_md, + ) + + def refine_candidate( + self, context: CodeContext, candidate: ScoredCandidate, baseline_bench: BenchmarkResults, trace_id: str = "" + ) -> list[Candidate]: + assert trace_id, "trace_id must be provided" + from codeflash.models.models import AIServiceRefinerRequest + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for refinement") + return [] + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + dependency_code = internal_ctx.read_only_context_code if internal_ctx else context.read_only_context + optimized_code = candidate.candidate.code_markdown or candidate.candidate.code + + request = AIServiceRefinerRequest( + optimization_id=candidate.candidate.candidate_id, + original_source_code=source_code, + read_only_dependency_code=dependency_code, + original_code_runtime=int(baseline_bench.total_time * 1e9), + optimized_source_code=optimized_code, + optimized_explanation=candidate.candidate.explanation, + optimized_code_runtime=int(candidate.benchmark_results.total_time * 1e9), + speedup=format_speedup_pct(candidate.speedup), + trace_id=trace_id, + original_line_profiler_results="", + optimized_line_profiler_results="", + ) + + try: + results = client.optimize_code_refinement([request]) + except Exception: + logger.exception("Code refinement API call failed") + return [] + + candidates = [] + for opt in results: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate( + code=code, + explanation=opt.explanation or "", + source="refine", + parent_id=candidate.candidate.candidate_id, + code_markdown=code_md, + ) + ) + return candidates + + def adaptive_optimize( + self, context: CodeContext, scored: list[ScoredCandidate], trace_id: str = "" + ) -> Candidate | None: + assert trace_id, "trace_id must be provided" + from codeflash.models.models import AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for adaptive optimization") + return None + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.flat if internal_ctx else context.target_code + + adaptive_candidates = [ + AdaptiveOptimizedCandidate( + optimization_id=sc.candidate.candidate_id, + source_code=sc.candidate.code, + explanation=sc.candidate.explanation, + source=map_candidate_source(sc.candidate.source), + speedup=f"Performance gain: {int(sc.speedup * 100 + 0.5)}%" + if sc.speedup > 0 + else "Candidate didn't match the behavior of the original code", + ) + for sc in scored + ] + + request = AIServiceAdaptiveOptimizeRequest( + trace_id=trace_id, original_source_code=source_code, candidates=adaptive_candidates + ) + + try: + result = client.adaptive_optimize(request) + except Exception: + logger.exception("Adaptive optimization API call failed") + return None + + if result is None: + return None + + code = result.source_code.flat if result.source_code else "" + if not code: + return None + + return Candidate(code=code, explanation=result.explanation or "", source="adaptive") diff --git a/codeflash/plugin_helpers.py b/codeflash/plugin_helpers.py new file mode 100644 index 000000000..41a596391 --- /dev/null +++ b/codeflash/plugin_helpers.py @@ -0,0 +1,167 @@ +"""Standalone helper functions used by PythonPlugin methods.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + from codeflash.models.models import OptimizedCandidateSource + from codeflash_core.models import CoverageData, FunctionToOptimize + +logger = logging.getLogger(__name__) + + +def make_test_env( + project_root: Path | str, *, loop_index: int = 0, test_iteration: int = 0, tracer_disable: int = 1 +) -> dict[str, str]: + """Return a copy of os.environ configured for running codeflash tests. + + Matches original codeflash get_test_env(): prepends project_root to PYTHONPATH + and sets CODEFLASH_* env vars expected by instrumented test harness. + """ + env = os.environ.copy() + project_root_str = str(project_root) + pythonpath = env.get("PYTHONPATH", "") + if pythonpath: + env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}" + else: + env["PYTHONPATH"] = project_root_str + env["CODEFLASH_LOOP_INDEX"] = str(loop_index) + env["CODEFLASH_TEST_ITERATION"] = str(test_iteration) + env["CODEFLASH_TRACER_DISABLE"] = str(tracer_disable) + return env + + +def format_speedup_pct(speedup: float) -> str: + """Format speedup as percentage string matching original codeflash API format.""" + return f"{int(speedup * 100)}%" + + +def read_return_values(test_iteration: int) -> dict[str, list[object]]: + """Read return values from the SQLite file written by instrumented tests. + + Returns a dict mapping test_function_name -> list of deserialized return values. + Only reads rows with loop_index == 1 (first timing iteration), matching original behavior. + """ + import pickle + import sqlite3 + + from codeflash.code_utils.code_utils import get_run_tmp_file + + sqlite_path = get_run_tmp_file(Path(f"test_return_values_{test_iteration}.sqlite")) + if not sqlite_path.exists(): + return {} + + result: dict[str, list[object]] = {} + db = None + try: + db = sqlite3.connect(sqlite_path) + rows = db.execute("SELECT test_function_name, loop_index, return_value FROM test_results").fetchall() + db.close() + db = None + + for test_fn_name, loop_index, return_value_blob in rows: + if loop_index != 1 or not return_value_blob or not test_fn_name: + continue + try: + ret_val = pickle.loads(return_value_blob) + result.setdefault(test_fn_name, []).append(ret_val) + except Exception: + logger.debug("Failed to deserialize return value for %s", test_fn_name) + except Exception: + logger.debug("Failed to read return values from %s", sqlite_path) + finally: + if db is not None: + db.close() + + return result + + +def map_candidate_source(source: str) -> OptimizedCandidateSource: + """Map core Candidate.source string to OptimizedCandidateSource enum value.""" + from codeflash.models.models import OptimizedCandidateSource + + mapping = { + "optimize": OptimizedCandidateSource.OPTIMIZE, + "line_profiler": OptimizedCandidateSource.OPTIMIZE_LP, + "refine": OptimizedCandidateSource.REFINE, + "repair": OptimizedCandidateSource.REPAIR, + "adaptive": OptimizedCandidateSource.ADAPTIVE, + } + return mapping.get(source, OptimizedCandidateSource.OPTIMIZE) + + +def coverage_data_to_details_dict(cov_data: CoverageData) -> dict[str, Any]: + """Convert CoverageData to the dict format expected by the repair API.""" + mc = cov_data.main_func_coverage + details: dict[str, Any] = { + "coverage_percentage": cov_data.coverage, + "threshold_percentage": cov_data.threshold_percentage, + "main_function": { + "name": mc.name, + "coverage": mc.coverage, + "executed_lines": sorted(mc.executed_lines), + "unexecuted_lines": sorted(mc.unexecuted_lines), + "executed_branches": mc.executed_branches, + "unexecuted_branches": mc.unexecuted_branches, + }, + } + dc = cov_data.dependent_func_coverage + if dc: + details["dependent_function"] = { + "name": dc.name, + "coverage": dc.coverage, + "executed_lines": sorted(dc.executed_lines), + "unexecuted_lines": sorted(dc.unexecuted_lines), + "executed_branches": dc.executed_branches, + "unexecuted_branches": dc.unexecuted_branches, + } + return details + + +def replace_function_simple(source: str, function: FunctionToOptimize, new_source: str) -> str: + from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file + + try: + return replace_functions_in_file( + source_code=source, + original_function_names=[function.qualified_name], + optimized_code=new_source, + preexisting_objects=set(), + ) + except Exception: + logger.warning("Failed to replace function %s", function.function_name) + return source + + +def format_code_with_ruff_or_black(source: str, file_path: Path | None = None) -> str: + import subprocess + + try: + result = subprocess.run( + ["ruff", "format", "-"], check=False, input=source, capture_output=True, text=True, timeout=30 + ) + if result.returncode == 0: + return result.stdout + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + except Exception: + pass + + try: + result = subprocess.run( + ["black", "-q", "-"], check=False, input=source, capture_output=True, text=True, timeout=30 + ) + if result.returncode == 0: + return result.stdout + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + except Exception: + pass + + return source diff --git a/codeflash/plugin_results.py b/codeflash/plugin_results.py new file mode 100644 index 000000000..024ae15fc --- /dev/null +++ b/codeflash/plugin_results.py @@ -0,0 +1,179 @@ +"""Mixin: ranking, explanation, PR creation, result logging.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.code_utils.time_utils import humanize_runtime +from codeflash.plugin_helpers import format_speedup_pct, replace_function_simple + +if TYPE_CHECKING: + from codeflash.plugin import PythonPlugin as _Base # type: ignore[attr-defined] + from codeflash_core.models import CodeContext, GeneratedTestSuite, OptimizationResult, ScoredCandidate +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginResultsMixin(_Base): # type: ignore[misc] + def rank_candidates( + self, scored: list[ScoredCandidate], context: CodeContext, trace_id: str = "" + ) -> list[int] | None: + assert trace_id, "trace_id must be provided" + from codeflash_core.diff import unified_diff + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for ranking") + return None + + diffs = [unified_diff(context.target_code, sc.candidate.code, context.target_file) for sc in scored] + optimization_ids = [sc.candidate.candidate_id for sc in scored] + speedups = [sc.speedup for sc in scored] + + try: + ranking: list[int] | None = client.generate_ranking( + trace_id=trace_id, diffs=diffs, optimization_ids=optimization_ids, speedups=speedups + ) + return ranking + except Exception: + logger.exception("Ranking API call failed") + return None + + def generate_explanation( + self, result: OptimizationResult, context: CodeContext, trace_id: str = "", annotated_tests: str = "" + ) -> str: + assert trace_id, "trace_id must be provided" + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for explanation") + return "" + + # Convert runtimes to nanoseconds and humanize, matching original + optimized_ns = int(result.benchmark_results.total_time * 1e9) + baseline_ns = int(optimized_ns * result.speedup) if result.speedup > 0 else 0 + + try: + explanation: str = client.get_new_explanation( + source_code=context.target_code, + optimized_code=result.optimized_code, + dependency_code=context.read_only_context, + trace_id=trace_id, + original_line_profiler_results="", + optimized_line_profiler_results="", + original_code_runtime=humanize_runtime(baseline_ns), + optimized_code_runtime=humanize_runtime(optimized_ns), + speedup=format_speedup_pct(result.speedup), + annotated_tests=annotated_tests, + optimization_id=result.candidate.candidate_id, + original_explanation=result.candidate.explanation, + ) + return explanation + except Exception: + logger.exception("Explanation generation API call failed") + return "" + + def create_pr( + self, + result: OptimizationResult, + context: CodeContext, + trace_id: str = "", + generated_tests: GeneratedTestSuite | None = None, + ) -> str | None: + from codeflash.models.models import TestResults as InternalTestResults + from codeflash.result.create_pr import check_create_pr + from codeflash.result.explanation import Explanation + + try: + # Build original_code: file with original function (optimizer restores before returning) + original_code = {context.target_file: context.target_file.read_text("utf-8")} + + # Build new_code: file with optimized function applied in memory + original_source = original_code[context.target_file] + internal_fn = context.target_function + new_source = replace_function_simple(original_source, internal_fn, result.optimized_code) + new_code = {context.target_file: new_source} + + # Build Explanation from optimization result + # Use empty internal TestResults since PR comment uses runtime/speedup fields directly + optimized_ns = int(result.benchmark_results.total_time * 1e9) + baseline_ns = int(optimized_ns * result.speedup) if result.speedup > 0 else 0 + + explanation = Explanation( + raw_explanation_message=result.explanation or result.candidate.explanation, + winning_behavior_test_results=InternalTestResults(), + winning_benchmarking_test_results=InternalTestResults(), + original_runtime_ns=baseline_ns, + best_runtime_ns=optimized_ns, + function_name=context.target_function.qualified_name, + file_path=context.target_file, + ) + + # Collect generated test source + generated_tests_str = "" + if generated_tests and generated_tests.test_files: + generated_tests_str = "\n\n".join( + tf.original_test_source for tf in generated_tests.test_files if tf.original_test_source + ) + + check_create_pr( + original_code=original_code, + new_code=new_code, + explanation=explanation, + existing_tests_source="", + generated_original_test_source=generated_tests_str, + function_trace_id=trace_id, + coverage_message="", + replay_tests="", + root_dir=self.project_root, + git_remote=None, + ) + except Exception: + logger.exception("PR creation failed") + return None + else: + return None + + def log_results( + self, + result: OptimizationResult, + trace_id: str, + all_speedups: dict[str, float] | None = None, + all_runtimes: dict[str, float] | None = None, + all_correct: dict[str, bool] | None = None, + ) -> None: + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for logging") + return + + # Use accumulated all-candidate data if available, otherwise fall back to winner-only + speedup_ratios = all_speedups or {result.candidate.candidate_id: result.speedup} + is_correct = all_correct or {result.candidate.candidate_id: result.test_results.passed} + + # Convert runtimes from seconds to nanoseconds (matching original API contract) + if all_runtimes: + optimized_runtimes = {cid: int(t * 1e9) for cid, t in all_runtimes.items()} + else: + optimized_runtimes = {result.candidate.candidate_id: int(result.benchmark_results.total_time * 1e9)} + + baseline_ns = int(result.benchmark_results.total_time * 1e9 * result.speedup) if result.speedup > 0 else None + + try: + client.log_results( + function_trace_id=trace_id, + speedup_ratio=speedup_ratios, + original_runtime=baseline_ns, + optimized_runtime=optimized_runtimes, + is_correct=is_correct, + optimized_line_profiler_results=None, + metadata={"best_optimization_id": result.candidate.candidate_id}, + ) + except Exception: + logger.exception("Result logging API call failed") diff --git a/codeflash/plugin_test_lifecycle.py b/codeflash/plugin_test_lifecycle.py new file mode 100644 index 000000000..dc1d91e00 --- /dev/null +++ b/codeflash/plugin_test_lifecycle.py @@ -0,0 +1,269 @@ +"""Mixin: test generation, review, and repair.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.plugin_helpers import coverage_data_to_details_dict +from codeflash.verification.test_runner import process_generated_test_strings +from codeflash_core.models import ( + GeneratedTestFile, + GeneratedTestSuite, + TestOutcomeStatus, + TestRepairInfo, + TestReviewResult, +) + +if TYPE_CHECKING: + from codeflash.plugin import PythonPlugin as _Base # type: ignore[attr-defined] + from codeflash_core.config import TestConfig + from codeflash_core.models import CodeContext, CoverageData, FunctionToOptimize, TestResults +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginTestLifecycleMixin(_Base): # type: ignore[misc] + def generate_tests( + self, function: FunctionToOptimize, context: CodeContext, test_config: TestConfig, trace_id: str = "" + ) -> GeneratedTestSuite | None: + from codeflash.code_utils.code_utils import module_name_from_file_path + from codeflash.verification.verification_utils import get_test_file_path + from codeflash.verification.verifier import generate_tests as _generate_tests + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test generation") + return None + + assert trace_id, "trace_id must be provided" + internal_fn = function + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + + # Compute is_numerical_code matching original analyze_code_characteristics + flat_code = internal_ctx.read_writable_code.flat if internal_ctx else context.target_code + try: + from codeflash.languages.python.static_analysis.code_extractor import ( + is_numerical_code as _is_numerical_code, + ) + + numerical = _is_numerical_code(code_string=flat_code) + except Exception: + numerical = None + + self.is_numerical_code = numerical + + # Cache tests_project_rootdir for use in repair_generated_tests + tests_project_rootdir = test_config.tests_project_rootdir or test_config.project_root + self.tests_project_rootdir: Path | None = tests_project_rootdir + + module_path = Path(module_name_from_file_path(function.file_path, test_config.project_root)) + helper_names = [h.qualified_name for h in context.helper_functions] + + test_dir = test_config.tests_root + test_dir.mkdir(parents=True, exist_ok=True) + + test_files: list[GeneratedTestFile] = [] + num_tests = 2 + + for i in range(num_tests): + behavior_path = get_test_file_path(test_dir, function.function_name, iteration=i, test_type="unit") + perf_path = get_test_file_path(test_dir, function.function_name, iteration=i, test_type="perf") + + try: + result = _generate_tests( + aiservice_client=client, + source_code_being_tested=source_code, + function_to_optimize=internal_fn, + helper_function_names=helper_names, + module_path=module_path, + test_cfg=test_config, + test_timeout=int(test_config.timeout), + function_trace_id=trace_id, + test_index=i, + test_path=behavior_path, + test_perf_path=perf_path, + is_numerical_code=numerical, + ) + except Exception: + logger.exception("Failed to generate test %d for %s", i, function.qualified_name) + continue + + if result is None: + continue + + gen_source, behavior_source, perf_source, _raw, _, _ = result + + # Write test files to disk + behavior_path.parent.mkdir(parents=True, exist_ok=True) + behavior_path.write_text(behavior_source, encoding="utf-8") + perf_path.parent.mkdir(parents=True, exist_ok=True) + perf_path.write_text(perf_source, encoding="utf-8") + + test_files.append( + GeneratedTestFile( + behavior_test_path=behavior_path, + perf_test_path=perf_path, + behavior_test_source=behavior_source, + perf_test_source=perf_source, + original_test_source=gen_source, + ) + ) + + if not test_files: + return None + + return GeneratedTestSuite(test_files=test_files) + + def review_generated_tests( + self, suite: GeneratedTestSuite, context: CodeContext, test_results: TestResults, trace_id: str = "" + ) -> list[TestReviewResult]: + assert trace_id, "trace_id must be provided" + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test review") + return [] + + # Collect failing test function names and error messages from test results + failed_test_functions: list[str] = [] + failure_messages: dict[str, str] = {} + for outcome in test_results.outcomes: + if outcome.status != TestOutcomeStatus.PASSED: + failed_test_functions.append(outcome.test_id) + if outcome.error_message: + failure_messages[outcome.test_id] = outcome.error_message + + tests_data = [ + { + "test_index": i, + "test_source": tf.original_test_source, + "failed_test_functions": failed_test_functions, + "failure_messages": failure_messages, + } + for i, tf in enumerate(suite.test_files) + ] + + try: + reviews = client.review_generated_tests( + tests=tests_data, + function_source_code=context.target_code, + function_name=context.target_function.function_name, + trace_id=trace_id, + language="python", + ) + except Exception: + logger.exception("Test review API call failed") + return [] + + return [ + TestReviewResult( + test_index=r.test_index, + functions_to_repair=[ + TestRepairInfo(function_name=f.function_name, reason=f.reason) for f in r.functions_to_repair + ], + ) + for r in reviews + ] + + def repair_generated_tests( + self, + suite: GeneratedTestSuite, + reviews: list[TestReviewResult], + context: CodeContext, + trace_id: str = "", + previous_repair_errors: dict[str, str] | None = None, + coverage_data: CoverageData | None = None, + ) -> GeneratedTestSuite | None: + from codeflash.code_utils.code_utils import module_name_from_file_path + from codeflash.models.models import FunctionRepairInfo + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test repair") + return None + + coverage_details = coverage_data_to_details_dict(coverage_data) if coverage_data is not None else None + internal_fn = context.target_function + assert trace_id, "trace_id must be provided" + + new_test_files = list(suite.test_files) + + for review in reviews: + if not review.functions_to_repair: + continue + + idx = review.test_index + if idx >= len(suite.test_files): + continue + + tf = suite.test_files[idx] + + repair_infos = [ + FunctionRepairInfo(function_name=f.function_name, reason=f.reason) for f in review.functions_to_repair + ] + + tests_project_rootdir = self.tests_project_rootdir or self.project_root + module_path = Path(module_name_from_file_path(context.target_file, self.project_root)) + test_module_path = Path(module_name_from_file_path(tf.behavior_test_path, tests_project_rootdir)) + + helper_names = [h.qualified_name for h in context.helper_functions] + + try: + result = client.repair_generated_tests( + test_source=tf.original_test_source, + functions_to_repair=repair_infos, + function_source_code=context.target_code, + function_to_optimize=internal_fn, + helper_function_names=helper_names, + module_path=module_path, + test_module_path=test_module_path, + test_framework="pytest", + test_timeout=60, + trace_id=trace_id, + language="python", + previous_repair_errors=previous_repair_errors, + module_source_code=context.target_code, + coverage_details=coverage_details, + ) + except Exception: + logger.exception("Test repair API call failed for test %d", idx) + continue + + if result is None: + continue + + gen_source, behavior_source, perf_source = result + + # Process (replace temp dir placeholders) + gen_source, behavior_source, perf_source = process_generated_test_strings( + generated_test_source=gen_source, + instrumented_behavior_test_source=behavior_source, + instrumented_perf_test_source=perf_source, + function_to_optimize=internal_fn, + test_path=tf.behavior_test_path, + test_cfg=None, + project_module_system=None, + ) + + # Write repaired tests + tf.behavior_test_path.write_text(behavior_source, encoding="utf-8") + tf.perf_test_path.write_text(perf_source, encoding="utf-8") + + new_test_files[idx] = GeneratedTestFile( + behavior_test_path=tf.behavior_test_path, + perf_test_path=tf.perf_test_path, + behavior_test_source=behavior_source, + perf_test_source=perf_source, + original_test_source=gen_source, + ) + + return GeneratedTestSuite(test_files=new_test_files) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 9325110fa..3fd6dc31a 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from codeflash.models.models import FunctionCalledInTest, InvocationId, TestFiles from codeflash.result.explanation import Explanation - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig def existing_tests_source_for( @@ -117,7 +117,7 @@ def existing_tests_source_for( if test_module_path.endswith(ext): matched_ext = ext break - if matched_ext: + if matched_ext and test_cfg.tests_project_rootdir is not None: # JavaScript/TypeScript: convert module-style path to file path # "tests.fibonacci__perfinstrumented.test.ts" -> "tests/fibonacci__perfinstrumented.test.ts" base_path = test_module_path[: -len(matched_ext)] @@ -143,7 +143,7 @@ def existing_tests_source_for( lang = current_language_support() # Let language-specific resolution handle non-Python module paths lang_result = lang.resolve_test_module_path_for_pr( - test_module_path, test_cfg.tests_project_rootdir, non_generated_tests + test_module_path, test_cfg.tests_project_rootdir or test_cfg.project_root, non_generated_tests ) if lang_result is not None: abs_path = lang_result diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 71173926c..85602933b 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -32,7 +32,7 @@ import subprocess from codeflash.models.models import CodeOptimizationContext, CoverageData, TestFiles - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig def parse_func(file_path: Path) -> XMLParser: diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py new file mode 100644 index 000000000..6ce3c153e --- /dev/null +++ b/codeflash/verification/test_runner.py @@ -0,0 +1,277 @@ +"""Standalone test runner for the PythonPlugin adapter. + +Extracted from the codeflash-next-gen test runner, adapted to use codeflash imports. +""" + +from __future__ import annotations + +import logging +import re +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file +from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash.languages.base import TestResult + +if TYPE_CHECKING: + import threading + from collections.abc import Sequence + +logger = logging.getLogger(__name__) + +_TIMING_MARKER_PATTERN = re.compile(r"!######.+:(\d+)######!") + +PYTEST_CMD: str = "pytest" + + +def setup_pytest_cmd(pytest_cmd: str | None) -> None: + global PYTEST_CMD + PYTEST_CMD = pytest_cmd or "pytest" + + +def pytest_cmd_tokens(is_posix: bool) -> list[str]: + import shlex + + return shlex.split(PYTEST_CMD, posix=is_posix) + + +def build_pytest_cmd(safe_sys_executable: str, is_posix: bool) -> list[str]: + return [safe_sys_executable, "-m", *pytest_cmd_tokens(is_posix)] + + +def run_tests( + test_files: Sequence[Path], + cwd: Path, + env: dict[str, str], + timeout: int, + *, + min_loops: int = 1, + max_loops: int = 1, + target_seconds: float | None = None, + stability_check: bool = False, + enable_coverage: bool = False, +) -> tuple[list[TestResult], Path, Path | None, Path | None]: + import contextlib + import shlex + import sys + + from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE + from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE + + if target_seconds is None: + target_seconds = TOTAL_LOOPING_TIME_EFFECTIVE + + junit_xml = get_run_tmp_file(Path("pytest_results.xml")) + + pytest_args = [ + "--capture=tee-sys", + "-q", + "--codeflash_loops_scope=session", + f"--codeflash_min_loops={min_loops}", + f"--codeflash_max_loops={max_loops}", + f"--codeflash_seconds={target_seconds}", + ] + if stability_check: + pytest_args.append("--codeflash_stability_check=true") + if timeout: + pytest_args.append(f"--timeout={timeout}") + + result_args = [f"--junitxml={junit_xml.as_posix()}", "-o", "junit_logging=all"] + + pytest_env = env.copy() + pytest_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + + blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"] + if min_loops > 1: + blocklisted_plugins.extend(["cov", "profiling"]) + + test_file_args = [str(f) for f in test_files] + + coverage_database_file: Path | None = None + coverage_config_file: Path | None = None + + try: + if enable_coverage: + from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files + + coverage_database_file, coverage_config_file = prepare_coverage_files() + pytest_env["NUMBA_DISABLE_JIT"] = str(1) + pytest_env["TORCHDYNAMO_DISABLE"] = str(1) + pytest_env["PYTORCH_JIT"] = str(0) + pytest_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" + pytest_env["TF_ENABLE_ONEDNN_OPTS"] = str(0) + pytest_env["JAX_DISABLE_JIT"] = str(0) + + is_windows = sys.platform == "win32" + if is_windows: + if coverage_database_file.exists(): + with contextlib.suppress(PermissionError, OSError): + coverage_database_file.unlink() + else: + cov_erase = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_env, timeout=30 + ) + logger.debug(cov_erase) + + coverage_cmd = [ + SAFE_SYS_EXECUTABLE, + "-m", + "coverage", + "run", + f"--rcfile={coverage_config_file.as_posix()}", + "-m", + ] + coverage_cmd.extend(pytest_cmd_tokens(IS_POSIX)) + + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"] + result = execute_test_subprocess( + coverage_cmd + pytest_args + blocklist_args + result_args + test_file_args, + cwd=cwd, + env=pytest_env, + timeout=600, + ) + else: + pytest_cmd_list = build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX) + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] + + result = execute_test_subprocess( + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_file_args, + cwd=cwd, + env=pytest_env, + timeout=600, + ) + + logger.debug("Result return code: %s, %s", result.returncode, result.stderr or "") + results = parse_test_results(junit_xml, result.stdout or "") + return results, junit_xml, coverage_database_file, coverage_config_file + + except Exception as e: + logger.exception("Test execution failed: %s", e) + return [], junit_xml, coverage_database_file, coverage_config_file + + +def parse_test_results(junit_xml_path: Path, stdout: str) -> list[TestResult]: + import xml.etree.ElementTree as ET + + results: list[TestResult] = [] + + if not junit_xml_path.exists(): + return results + + try: + tree = ET.parse(junit_xml_path) + root = tree.getroot() + + for testcase in root.iter("testcase"): + name = testcase.get("name", "unknown") + classname = testcase.get("classname", "") + time_str = testcase.get("time", "0") + + try: + runtime_ns = int(float(time_str) * 1_000_000_000) + except ValueError: + runtime_ns = None + + failure = testcase.find("failure") + error = testcase.find("error") + passed = failure is None and error is None + + error_message = None + if failure is not None: + error_message = failure.get("message", failure.text) + elif error is not None: + error_message = error.get("message", error.text) + + test_file = Path(classname.replace(".", "/") + ".py") if classname else Path("unknown") + + results.append( + TestResult( + test_name=name, + test_file=test_file, + passed=passed, + runtime_ns=runtime_ns, + error_message=error_message, + stdout=stdout, + ) + ) + except Exception as e: + logger.warning("Failed to parse JUnit XML: %s", e) + + return results + + +def process_generated_test_strings( + generated_test_source: str, + instrumented_behavior_test_source: str, + instrumented_perf_test_source: str, + function_to_optimize: object, + test_path: Path, + test_cfg: object, + project_module_system: str | None, +) -> tuple[str, str, str]: + temp_run_dir = get_run_tmp_file(Path()).as_posix() + instrumented_behavior_test_source = instrumented_behavior_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + instrumented_perf_test_source = instrumented_perf_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source + + +def execute_test_subprocess( + cmd_list: list[str], + cwd: Path, + env: dict[str, str] | None, + timeout: int = 600, + cancel_event: threading.Event | None = None, +) -> subprocess.CompletedProcess[str]: + """Execute a subprocess with the given command list, working directory, environment variables, and timeout. + + If *cancel_event* is provided and becomes set while the process is running, + the subprocess is terminated immediately and a CompletedProcess with + returncode -15 is returned. + """ + import time + + logger.debug("executing test run with command: %s", " ".join(cmd_list)) + with custom_addopts(): + if cancel_event is None: + run_args = get_cross_platform_subprocess_run_args( + cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True + ) + result: subprocess.CompletedProcess[str] = subprocess.run(cmd_list, **run_args) # type: ignore[call-overload] # noqa: PLW1510 + return result + + # Use Popen so we can poll for cancellation + run_args = get_cross_platform_subprocess_run_args( + cwd=cwd, env=env, timeout=None, check=False, text=True, capture_output=False + ) + # Remove keys that don't apply to Popen + run_args.pop("check", None) + run_args.pop("timeout", None) + run_args.pop("capture_output", None) + proc = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **run_args) # type: ignore[call-overload] + deadline = time.monotonic() + timeout + try: + while proc.poll() is None: + if cancel_event.is_set(): + proc.terminate() + proc.wait(timeout=5) + return subprocess.CompletedProcess(cmd_list, -15, stdout="", stderr="cancelled") + remaining = deadline - time.monotonic() + if remaining <= 0: + proc.terminate() + proc.wait(timeout=5) + msg = f"Timed out after {timeout}s" + raise subprocess.TimeoutExpired(cmd_list, timeout, output="", stderr=msg) # noqa: TRY301 + # Poll every 200ms + cancel_event.wait(min(0.2, remaining)) + stdout, stderr = proc.communicate(timeout=5) + return subprocess.CompletedProcess(cmd_list, proc.returncode, stdout=stdout or "", stderr=stderr or "") + except BaseException: + proc.kill() + proc.wait() + raise diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 76583edad..671b1f0ad 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -1,12 +1,15 @@ from __future__ import annotations import ast -from pathlib import Path -from typing import Optional - -from pydantic.dataclasses import dataclass +from typing import TYPE_CHECKING from codeflash.languages import current_language_support +from codeflash_core.config import TestConfig + +if TYPE_CHECKING: + from pathlib import Path + +__all__ = ["TestConfig"] def get_test_file_path( @@ -100,40 +103,3 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: return node node.name = node.name + "Inspired" return node - - -@dataclass -class TestConfig: - tests_root: Path - project_root_path: Path - tests_project_rootdir: Path - # tests_project_rootdir corresponds to pytest rootdir - concolic_test_root_dir: Optional[Path] = None - pytest_cmd: str = "pytest" - benchmark_tests_root: Optional[Path] = None - use_cache: bool = True - _language: Optional[str] = None # Language identifier for multi-language support - js_project_root: Optional[Path] = None # JavaScript project root (directory containing package.json) - - def __post_init__(self) -> None: - self.project_root_path = self.project_root_path.resolve() - self.tests_project_rootdir = self.tests_project_rootdir.resolve() - - @property - def test_framework(self) -> str: - """Returns the appropriate test framework based on language.""" - return current_language_support().test_framework - - def set_language(self, language: str) -> None: - """Set the language for this test config. - - Args: - language: Language identifier (e.g., "python", "javascript"). - - """ - self._language = language - - @property - def language(self) -> Optional[str]: - """Get the current language setting.""" - return self._language diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 7cfa8473c..4e78b446e 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -12,8 +12,8 @@ if TYPE_CHECKING: from codeflash.api.aiservice import AiServiceClient - from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig + from codeflash_core.models import FunctionToOptimize def generate_tests( diff --git a/pyproject.toml b/pyproject.toml index b3d9b9969..3bfb4399d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,7 @@ tests = [ ] [tool.hatch.build.targets.sdist] -include = ["codeflash", "src/codeflash_core"] +include = ["codeflash", "src/codeflash_core", "src/codeflash_python"] exclude = [ "docs/*", "experiments/*", @@ -154,6 +154,7 @@ exclude = [ ] [tool.hatch.build.targets.wheel] +packages = ["codeflash", "src/codeflash_core", "src/codeflash_python"] exclude = [ "docs/*", "experiments/*", @@ -204,6 +205,7 @@ exclude = [ ] [tool.mypy] +mypy_path = ["src"] show_error_code_links = true pretty = true show_absolute_path = true diff --git a/src/codeflash_core/config.py b/src/codeflash_core/config.py index 791c92c28..f596908d3 100644 --- a/src/codeflash_core/config.py +++ b/src/codeflash_core/config.py @@ -63,9 +63,19 @@ def get_effort_value(key: EffortKeys, effort: EffortLevel | str) -> Any: class TestConfig: tests_root: Path project_root: Path - test_command: str = "" - timeout: float = 60.0 tests_project_rootdir: Path | None = None + concolic_test_root_dir: Path | None = None + test_command: str = "pytest" + test_framework: str = "pytest" + benchmark_tests_root: Path | None = None + use_cache: bool = True + timeout: float = 60.0 + js_project_root: Path | None = None + + def __post_init__(self) -> None: + self.project_root = Path(self.project_root).resolve() + if self.tests_project_rootdir is not None: + self.tests_project_rootdir = Path(self.tests_project_rootdir).resolve() @dataclass diff --git a/src/codeflash_core/models.py b/src/codeflash_core/models.py index 29489c679..99b5411ce 100644 --- a/src/codeflash_core/models.py +++ b/src/codeflash_core/models.py @@ -3,10 +3,8 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pathlib import Path +from pathlib import Path +from typing import Any class TestOutcomeStatus(Enum): @@ -19,7 +17,7 @@ class TestOutcomeStatus(Enum): @dataclass(frozen=True) class FunctionParent: name: str - type: str = "ClassDef" + type: str def __str__(self) -> str: return f"{self.type}:{self.name}" @@ -36,10 +34,15 @@ class FunctionToOptimize: ending_col: int | None = None is_async: bool = False is_method: bool = False - language: str = "" + language: str = "python" doc_start_line: int | None = None source_code: str = "" + def __post_init__(self) -> None: + if not isinstance(self.file_path, Path): + self.file_path = Path(self.file_path) + self.parents = [p if isinstance(p, FunctionParent) else FunctionParent(**p) for p in self.parents] + @property def qualified_name(self) -> str: if not self.parents: @@ -58,6 +61,16 @@ def class_name(self) -> str | None: return parent.name return None + def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: + from codeflash.code_utils.code_utils import module_name_from_file_path + + return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" + + def __str__(self) -> str: + qualified = f"{'.'.join([p.name for p in self.parents])}{'.' if self.parents else ''}{self.function_name}" + line_info = f":{self.starting_line}-{self.ending_line}" if self.starting_line and self.ending_line else "" + return f"{self.file_path}:{qualified}{line_info}" + @dataclass class HelperFunction: diff --git a/src/codeflash_python/__init__.py b/src/codeflash_python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/api/__init__.py b/src/codeflash_python/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/api/aiservice.py b/src/codeflash_python/api/aiservice.py new file mode 100644 index 000000000..a862aa414 --- /dev/null +++ b/src/codeflash_python/api/aiservice.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import json +import logging +import os +from itertools import count +from typing import TYPE_CHECKING, Any + +import requests +from pydantic.json import pydantic_encoder + +from codeflash_python.api.aiservice_optimize import AiServiceOptimizeMixin +from codeflash_python.api.aiservice_results import AiServiceResultsMixin +from codeflash_python.api.aiservice_testgen import AiServiceTestgenMixin +from codeflash_python.code_utils.config_consts import PYTHON_LANGUAGE_VERSION +from codeflash_python.code_utils.env_utils import get_codeflash_api_key +from codeflash_python.models.models import CodeStringsMarkdown, OptimizedCandidate + +if TYPE_CHECKING: + from codeflash_python.models.models import OptimizedCandidateSource + +logger = logging.getLogger("codeflash_python") + + +class AiServiceClient(AiServiceOptimizeMixin, AiServiceTestgenMixin, AiServiceResultsMixin): + def __init__(self) -> None: + self.base_url = self.get_aiservice_base_url() + self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"} + self.llm_call_counter = count(1) + self.is_local = self.base_url == "http://localhost:8000" + self.timeout: float | None = 300 if self.is_local else 90 + + def get_next_sequence(self) -> int: + """Get the next LLM call sequence number.""" + return next(self.llm_call_counter) + + @staticmethod + def add_language_metadata( + payload: dict[str, Any], + language_version: str | None = None, + module_system: str | None = None, # noqa: ARG004 + ) -> None: + """Add language version metadata to an API payload.""" + if language_version is None: + language_version = PYTHON_LANGUAGE_VERSION + payload["language_version"] = language_version + payload["python_version"] = language_version + + @staticmethod + def log_error_response(response: requests.Response, action: str, ph_event: str) -> None: + """Log and report an API error response.""" + from codeflash_python.telemetry.posthog_cf import ph + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error("Error %s: %s - %s", action, response.status_code, error) + ph(ph_event, {"response_status_code": response.status_code, "error": error}) + + def get_aiservice_base_url(self) -> str: + if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local": + logger.info("Using local AI Service at http://localhost:8000") + + return "http://localhost:8000" + return "https://app.codeflash.ai" + + def make_ai_service_request( + self, + endpoint: str, + method: str = "POST", + payload: dict[str, Any] | list[dict[str, Any]] | None = None, + timeout: float | None = None, + ) -> requests.Response: + """Make an API request to the given endpoint on the AI service. + + Args: + ---- + endpoint: The endpoint to call, e.g., "/optimize" + method: The HTTP method to use ('GET' or 'POST') + payload: Optional JSON payload to include in the POST request body + timeout: The timeout for the request in seconds + + Returns: + ------- + The response object from the API + + Raises: + ------ + requests.exceptions.RequestException: If the request fails + + """ + url = f"{self.base_url}/ai{endpoint}" + if method.upper() == "POST": + json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) + headers = {**self.headers, "Content-Type": "application/json"} + response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) + else: + response = requests.get(url, headers=self.headers, timeout=timeout) + # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code + return response + + def get_valid_candidates( + self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource, language: str = "python" + ) -> list[OptimizedCandidate]: + candidates: list[OptimizedCandidate] = [] + for opt in optimizations_json: + code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"], expected_language=language) + if not code.code_strings: + continue + candidates.append( + OptimizedCandidate( + source_code=code, + explanation=opt["explanation"], + optimization_id=opt["optimization_id"], + source=source, + parent_id=opt.get("parent_id", None), + model=opt.get("model"), + ) + ) + return candidates + + +class LocalAiServiceClient(AiServiceClient): + """Client for interacting with the local AI service.""" + + def get_aiservice_base_url(self) -> str: + """Get the base URL for the local AI service.""" + return "http://localhost:8000" diff --git a/src/codeflash_python/api/aiservice_optimize.py b/src/codeflash_python/api/aiservice_optimize.py new file mode 100644 index 000000000..c91e15bea --- /dev/null +++ b/src/codeflash_python/api/aiservice_optimize.py @@ -0,0 +1,366 @@ +"""Mixin: optimization-related API endpoints.""" + +from __future__ import annotations + +import logging +import platform +from typing import TYPE_CHECKING, Any + +import requests + +from codeflash_python.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.models.models import OptimizedCandidateSource +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.version import __version__ as codeflash_version + +if TYPE_CHECKING: + from codeflash_python.api.types import ( + AIServiceAdaptiveOptimizeRequest, + AIServiceCodeRepairRequest, + AIServiceRefinerRequest, + ) + from codeflash_python.models.experiment_metadata import ExperimentMetadata + from codeflash_python.models.models import OptimizedCandidate +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]: + try: + git_repo_owner, git_repo_name = get_repo_owner_and_name() + except Exception as e: + logger.warning("Could not determine repo owner and name: %s", e) + git_repo_owner, git_repo_name = None, None + return git_repo_owner, git_repo_name + + +class AiServiceOptimizeMixin(_Base): # type: ignore[name-defined] + def optimize_code( + self, + source_code: str, + dependency_code: str, + trace_id: str, + experiment_metadata: ExperimentMetadata | None = None, + *, + language: str = "python", + language_version: str | None = None, + module_system: str | None = None, + is_async: bool = False, + n_candidates: int = 5, + is_numerical_code: bool | None = None, + ) -> list[OptimizedCandidate]: + """Optimize the given code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + - language (str): Programming language (e.g., "python") + - language_version (str | None): Language version (e.g., "3.11.0") + - module_system (str | None): Module system (None for Python) + - is_async (bool): Whether the function being optimized is async + - n_candidates (int): Number of candidates to generate + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + logger.info("Generating optimized candidates\u2026") + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + # Build payload with language-specific fields + payload: dict[str, Any] = { + "source_code": source_code, + "dependency_code": dependency_code, + "trace_id": trace_id, + "language": language, + "experiment_metadata": experiment_metadata, + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + "is_async": is_async, + "call_sequence": self.get_next_sequence(), + "n_candidates": n_candidates, + "is_numerical_code": is_numerical_code, + } + + self.add_language_metadata(payload, language_version, module_system) + + # DEBUG: Print payload language field + logger.debug( + "Sending optimize request with language='%s' (type: %s)", payload["language"], type(payload["language"]) + ) + logger.debug("Sending optimize request: trace_id=%s, n_candidates=%s", trace_id, payload["n_candidates"]) + + try: + response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating optimized candidates: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + return self.get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE, language) + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + return [] + + # Backward-compatible alias + def optimize_python_code( + self, + source_code: str, + dependency_code: str, + trace_id: str, + experiment_metadata: ExperimentMetadata | None = None, + *, + is_async: bool = False, + n_candidates: int = 5, + ) -> list[OptimizedCandidate]: + """Backward-compatible alias for optimize_code() with language='python'.""" + return self.optimize_code( + source_code=source_code, + dependency_code=dependency_code, + trace_id=trace_id, + experiment_metadata=experiment_metadata, + language="python", + is_async=is_async, + n_candidates=n_candidates, + ) + + def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[OptimizedCandidate]: + """Rewrite the given python code for performance via jit compilation by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - trace_id (str): Trace id of optimization run + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "trace_id": trace_id, + "dependency_code": "", # dummy value to please the api endpoint + "python_version": platform.python_version(), # backward compat + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + } + + try: + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating jit rewritten candidate: %s", e) + ph("cli-jit-rewrite-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + return self.get_valid_candidates(optimizations_json, OptimizedCandidateSource.JIT_REWRITE) + self.log_error_response(response, "generating jit rewritten candidate", "cli-jit-rewrite-error-response") + return [] + + def optimize_python_code_line_profiler( + self, + source_code: str, + dependency_code: str, + trace_id: str, + line_profiler_results: str, + n_candidates: int, + experiment_metadata: ExperimentMetadata | None = None, + is_numerical_code: bool | None = None, + language: str = "python", + language_version: str | None = None, + ) -> list[OptimizedCandidate]: + """Optimize code for performance using line profiler results. + + Parameters + ---------- + - source_code (str): The code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - line_profiler_results (str): Line profiler output to guide optimization + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + - n_candidates (int): Number of candidates to generate + - language (str): Programming language (e.g., "python") + - language_version (str): Language version (e.g., "3.12.0") + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + if line_profiler_results == "": + logger.info("No LineProfiler results were provided, Skipping optimization.") + return [] + + logger.info("Generating optimized candidates with line profiler\u2026") + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "n_candidates": n_candidates, + "line_profiler_results": line_profiler_results, + "trace_id": trace_id, + "language": language, + "experiment_metadata": experiment_metadata, + "codeflash_version": codeflash_version, + "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, + } + self.add_language_metadata(payload, language_version) + + try: + response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating optimized candidates: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + return self.get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP) + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + return [] + + def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> OptimizedCandidate | None: + try: + payload = { + "trace_id": request.trace_id, + "original_source_code": request.original_source_code, + "candidates": request.candidates, + } + response = self.make_ai_service_request("/adaptive_optimize", payload=payload, timeout=self.timeout) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception("Error generating adaptive optimized candidates: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + + valid_candidates = self.get_valid_candidates([fixed_optimization], OptimizedCandidateSource.ADAPTIVE) + if not valid_candidates: + logger.error("Adaptive optimization failed to generate a valid candidate.") + return None + + return valid_candidates[0] + + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + return None + + def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]: + """Refine optimization candidates for improved performance. + + Refines optimization candidates with optional multi-file context for + better understanding of imports and dependencies. + + Args: + request: A list of optimization candidate details for refinement + + Returns: + List of refined optimization candidates + + """ + payload: list[dict[str, Any]] = [] + for opt in request: + item: dict[str, Any] = { + "optimization_id": opt.optimization_id, + "original_source_code": opt.original_source_code, + "read_only_dependency_code": opt.read_only_dependency_code, + "original_line_profiler_results": opt.original_line_profiler_results, + "original_code_runtime": humanize_runtime(opt.original_code_runtime), + "optimized_source_code": opt.optimized_source_code, + "optimized_explanation": opt.optimized_explanation, + "optimized_line_profiler_results": opt.optimized_line_profiler_results, + "optimized_code_runtime": humanize_runtime(opt.optimized_code_runtime), + "speedup": opt.speedup, + "trace_id": opt.trace_id, + "function_references": opt.function_references, + "call_sequence": self.get_next_sequence(), + # Multi-language support + "language": opt.language, + } + + self.add_language_metadata(item, opt.language_version) + + # Add multi-file context if provided + if opt.additional_context_files: + item["additional_context_files"] = opt.additional_context_files + + payload.append(item) + + try: + response = self.make_ai_service_request("/refinement", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating optimization refinements: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + refined_optimizations = response.json()["refinements"] + + return self.get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE) + + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + + return [] + + # Alias for backward compatibility + optimize_python_code_refinement = optimize_code_refinement + + def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: + """Repair the optimization candidate that is not matching the test result of the original code. + + Args: + request: candidate details for repair + + Returns: + ------- + - OptimizedCandidate: new fixed candidate. + + """ + try: + payload = { + "optimization_id": request.optimization_id, + "original_source_code": request.original_source_code, + "modified_source_code": request.modified_source_code, + "trace_id": request.trace_id, + "test_diffs": request.test_diffs, + "language": request.language, + } + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception("Error generating optimization repair: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + + valid_candidates = self.get_valid_candidates( + [fixed_optimization], OptimizedCandidateSource.REPAIR, request.language + ) + if not valid_candidates: + logger.error("Code repair failed to generate a valid candidate.") + return None + + return valid_candidates[0] + + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + + return None diff --git a/src/codeflash_python/api/aiservice_results.py b/src/codeflash_python/api/aiservice_results.py new file mode 100644 index 000000000..31f4a6581 --- /dev/null +++ b/src/codeflash_python/api/aiservice_results.py @@ -0,0 +1,341 @@ +"""Mixin: explanation, ranking, logging, review, and workflow API endpoints.""" + +from __future__ import annotations + +import logging +import platform +from typing import TYPE_CHECKING, Any, cast + +import requests + +from codeflash_python.api.types import OptimizationReviewResult +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.version import __version__ as codeflash_version + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_python.result.explanation import Explanation +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class AiServiceResultsMixin(_Base): # type: ignore[name-defined] + def get_new_explanation( + self, + source_code: str, + optimized_code: str, + dependency_code: str, + trace_id: str, + original_line_profiler_results: str, + optimized_line_profiler_results: str, + original_code_runtime: str, + optimized_code_runtime: str, + speedup: str, + annotated_tests: str, + optimization_id: str, + original_explanation: str, + original_throughput: str | None = None, + optimized_throughput: str | None = None, + throughput_improvement: str | None = None, + function_references: str | None = None, + acceptance_reason: str | None = None, + original_concurrency_ratio: str | None = None, + optimized_concurrency_ratio: str | None = None, + concurrency_improvement: str | None = None, + codeflash_version: str = codeflash_version, + ) -> str: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - optimized_code (str): The python code generated by the AI service. + - dependency_code (str): The dependency code used as read-only context for the optimization + - original_line_profiler_results: str - line profiler results for the baseline code + - optimized_line_profiler_results: str - line profiler results for the optimized code + - original_code_runtime: str - runtime for the baseline code + - optimized_code_runtime: str - runtime for the optimized code + - speedup: str - speedup of the optimized code + - annotated_tests: str - test functions annotated with runtime + - optimization_id: str - unique id of opt candidate + - original_explanation: str - original_explanation generated for the opt candidate + - original_throughput: str | None - throughput for the baseline code (operations per second) + - optimized_throughput: str | None - throughput for the optimized code (operations per second) + - throughput_improvement: str | None - throughput improvement percentage + - function_references: str | None - where the function is called in the codebase + - acceptance_reason: str | None - why the optimization was accepted (runtime, throughput, or concurrency) + - original_concurrency_ratio: str | None - concurrency ratio for the baseline code + - optimized_concurrency_ratio: str | None - concurrency ratio for the optimized code + - concurrency_improvement: str | None - concurrency improvement percentage + - codeflash_version: str - current codeflash version + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = { + "trace_id": trace_id, + "source_code": source_code, + "optimized_code": optimized_code, + "original_line_profiler_results": original_line_profiler_results, + "optimized_line_profiler_results": optimized_line_profiler_results, + "original_code_runtime": original_code_runtime, + "optimized_code_runtime": optimized_code_runtime, + "speedup": speedup, + "annotated_tests": annotated_tests, + "optimization_id": optimization_id, + "original_explanation": original_explanation, + "dependency_code": dependency_code, + "original_throughput": original_throughput, + "optimized_throughput": optimized_throughput, + "throughput_improvement": throughput_improvement, + "function_references": function_references, + "acceptance_reason": acceptance_reason, + "original_concurrency_ratio": original_concurrency_ratio, + "optimized_concurrency_ratio": optimized_concurrency_ratio, + "concurrency_improvement": concurrency_improvement, + "codeflash_version": codeflash_version, + "call_sequence": self.get_next_sequence(), + } + logger.info("loading|Generating explanation") + + try: + response = self.make_ai_service_request("/explain", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating explanations: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return "" + + if response.status_code == 200: + explanation: str = response.json()["explanation"] + + return explanation + self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") + + return "" + + def generate_ranking( + self, + trace_id: str, + diffs: list[str], + optimization_ids: list[str], + speedups: list[float], + function_references: str | None = None, + ) -> list[int] | None: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - trace_id : unique uuid of function + - diffs : list of unified diff strings of opt candidates + - speedups : list of speedups of opt candidates + - function_references : where the function is called in the codebase + + Returns + ------- + - List[int]: Ranking of opt candidates in decreasing order + + """ + payload = { + "trace_id": trace_id, + "diffs": diffs, + "speedups": speedups, + "optimization_ids": optimization_ids, + "python_version": platform.python_version(), # backward compat + "function_references": function_references, + } + logger.info("loading|Generating ranking") + + try: + response = self.make_ai_service_request("/rank", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating ranking: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + ranking: list[int] = response.json()["ranking"] + + return ranking + self.log_error_response(response, "generating ranking", "cli-optimize-error-response") + + return None + + def log_results( + self, + function_trace_id: str, + speedup_ratio: dict[str, float | None] | None, + original_runtime: float | None, + optimized_runtime: dict[str, float | None] | None, + is_correct: dict[str, bool] | None, + optimized_line_profiler_results: dict[str, str] | None, + metadata: dict[str, Any] | None, + optimizations_post: dict[str, str] | None = None, + ) -> None: + """Log features to the database. + + Parameters + ---------- + - function_trace_id (str): The UUID. + - speedup_ratio (Optional[Dict[str, float]]): The speedup. + - original_runtime (Optional[Dict[str, float]]): The original runtime. + - optimized_runtime (Optional[Dict[str, float]]): The optimized runtime. + - is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct. + - optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id + - metadata: contains the best optimization id + - optimizations_post - dict mapping opt id to code str after postprocessing + + """ + payload = { + "trace_id": function_trace_id, + "speedup_ratio": speedup_ratio, + "original_runtime": original_runtime, + "optimized_runtime": optimized_runtime, + "is_correct": is_correct, + "codeflash_version": codeflash_version, + "optimized_line_profiler_results": optimized_line_profiler_results, + "metadata": metadata, + "optimizations_post": optimizations_post, + } + try: + self.make_ai_service_request("/log_features", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error logging features: %s", e) + + def get_optimization_review( + self, + original_code: dict[Path, str], + new_code: dict[Path, str], + explanation: Explanation, + existing_tests_source: str, + generated_original_test_source: str, + function_trace_id: str, + coverage_message: str, + replay_tests: str, + calling_fn_details: str, + language: str = "python", + **_kwargs: Any, + ) -> OptimizationReviewResult: + """Compute the optimization review of current Pull Request. + + Args: + original_code: dict -> data structure mapping file paths to function definition for original code + new_code: dict -> data structure mapping file paths to function definition for optimized code + explanation: Explanation -> data structure containing runtime information + existing_tests_source: str -> existing tests table + generated_original_test_source: str -> annotated generated tests + function_trace_id: str -> traceid of function + coverage_message: str -> coverage information + replay_tests: str -> replay test table + root_dir: Path -> path of git directory + concolic_tests: str -> concolic_tests (not used) + calling_fn_details: str -> filenames and definitions of functions which call the function_to_optimize + + Returns: + ------- + OptimizationReviewResult with review ('high', 'medium', 'low', or '') and explanation + + """ + original_code_str = "\n\n".join([original_code[p] for p in original_code]) + optimized_code_str = "\n\n".join([new_code[p] for p in new_code]) + + logger.info("loading|Reviewing Optimization\u2026") + payload = { + "original_code": original_code_str, + "optimized_code": optimized_code_str, + "explanation": explanation.raw_explanation_message, + "existing_tests": existing_tests_source, + "generated_tests": generated_original_test_source, + "trace_id": function_trace_id, + "coverage_message": coverage_message, + "replay_tests": replay_tests, + "speedup": f"{(100 * float(explanation.speedup)):.2f}%", + "loop_count": explanation.winning_benchmarking_test_results.number_of_loops(), + "benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None, + "optimized_runtime": humanize_runtime(explanation.best_runtime_ns), + "original_runtime": humanize_runtime(explanation.original_runtime_ns), + "codeflash_version": codeflash_version, + "calling_fn_details": calling_fn_details, + "language": language, + "language_version": platform.python_version(), + "python_version": platform.python_version(), + "call_sequence": self.get_next_sequence(), + } + + try: + response = self.make_ai_service_request("/optimization_review", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error generating optimization refinements: %s", e) + ph("cli-optimize-error-caught", {"error": str(e)}) + return OptimizationReviewResult(review="", explanation="") + + if response.status_code == 200: + data = response.json() + return OptimizationReviewResult( + review=cast("str", data["review"]), explanation=cast("str", data.get("review_explanation", "")) + ) + self.log_error_response(response, "generating optimization review", "cli-optimize-error-response") + + return OptimizationReviewResult(review="", explanation="") + + def generate_workflow_steps( + self, + repo_files: dict[str, str], + directory_structure: dict[str, Any], + codeflash_config: dict[str, Any] | None = None, + ) -> str | None: + """Generate GitHub Actions workflow steps based on repository analysis. + + :param repo_files: Dictionary mapping file paths to their contents + :param directory_structure: 2-level nested directory structure + :param codeflash_config: Optional codeflash configuration + :return: YAML string for workflow steps section, or None on error + """ + payload = { + "repo_files": repo_files, + "directory_structure": directory_structure, + "codeflash_config": codeflash_config, + } + + logger.debug( + "[aiservice.py:generate_workflow_steps] Sending request to AI service with %s files, " + "%s top-level directories", + len(repo_files), + len(directory_structure), + ) + + try: + response = self.make_ai_service_request("/workflow-gen", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + # AI service unavailable - this is expected, will fall back to static workflow + logger.debug( + "[aiservice.py:generate_workflow_steps] Request exception (falling back to static workflow): %s", e + ) + return None + + if response.status_code == 200: + response_data = response.json() + workflow_steps = cast("str", response_data.get("workflow_steps")) + logger.debug( + "[aiservice.py:generate_workflow_steps] Successfully received workflow steps (%s chars)", + len(workflow_steps) if workflow_steps else 0, + ) + return workflow_steps + # AI service unavailable or endpoint not found - this is expected, will fall back to static workflow + logger.debug( + "[aiservice.py:generate_workflow_steps] AI service returned status %s, " + "falling back to static workflow generation", + response.status_code, + ) + try: + error_response = response.json() + error = cast("str", error_response.get("error", "Unknown error")) + logger.debug("[aiservice.py:generate_workflow_steps] Error: %s", error) + except Exception: + logger.debug("[aiservice.py:generate_workflow_steps] Could not parse error response") + return None diff --git a/src/codeflash_python/api/aiservice_testgen.py b/src/codeflash_python/api/aiservice_testgen.py new file mode 100644 index 000000000..0085c65c2 --- /dev/null +++ b/src/codeflash_python/api/aiservice_testgen.py @@ -0,0 +1,213 @@ +"""Mixin: test generation, review, and repair API endpoints.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import requests + +from codeflash_python.api.types import FunctionRepairInfo, TestFileReview +from codeflash_python.code_utils.config_consts import PYTHON_VALID_TEST_FRAMEWORKS +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.version import __version__ as codeflash_version + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.models import FunctionToOptimize +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class AiServiceTestgenMixin(_Base): # type: ignore[name-defined] + def generate_regression_tests( + self, + source_code_being_tested: str, + function_to_optimize: FunctionToOptimize, + helper_function_names: list[str], + module_path: Path, + test_module_path: Path, + test_framework: str, + test_timeout: int, + trace_id: str, + test_index: int, + *, + language: str = "python", + language_version: str | None = None, + module_system: str | None = None, + is_numerical_code: bool | None = None, + ) -> tuple[str, str, str, str | None] | None: + """Generate regression tests for the given function by making a request to the Django endpoint. + + Parameters + ---------- + - source_code_being_tested (str): The source code of the function being tested. + - function_to_optimize (FunctionToOptimize): The function to optimize. + - helper_function_names (list[Source]): List of helper function names. + - module_path (Path): The module path where the function is located. + - test_module_path (Path): The module path for the test code. + - test_framework (str): The test framework to use, e.g., "pytest". + - test_timeout (int): The timeout for each test in seconds. + - test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id + - language (str): Programming language (e.g., "python") + - language_version (str | None): Language version (e.g., "3.11.0") + - module_system (str | None): Module system (None for Python) + + Returns + ------- + - Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred. + + """ + valid_frameworks = PYTHON_VALID_TEST_FRAMEWORKS + assert test_framework in valid_frameworks, ( + f"Invalid test framework for python, got {test_framework} but expected one of {list(valid_frameworks)}" + ) + + payload: dict[str, Any] = { + "source_code_being_tested": source_code_being_tested, + "function_to_optimize": function_to_optimize, + "helper_function_names": helper_function_names, + "module_path": module_path, + "test_module_path": test_module_path, + "test_framework": test_framework, + "test_timeout": test_timeout, + "trace_id": trace_id, + "test_index": test_index, + "language": language, + "codeflash_version": codeflash_version, + "is_async": function_to_optimize.is_async, + "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, + "class_name": function_to_optimize.class_name, + "qualified_name": function_to_optimize.qualified_name, + } + + self.add_language_metadata(payload, language_version, module_system) + + # DEBUG: Print payload language field + logger.debug("Sending testgen request with language='%s', framework='%s'", payload["language"], test_framework) + try: + response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + from codeflash_python.telemetry.posthog_cf import ph + + logger.exception("Error generating tests: %s", e) + ph("cli-testgen-error-caught", {"error": str(e)}) + return None + + # the timeout should be the same as the timeout for the AI service backend + + if response.status_code == 200: + response_json = response.json() + logger.debug("Generated tests for function %s", function_to_optimize.function_name) + return ( + response_json["generated_tests"], + response_json["instrumented_behavior_tests"], + response_json["instrumented_perf_tests"], + response_json.get("raw_generated_tests"), + ) + self.log_error_response(response, "generating tests", "cli-testgen-error-response") + return None + + def review_generated_tests( + self, + tests: list[dict[str, Any]], + function_source_code: str, + function_name: str, + trace_id: str, + coverage_summary: str = "", + coverage_details: dict[str, Any] | None = None, + language: str = "python", + ) -> list[TestFileReview]: + payload: dict[str, Any] = { + "tests": tests, + "function_source_code": function_source_code, + "function_name": function_name, + "trace_id": trace_id, + "language": language, + "codeflash_version": codeflash_version, + "call_sequence": self.get_next_sequence(), + } + if coverage_summary: + payload["coverage_summary"] = coverage_summary + if coverage_details: + payload["coverage_details"] = coverage_details + self.add_language_metadata(payload) + try: + response = self.make_ai_service_request("/testgen_review", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error reviewing generated tests: %s", e) + ph("cli-testgen-review-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + data = response.json() + return [ + TestFileReview( + test_index=r["test_index"], + functions_to_repair=[ + FunctionRepairInfo(function_name=f["function_name"], reason=f.get("reason", "")) + for f in r.get("functions", []) + ], + ) + for r in data.get("reviews", []) + ] + self.log_error_response(response, "reviewing generated tests", "cli-testgen-review-error-response") + return [] + + def repair_generated_tests( + self, + test_source: str, + functions_to_repair: list[FunctionRepairInfo], + function_source_code: str, + function_to_optimize: FunctionToOptimize, + helper_function_names: list[str], + module_path: Path, + test_module_path: Path, + test_framework: str, + test_timeout: int, + trace_id: str, + language: str = "python", + coverage_details: dict[str, Any] | None = None, + previous_repair_errors: dict[str, str] | None = None, + module_source_code: str = "", + ) -> tuple[str, str, str] | None: + payload: dict[str, Any] = { + "test_source": test_source, + "functions_to_repair": [ + {"function_name": f.function_name, "reason": f.reason} for f in functions_to_repair + ], + "function_source_code": function_source_code, + "function_to_optimize": function_to_optimize, + "helper_function_names": helper_function_names, + "module_path": module_path, + "test_module_path": test_module_path, + "test_framework": test_framework, + "test_timeout": test_timeout, + "trace_id": trace_id, + "language": language, + "codeflash_version": codeflash_version, + "call_sequence": self.get_next_sequence(), + } + if module_source_code: + payload["module_source_code"] = module_source_code + if coverage_details: + payload["coverage_details"] = coverage_details + if previous_repair_errors: + payload["previous_repair_errors"] = previous_repair_errors + self.add_language_metadata(payload) + try: + response = self.make_ai_service_request("/testgen_repair", payload=payload, timeout=self.timeout) + except requests.exceptions.RequestException as e: + logger.exception("Error repairing generated tests: %s", e) + ph("cli-testgen-repair-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + data = response.json() + return (data["generated_tests"], data["instrumented_behavior_tests"], data["instrumented_perf_tests"]) + self.log_error_response(response, "repairing generated tests", "cli-testgen-repair-error-response") + return None diff --git a/src/codeflash_python/api/cfapi.py b/src/codeflash_python/api/cfapi.py new file mode 100644 index 000000000..e3ad5a84e --- /dev/null +++ b/src/codeflash_python/api/cfapi.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import git +import requests +import sentry_sdk +from pydantic.json import pydantic_encoder + +from codeflash_python.code_utils.code_utils import exit_with_message +from codeflash_python.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number +from codeflash_python.code_utils.git_utils import get_current_branch, get_repo_owner_and_name +from codeflash_python.result.pr_comment import FileDiffContent, PrComment +from codeflash_python.version import __version__ + +if TYPE_CHECKING: + from requests import Response + + from codeflash_python.result.explanation import Explanation + +from packaging import version + +logger = logging.getLogger("codeflash_python") + + +@dataclass +class BaseUrls: + cfapi_base_url: str | None = None + cfwebapp_base_url: str | None = None + + +@lru_cache(maxsize=1) +def get_cfapi_base_urls() -> BaseUrls: + if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local": + cfapi_base_url = "http://localhost:3001" + cfwebapp_base_url = "http://localhost:3000" + logger.info("Using local CF API at %s.", cfapi_base_url) + else: + cfapi_base_url = "https://app.codeflash.ai" + cfwebapp_base_url = "https://app.codeflash.ai" + return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url) + + +def make_cfapi_request( + endpoint: str, + method: str, + payload: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, + *, + api_key: str | None = None, + suppress_errors: bool = False, + params: dict[str, Any] | None = None, +) -> Response: + """Make an HTTP request using the specified method, URL, headers, and JSON payload. + + :param endpoint: The endpoint URL to send the request to. + :param method: The HTTP method to use ('GET', 'POST', etc.). + :param payload: Optional JSON payload to include in the POST request body. + :param extra_headers: Optional extra headers to include in the request. + :param api_key: Optional API key to use for authentication. + :param suppress_errors: If True, suppress error logging for HTTP errors. + :param params: Optional query parameters for GET requests. + :return: The response object from the API. + """ + url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}" + final_api_key = api_key or get_codeflash_api_key() + cfapi_headers = {"Authorization": f"Bearer {final_api_key}"} + if extra_headers: + cfapi_headers.update(extra_headers) + try: + if method.upper() == "POST": + json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) + cfapi_headers["Content-Type"] = "application/json" + response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60) + else: + response = requests.get(url, headers=cfapi_headers, params=params, timeout=60) + response.raise_for_status() + return response + except requests.exceptions.HTTPError: + # response may be either a string or JSON, so we handle both cases + error_message = "" + try: + json_response = response.json() + if "error" in json_response: + error_message = json_response["error"] + elif "message" in json_response: + error_message = json_response["message"] + except (ValueError, TypeError): + error_message = response.text + + if not suppress_errors: + logger.exception( + "CF_API_Error:: making request to Codeflash API (url: %s, method: %s, status %s): %s", + url, + method, + response.status_code, + error_message, + ) + return response + + +@lru_cache(maxsize=1) +def get_user_id(api_key: str | None = None) -> str | None: + """Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint. + + :param api_key: The API key to use. If None, uses get_codeflash_api_key(). + :return: The userid or None if the request fails. + """ + if not api_key and not ensure_codeflash_api_key(): + return None + + response = make_cfapi_request( + endpoint="/cli-get-user", + method="GET", + extra_headers={"cli_version": __version__}, + api_key=api_key, + suppress_errors=True, + ) + if response.status_code == 200: + if "min_version" not in response.text: + return response.text + resp_json = response.json() + userid: str | None = resp_json.get("userId") + min_version: str | None = resp_json.get("min_version") + if userid: + if min_version and version.parse(min_version) > version.parse(__version__): + msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`." + exit_with_message(msg, error_on_exit=True) + return userid + + logger.error("Failed to retrieve userid from the response.") + return None + + if response.status_code == 403: + error_title = "Invalid Codeflash API key. The API key you provided is not valid." + msg = ( + f"{error_title}\n" + "Please generate a new one at https://app.codeflash.ai/app/apikeys ,\n" + "then set it as a CODEFLASH_API_KEY environment variable.\n" + "For more information, refer to the documentation at \n" + "https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup\n" + "or\n" + "https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#automated-setup-recommended" + ) + exit_with_message(msg, error_on_exit=True) + + # For other errors, log and return None (backward compatibility) + logger.error("Failed to look up your userid; is your CF API key valid? (%s)", response.reason) + return None + + +def suggest_changes( + owner: str, + repo: str, + pr_number: int, + file_changes: dict[str, FileDiffContent], + pr_comment: PrComment, + existing_tests: str, + generated_tests: str, + trace_id: str, + coverage_message: str, + replay_tests: str = "", + concolic_tests: str = "", + optimization_review: str = "", + original_line_profiler: str | None = None, + optimized_line_profiler: str | None = None, +) -> Response: + """Suggest changes to a pull request. + + Will make a review suggestion when possible; + or create a new dependent pull request with the suggested changes. + :param owner: The owner of the repository. + :param repo: The name of the repository. + :param pr_number: The number of the pull request. + :param file_changes: A dictionary of file changes. + :param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc. + :param generated_tests: The generated tests. + :param original_line_profiler: Line profiler results for original code (markdown format). + :param optimized_line_profiler: Line profiler results for optimized code (markdown format). + :return: The response object. + """ + payload = { + "owner": owner, + "repo": repo, + "pullNumber": pr_number, + "diffContents": file_changes, + "prCommentFields": pr_comment.to_json(), + "existingTests": existing_tests, + "generatedTests": generated_tests, + "traceId": trace_id, + "coverage_message": coverage_message, + "replayTests": replay_tests, + "concolicTests": concolic_tests, + "optimizationReview": optimization_review, + "originalLineProfiler": original_line_profiler, + "optimizedLineProfiler": optimized_line_profiler, + } + + return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) + + +def create_pr( + owner: str, + repo: str, + base_branch: str, + file_changes: dict[str, FileDiffContent], + pr_comment: PrComment, + existing_tests: str, + generated_tests: str, + trace_id: str, + coverage_message: str, + replay_tests: str = "", + concolic_tests: str = "", + optimization_review: str = "", + original_line_profiler: str | None = None, + optimized_line_profiler: str | None = None, +) -> Response: + """Create a pull request, targeting the specified branch. (usually 'main'). + + :param owner: The owner of the repository. + :param repo: The name of the repository. + :param base_branch: The base branch to target. + :param file_changes: A dictionary of file changes. + :param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc. + :param generated_tests: The generated tests. + :param original_line_profiler: Line profiler results for original code (markdown format). + :param optimized_line_profiler: Line profiler results for optimized code (markdown format). + :return: The response object. + """ + # convert Path objects to strings + payload = { + "owner": owner, + "repo": repo, + "baseBranch": base_branch, + "diffContents": file_changes, + "prCommentFields": pr_comment.to_json(), + "existingTests": existing_tests, + "generatedTests": generated_tests, + "traceId": trace_id, + "coverage_message": coverage_message, + "replayTests": replay_tests, + "concolicTests": concolic_tests, + "optimizationReview": optimization_review, + "originalLineProfiler": original_line_profiler, + "optimizedLineProfiler": optimized_line_profiler, + } + + return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) + + +def setup_github_actions(owner: str, repo: str, base_branch: str, workflow_content: str) -> Response: + """Set up GitHub Actions workflow by creating a PR with the workflow file. + + :param owner: Repository owner (username or organization) + :param repo: Repository name + :param base_branch: Base branch to create PR against (e.g., "main", "master") + :param workflow_content: Content of the GitHub Actions workflow file (YAML) + :return: Response object with pr_url and pr_number on success + """ + payload = {"owner": owner, "repo": repo, "baseBranch": base_branch, "workflowContent": workflow_content} + + return make_cfapi_request(endpoint="/setup-github-actions", method="POST", payload=payload) + + +def create_staging( + original_code: dict[Path, str], + new_code: dict[Path, str], + explanation: Explanation, + existing_tests_source: str, + generated_original_test_source: str, + function_trace_id: str, + coverage_message: str, + replay_tests: str, + concolic_tests: str, + root_dir: Path, + optimization_review: str = "", + original_line_profiler: str | None = None, + optimized_line_profiler: str | None = None, + **_kwargs: Any, +) -> Response: + """Create a staging pull request, targeting the specified branch. (usually 'staging'). + + :param original_code: A mapping of file paths to original source code. + :param new_code: A mapping of file paths to optimized source code. + :param explanation: An Explanation object with optimization details. + :param existing_tests_source: Existing test code. + :param generated_original_test_source: Generated tests for the original function. + :param function_trace_id: Unique identifier for this optimization trace. + :param coverage_message: Coverage report or summary. + :param original_line_profiler: Line profiler results for original code (markdown format). + :param optimized_line_profiler: Line profiler results for optimized code (markdown format). + :return: The response object from the backend. + """ + relative_path = explanation.file_path.relative_to(root_dir).as_posix() + + build_file_changes = { + Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p]) + for p in original_code + } + + payload = { + "baseBranch": get_current_branch(), + "diffContents": build_file_changes, + "prCommentFields": PrComment( + optimization_explanation=explanation.explanation_message(), + best_runtime=explanation.best_runtime_ns, + original_runtime=explanation.original_runtime_ns, + function_name=explanation.function_name, + relative_file_path=relative_path, + speedup_x=explanation.speedup_x, + speedup_pct=explanation.speedup_pct, + winning_behavior_test_results=explanation.winning_behavior_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details, + ).to_json(), + "existingTests": existing_tests_source, + "generatedTests": generated_original_test_source, + "traceId": function_trace_id, + "coverage_message": coverage_message, + "replayTests": replay_tests, + "concolicTests": concolic_tests, + "optimizationReview": optimization_review, + "originalLineProfiler": original_line_profiler, + "optimizedLineProfiler": optimized_line_profiler, + } + + return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) + + +def is_github_app_installed_on_repo(owner: str, repo: str, *, suppress_errors: bool = False) -> bool: + """Check if the Codeflash GitHub App is installed on the specified repository. + + :param owner: The owner of the repository. + :param repo: The name of the repository. + :param suppress_errors: If True, suppress error logging when the app is not installed. + :return: True if the app is installed, False otherwise. + """ + response = make_cfapi_request( + endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}", method="GET", suppress_errors=suppress_errors + ) + return response.ok and response.text == "true" + + +def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: + """Retrieve blocklisted functions for the current pull request. + + Returns A dictionary mapping filenames to sets of blocklisted function names. + """ + pr_number = get_pr_number() + if pr_number is None: + return {} + + try: + owner, repo = get_repo_owner_and_name() + information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo} + + req = make_cfapi_request( + endpoint="/verify-existing-optimizations", method="POST", payload=information, suppress_errors=True + ) + assert req is not None + if req.status_code >= 500: # type: ignore[unsupported-operator] + logger.error("Server error getting blocklisted functions: %s", req.status_code) + sentry_sdk.capture_message(f"Server error in verify-existing-optimizations: {req.status_code}") + return {} + if not req.ok: + if req.status_code == 401: + logger.debug("Not authorized to check blocklisted functions for %s/%s PR #%s", owner, repo, pr_number) + elif req.status_code == 404: + logger.debug("PR #%s not found for %s/%s", pr_number, owner, repo) + else: + logger.warning("Unexpected response %s from verify-existing-optimizations", req.status_code) + return {} + + content: dict[str, list[str]] = req.json() + + if "error" in content: + logger.debug("No existing optimizations found for PR #%s", pr_number) + return {} + + logger.debug("Found %s files with blocklisted functions for PR #%s", len(content), pr_number) + return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} + + except Exception as e: + logger.exception("Error getting blocklisted functions: %s", e) + sentry_sdk.capture_exception(e) + return {} + + +def is_function_being_optimized_again( + owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]] +) -> Any: + """Check if the function being optimized is being optimized again.""" + response = make_cfapi_request( + "/is-already-optimized", + "POST", + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts}, + ) + response.raise_for_status() + return response.json() + + +def add_code_context_hash(code_context_hash: str) -> None: + """Add code context to the DB cache.""" + pr_number = get_pr_number() + if pr_number is None: + return + try: + owner, repo = get_repo_owner_and_name() + pr_number = get_pr_number() + except git.exc.InvalidGitRepositoryError: + return + + if owner and repo and pr_number is not None: + make_cfapi_request( + "/add-code-hash", + "POST", + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash}, + ) + + +def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) -> Response: + """Mark an optimization event as success or not. + + :param trace_id: The unique identifier for the optimization event. + :param is_optimization_found: Boolean indicating whether the optimization was found. + :return: The response object from the API. + """ + payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found} + return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload) + + +def send_completion_email() -> Response: + """Send an email notification when codeflash --all completes.""" + try: + owner, repo = get_repo_owner_and_name() + except Exception as e: + sentry_sdk.capture_exception(e) + response = requests.Response() + response.status_code = 500 + return response + payload = {"owner": owner, "repo": repo} + return make_cfapi_request(endpoint="/send-completion-email", method="POST", payload=payload) diff --git a/src/codeflash_python/api/types.py b/src/codeflash_python/api/types.py new file mode 100644 index 000000000..027c46620 --- /dev/null +++ b/src/codeflash_python/api/types.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from enum import Enum +from typing import NamedTuple + +from pydantic.dataclasses import dataclass + +from codeflash_python.models.models import OptimizedCandidateSource + + +@dataclass(frozen=True) +class AIServiceRefinerRequest: + """Request model for code refinement API. + + Supports multi-language optimization refinement with optional multi-file context. + """ + + optimization_id: str + original_source_code: str + read_only_dependency_code: str + original_code_runtime: int + optimized_source_code: str + optimized_explanation: str + optimized_code_runtime: int + speedup: str + trace_id: str + original_line_profiler_results: str + optimized_line_profiler_results: str + function_references: str | None = None + call_sequence: int | None = None + language: str = "python" + language_version: str | None = None + additional_context_files: dict[str, str] | None = None + + +@dataclass(frozen=True) +class AdaptiveOptimizedCandidate: + optimization_id: str + source_code: str + explanation: str + source: OptimizedCandidateSource + speedup: str + + +@dataclass(frozen=True) +class AIServiceAdaptiveOptimizeRequest: + trace_id: str + original_source_code: str + candidates: list[AdaptiveOptimizedCandidate] + + +class TestDiffScope(str, Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + original_pass: bool + candidate_pass: bool + + original_value: str | None = None + candidate_value: str | None = None + test_src_code: str | None = None + candidate_pytest_error: str | None = None + original_pytest_error: str | None = None + + +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + trace_id: str + test_diffs: list[TestDiff] + language: str = "python" + + +class OptimizationReviewResult(NamedTuple): + """Result from the optimization review API.""" + + review: str # "high", "medium", "low", or "" + explanation: str + + +class FunctionRepairInfo(NamedTuple): + function_name: str + reason: str + + +class TestFileReview(NamedTuple): + test_index: int + functions_to_repair: list[FunctionRepairInfo] diff --git a/src/codeflash_python/benchmarking/__init__.py b/src/codeflash_python/benchmarking/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/benchmarking/codeflash_trace.py b/src/codeflash_python/benchmarking/codeflash_trace.py new file mode 100644 index 000000000..641be8861 --- /dev/null +++ b/src/codeflash_python/benchmarking/codeflash_trace.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import functools +import logging +import os +import pickle +import sqlite3 +import threading +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeVar + +from codeflash_python.picklepatch.pickle_patcher import PicklePatcher + +if TYPE_CHECKING: + from types import FunctionType + +F = TypeVar("F", bound="FunctionType") + +logger = logging.getLogger("codeflash_python") + + +class CodeflashTrace: + """Decorator class that traces and profiles function execution.""" + + def __init__(self) -> None: + self.function_calls_data = [] + self.function_call_count = 0 + self.pickle_count_limit = 1000 + self._connection = None + self._trace_path = None + self._thread_local = threading.local() + self._thread_local.active_functions = set() + + def setup(self, trace_path: str) -> None: + """Set up the database connection for direct writing. + + Args: + ---- + trace_path: Path to the trace database file + + """ + try: + self._trace_path = trace_path + self._connection = sqlite3.connect(self._trace_path) + cur = self._connection.cursor() + cur.execute("PRAGMA synchronous = OFF") + cur.execute("PRAGMA journal_mode = MEMORY") + cur.execute( + "CREATE TABLE IF NOT EXISTS benchmark_function_timings(" + "function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT," + "benchmark_function_name TEXT, benchmark_module_path TEXT, benchmark_line_number INTEGER," + "function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)" + ) + self._connection.commit() + except Exception as e: + logger.warning("Database setup error: %s", e) + if self._connection: + self._connection.close() + self._connection = None + raise + + def write_function_timings(self) -> None: + """Write function call data directly to the database. + + Args: + ---- + data: List of function call data tuples to write + + """ + if not self.function_calls_data: + return # No data to write + + if self._connection is None: + assert self._trace_path is not None + self._connection = sqlite3.connect(self._trace_path) + + try: + assert self._connection is not None + cur = self._connection.cursor() + # Insert data into the benchmark_function_timings table + cur.executemany( + "INSERT INTO benchmark_function_timings" + "(function_name, class_name, module_name, file_path, benchmark_function_name, " + "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + self.function_calls_data, + ) + self._connection.commit() + self.function_calls_data = [] + except Exception as e: + logger.warning("Error writing to function timings database: %s", e) + if self._connection: + self._connection.rollback() + raise + + def open(self) -> None: + """Open the database connection.""" + if self._connection is None: + assert self._trace_path is not None + self._connection = sqlite3.connect(self._trace_path) + + def close(self) -> None: + """Close the database connection.""" + if self._connection: + self._connection.close() + self._connection = None + + def __call__(self, func: F) -> F: + """Use as a decorator to trace function execution. + + Args: + ---- + func: The function to be decorated + + Returns: + ------- + The wrapped function + + """ + func_id = (func.__module__, func.__name__) + + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003 + # Initialize thread-local active functions set if it doesn't exist + if not hasattr(self._thread_local, "active_functions"): + self._thread_local.active_functions = set() + # If it's in a recursive function, just return the result + if func_id in self._thread_local.active_functions: + return func(*args, **kwargs) + # Track active functions so we can detect recursive functions + self._thread_local.active_functions.add(func_id) + # Measure execution time + start_time = time.thread_time_ns() + result = func(*args, **kwargs) + end_time = time.thread_time_ns() + # Calculate execution time + execution_time = end_time - start_time + self.function_call_count += 1 + + # Check if currently in pytest benchmark fixture + if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False": + self._thread_local.active_functions.remove(func_id) + return result + # Get benchmark info from environment + benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "") + benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "") + benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "") + # Get class name + class_name = "" + qualname = func.__qualname__ + if "." in qualname: + class_name = qualname.split(".")[0] + + # Limit pickle count so memory does not explode + if self.function_call_count > self.pickle_count_limit: + logger.warning("Pickle limit reached") + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + normalized_file_path = Path(func.__code__.co_filename).as_posix() + self.function_calls_data.append( + ( + func.__name__, + class_name, + func.__module__, + normalized_file_path, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) + ) + return result + + try: + # Pickle the arguments + pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as e: + logger.warning("Error pickling arguments for function %s: %s", func.__name__, e) + # Add to the list of function calls without pickled args. Used for timing info only + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + normalized_file_path = Path(func.__code__.co_filename).as_posix() + self.function_calls_data.append( + ( + func.__name__, + class_name, + func.__module__, + normalized_file_path, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) + ) + return result + # Flush to database every 100 calls + if len(self.function_calls_data) > 100: + self.write_function_timings() + + # Add to the list of function calls with pickled args, to be used for replay tests + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + normalized_file_path = Path(func.__code__.co_filename).as_posix() + self.function_calls_data.append( + ( + func.__name__, + class_name, + func.__module__, + normalized_file_path, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + pickled_args, + pickled_kwargs, + ) + ) + return result + + return wrapper # type: ignore[return-value] + + +# Create a singleton instance +codeflash_trace = CodeflashTrace() diff --git a/src/codeflash_python/benchmarking/function_ranker.py b/src/codeflash_python/benchmarking/function_ranker.py new file mode 100644 index 000000000..24f0d361a --- /dev/null +++ b/src/codeflash_python/benchmarking/function_ranker.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionToOptimize +from codeflash_python.benchmarking.profile_stats import ProfileStats +from codeflash_python.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.models import FunctionToOptimize + + +logger = logging.getLogger("codeflash_python") + +pytest_patterns = { + "", # Dynamically evaluated code + "_pytest/", # Pytest internals + "pytest", # Pytest files + "pluggy/", # Plugin system + "_pydev", # PyDev debugger + "runpy.py", # Python module runner +} +pytest_func_patterns = {"pytest_", "_pytest", "runtest"} + + +def is_pytest_infrastructure(filename: str, function_name: str) -> bool: + """Check if a function is part of pytest infrastructure that should be excluded from ranking. + + This filters out pytest internal functions, hooks, and test framework code that + would otherwise dominate the ranking but aren't candidates for optimization. + """ + # Check filename patterns + for pattern in pytest_patterns: + if pattern in filename: + return True + + return any(pattern in function_name.lower() for pattern in pytest_func_patterns) + + +class FunctionRanker: + """Ranks and filters functions based on % of addressable time derived from profiling data. + + The % of addressable time is calculated as: + addressable_time = own_time + (time_spent_in_callees / call_count) + + This represents the runtime of a function plus the runtime of its immediate dependent functions, + as a fraction of overall runtime. It prioritizes functions that are computationally heavy themselves + (high `own_time`) or that make expensive calls to other functions (high average `time_spent_in_callees`). + + Functions are first filtered by an importance threshold based on their `own_time` as a + fraction of the total runtime. The remaining functions are then ranked by their % of addressable time + to identify the best candidates for optimization. + """ + + def __init__(self, trace_file_path: Path) -> None: + self.trace_file_path = trace_file_path + self._profile_stats = ProfileStats(trace_file_path.as_posix()) + self._function_stats: dict[str, dict] = {} + self._function_stats_by_name: dict[str, list[tuple[str, dict]]] = {} + self.load_function_stats() + + # Build index for faster lookups: map function_name to list of (key, stats) + for key, stats in self._function_stats.items(): + func_name = stats.get("function_name") + if func_name: + self._function_stats_by_name.setdefault(func_name, []).append((key, stats)) + + def load_function_stats(self) -> None: + try: + pytest_filtered_count = 0 + for (filename, line_number, func_name), ( + call_count, + _num_callers, + total_time_ns, + cumulative_time_ns, + _callers, + ) in self._profile_stats.stats.items(): + if call_count <= 0: + continue + + if is_pytest_infrastructure(filename, func_name): + pytest_filtered_count += 1 + continue + + # Parse function name to handle methods within classes + class_name, qualified_name, base_function_name = (None, func_name, func_name) + if "." in func_name and not func_name.startswith("<"): + parts = func_name.split(".", 1) + if len(parts) == 2: + class_name, base_function_name = parts + + # Calculate own time (total time - time spent in subcalls) + own_time_ns = total_time_ns + time_in_callees_ns = cumulative_time_ns - total_time_ns + + # Calculate addressable time (own time + avg time in immediate callees) + addressable_time_ns = own_time_ns + (time_in_callees_ns / call_count) + + function_key = f"{filename}:{qualified_name}" + self._function_stats[function_key] = { + "filename": filename, + "function_name": base_function_name, + "qualified_name": qualified_name, + "class_name": class_name, + "line_number": line_number, + "call_count": call_count, + "own_time_ns": own_time_ns, + "cumulative_time_ns": cumulative_time_ns, + "time_in_callees_ns": time_in_callees_ns, + "addressable_time_ns": addressable_time_ns, + } + + logger.debug( + "Loaded timing stats for %s functions from trace using ProfileStats " + "(filtered %s pytest infrastructure functions)", + len(self._function_stats), + pytest_filtered_count, + ) + + except Exception as e: + logger.warning("Failed to process function stats from trace file %s: %s", self.trace_file_path, e) + self._function_stats = {} + + def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None: + target_filename = function_to_optimize.file_path.name + candidates = self._function_stats_by_name.get(function_to_optimize.function_name) + if not candidates: + logger.debug( + "Could not find stats for function %s in file %s", function_to_optimize.function_name, target_filename + ) + return None + + for key, stats in candidates: + # The check preserves exact logic: "key.endswith(f"/{target_filename}") or target_filename in key" + if key.endswith(f"/{target_filename}") or target_filename in key: + return stats + + logger.debug( + "Could not find stats for function %s in file %s", function_to_optimize.function_name, target_filename + ) + return None + + def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float: + """Get the addressable time in nanoseconds for a function. + + Addressable time = own_time + (time_in_callees / call_count) + This represents the runtime of the function plus runtime of immediate dependent functions. + """ + stats = self.get_function_stats_summary(function_to_optimize) + return stats["addressable_time_ns"] if stats else 0.0 + + def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + """Ranks and filters functions based on their % of addressable time and importance. + + Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD + of file-relative runtime, then ranks the remaining functions by addressable time. + + Importance is calculated relative to functions in the same file(s) rather than + total program time. This avoids filtering out functions due to test infrastructure + overhead. + + The addressable time metric (own_time + avg time in immediate callees) prioritizes + functions that are computationally heavy themselves or that make expensive calls + to other functions. + + Args: + functions_to_optimize: List of functions to rank. + + Returns: + Important functions sorted in descending order of their addressable time. + + """ + if not self._function_stats: + logger.warning("No function stats available to rank functions.") + return [] + + # Calculate total time from functions in the same file(s) as functions to optimize + if functions_to_optimize: + # Get unique files from functions to optimize + target_files = {func.file_path.name for func in functions_to_optimize} + # Calculate total time only from functions in these files + total_program_time = sum( + s["own_time_ns"] + for s in self._function_stats.values() + if s.get("own_time_ns", 0) > 0 + and any( + str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file + for target_file in target_files + ) + ) + logger.debug( + "Using file-relative importance for %s file(s): %s. Total file time: %s ns", + len(target_files), + target_files, + format(total_program_time, ","), + ) + else: + total_program_time = sum( + s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0 + ) + + if total_program_time == 0: + logger.warning("Total program time is zero, cannot determine function importance.") + functions_to_rank = functions_to_optimize + else: + functions_to_rank = [] + for func in functions_to_optimize: + func_stats = self.get_function_stats_summary(func) + if func_stats and func_stats.get("addressable_time_ns", 0) > 0: + importance = func_stats["addressable_time_ns"] / total_program_time + if importance >= DEFAULT_IMPORTANCE_THRESHOLD: + functions_to_rank.append(func) + else: + logger.debug( + "Filtering out function %s with importance %.2f%% (below threshold %.2f%%)", + func.qualified_name, + importance * 100, + DEFAULT_IMPORTANCE_THRESHOLD * 100, + ) + + logger.info( + "Filtered down to %s important functions from %s total functions", + len(functions_to_rank), + len(functions_to_optimize), + ) + + ranked = sorted(functions_to_rank, key=self.get_function_addressable_time, reverse=True) + logger.debug( + "Function ranking order: %s", + [ + f"{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)" + for func in ranked + ], + ) + return ranked diff --git a/src/codeflash_python/benchmarking/instrument_codeflash_trace.py b/src/codeflash_python/benchmarking/instrument_codeflash_trace.py new file mode 100644 index 000000000..1851c59e2 --- /dev/null +++ b/src/codeflash_python/benchmarking/instrument_codeflash_trace.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_python.code_utils.formatter import sort_imports + +if TYPE_CHECKING: + from pathlib import Path + + from libcst import BaseStatement, ClassDef, FlattenSentinel, FunctionDef, RemovalSentinel + + from codeflash_core.models import FunctionToOptimize + + +class AddDecoratorTransformer(cst.CSTTransformer): + def __init__(self, target_functions: set[tuple[str, str]]) -> None: + super().__init__() + self.target_functions = target_functions + self.added_codeflash_trace = False + self.class_name = "" + self.function_name = "" + self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace")) + + def leave_ClassDef( + self, original_node: ClassDef, updated_node: ClassDef + ) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel: + if self.class_name == original_node.name.value: + self.class_name = "" # Even if nested classes are not visited, this function is still called on them + return updated_node + + def visit_ClassDef(self, node: ClassDef) -> bool | None: + if self.class_name: # Don't go into nested class + return False + self.class_name = node.name.value + return None + + def visit_FunctionDef(self, node: FunctionDef) -> bool | None: + if self.function_name: # Don't go into nested function + return False + self.function_name = node.name.value + return None + + def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef: + if self.function_name == original_node.name.value: + self.function_name = "" + if (self.class_name, original_node.name.value) in self.target_functions: + # Add the new decorator after any existing decorators, so it gets executed first + updated_decorators = [*list(updated_node.decorators), self.decorator] + self.added_codeflash_trace = True + return updated_node.with_changes(decorators=updated_decorators) + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # Create import statement for codeflash_trace + if not self.added_codeflash_trace: + return updated_node + import_stmt = cst.SimpleStatementLine( + body=[ + cst.ImportFrom( + module=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="codeflash_python"), attr=cst.Name(value="benchmarking") + ), + attr=cst.Name(value="codeflash_trace"), + ), + names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))], + ) + ] + ) + + # Insert at the beginning of the file. We'll use isort later to sort the imports. + new_body = [import_stmt, *list(updated_node.body)] + + return updated_node.with_changes(body=new_body) + + +def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str: + """Add codeflash_trace to a function. + + Args: + ---- + code: The source code as a string + functions_to_optimize: List of FunctionToOptimize instances containing function details + + Returns: + ------- + The modified source code as a string + + """ + target_functions = set() + for function_to_optimize in functions_to_optimize: + class_name = "" + if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef": + class_name = function_to_optimize.parents[0].name + target_functions.add((class_name, function_to_optimize.function_name)) + + transformer = AddDecoratorTransformer(target_functions=target_functions) + + module = cst.parse_module(code) + modified_module = module.visit(transformer) + return modified_module.code + + +def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None: + """Instrument codeflash_trace decorator to functions to optimize.""" + for file_path, functions_to_optimize in file_to_funcs_to_optimize.items(): + # Skip codeflash's own benchmarking and picklepatch modules to avoid circular imports + # (codeflash_trace.py imports from picklepatch, and instrumenting these would cause + # them to import codeflash_trace back, creating a circular import) + posix_path = file_path.as_posix() + skip = False + for pkg_marker in ("/codeflash_python/", "/codeflash/"): + _, sep, after = posix_path.rpartition(pkg_marker) + if sep: + submodule = after.partition("/")[0] + if submodule in ("benchmarking", "picklepatch"): + skip = True + break + if skip: + continue + original_code = file_path.read_text(encoding="utf-8") + new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize) + # Modify the code + modified_code = sort_imports(code=new_code, float_to_top=True) + + # Write the modified code back to the file + file_path.write_text(modified_code, encoding="utf-8") diff --git a/src/codeflash_python/benchmarking/parse_line_profile_test_output.py b/src/codeflash_python/benchmarking/parse_line_profile_test_output.py new file mode 100644 index 000000000..593012b2b --- /dev/null +++ b/src/codeflash_python/benchmarking/parse_line_profile_test_output.py @@ -0,0 +1,131 @@ +"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License).""" + +from __future__ import annotations + +import inspect +import linecache +from pathlib import Path + +import dill as pickle + +from codeflash_python.code_utils.tabulate import tabulate + + +def show_func( + filename: str, start_lineno: int, func_name: str, timings: list[tuple[int, int, float]], unit: float +) -> str: + total_hits = sum(t[1] for t in timings) + total_time = sum(t[2] for t in timings) + out_table = "" + table_rows = [] + if total_hits == 0: + return "" + scalar = 1 + sublines = [] + if Path(filename).exists(): + out_table += f"## Function: {func_name}\n" + # Clear the cache to ensure that we get up-to-date results. + linecache.clearcache() + all_lines = linecache.getlines(filename) + sublines = inspect.getblock(all_lines[start_lineno - 1 :]) + out_table += "## Total time: %g s\n" % (total_time * unit) + # Define minimum column sizes so text fits and usually looks consistent + default_column_sizes = {"hits": 9, "time": 12, "perhit": 8, "percent": 8} + display = {} + # Loop over each line to determine better column formatting. + # Fallback to scientific notation if columns are larger than a threshold. + for lineno, nhits, time in timings: + percent = "" if total_time == 0 else "%5.1f" % (100 * time / total_time) + + time_disp = "%5.1f" % (time * scalar) + if len(time_disp) > default_column_sizes["time"]: + time_disp = "%5.1g" % (time * scalar) + perhit_disp = "%5.1f" % (float(time) * scalar / nhits) + if len(perhit_disp) > default_column_sizes["perhit"]: + perhit_disp = "%5.1g" % (float(time) * scalar / nhits) + nhits_disp = str(nhits) + if len(nhits_disp) > default_column_sizes["hits"]: + nhits_disp = f"{nhits:g}" + display[lineno] = (nhits_disp, time_disp, perhit_disp, percent) + linenos = range(start_lineno, start_lineno + len(sublines)) + empty = ("", "", "", "") + table_cols = ("Hits", "Time", "Per Hit", "% Time", "Line Contents") + for lineno, line in zip(linenos, sublines): + nhits, time, per_hit, percent = display.get(lineno, empty) + line_ = line.rstrip("\n").rstrip("\r") + if "def" in line_ or nhits != "": + table_rows.append((nhits, time, per_hit, percent, line_)) + out_table += tabulate( + headers=table_cols, tabular_data=table_rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + out_table += "\n" + return out_table + + +def show_text(stats: dict) -> str: + """Show text for the given timings.""" + out_table = "" + out_table += "# Timer unit: {:g} s\n".format(stats["unit"]) + stats_order = sorted(stats["timings"].items()) + # Show detailed per-line information for each function. + for (fn, lineno, name), _timings in stats_order: + table_md = show_func(fn, lineno, name, stats["timings"][fn, lineno, name], stats["unit"]) + out_table += table_md + return out_table + + +def show_text_non_python(stats: dict, line_contents: dict[tuple[str, int], str]) -> str: + """Show text for non-Python timings using profiler-provided line contents.""" + out_table = "" + out_table += "# Timer unit: {:g} s\n".format(stats["unit"]) + stats_order = sorted(stats["timings"].items()) + for (fn, _lineno, name), timings in stats_order: + total_hits = sum(t[1] for t in timings) + total_time = sum(t[2] for t in timings) + if total_hits == 0: + continue + + out_table += f"## Function: {name}\n" + out_table += "## Total time: %g s\n" % (total_time * stats["unit"]) + + default_column_sizes = {"hits": 9, "time": 12, "perhit": 8, "percent": 8} + table_rows = [] + for lineno, nhits, time in timings: + if nhits == 0: + table_rows.append(("", "", "", "", line_contents.get((fn, lineno), ""))) + continue + percent = "" if total_time == 0 else "%5.1f" % (100 * time / total_time) + time_disp = f"{time:5.1f}" + if len(time_disp) > default_column_sizes["time"]: + time_disp = f"{time:5.1g}" + perhit = (float(time) / nhits) if nhits > 0 else 0.0 + perhit_disp = f"{perhit:5.1f}" + if len(perhit_disp) > default_column_sizes["perhit"]: + perhit_disp = f"{perhit:5.1g}" + nhits_disp = str(nhits) + if len(nhits_disp) > default_column_sizes["hits"]: + nhits_disp = f"{nhits:g}" + + table_rows.append((nhits_disp, time_disp, perhit_disp, percent, line_contents.get((fn, lineno), ""))) + + table_cols = ("Hits", "Time", "Per Hit", "% Time", "Line Contents") + out_table += tabulate( + headers=table_cols, tabular_data=table_rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + out_table += "\n" + return out_table + + +def parse_line_profile_results(line_profiler_output_file: Path | None) -> tuple[dict, None]: + assert line_profiler_output_file is not None + line_profiler_output_file = line_profiler_output_file.with_suffix(".lprof") + stats_dict: dict = {} + if not line_profiler_output_file.exists(): + return {"timings": {}, "unit": 0, "str_out": ""}, None + with line_profiler_output_file.open("rb") as f: + stats = pickle.load(f) + stats_dict["timings"] = stats.timings + stats_dict["unit"] = stats.unit + str_out = show_text(stats_dict) + stats_dict["str_out"] = str_out + return stats_dict, None diff --git a/src/codeflash_python/benchmarking/plugin/__init__.py b/src/codeflash_python/benchmarking/plugin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/benchmarking/plugin/plugin.py b/src/codeflash_python/benchmarking/plugin/plugin.py new file mode 100644 index 000000000..48866e627 --- /dev/null +++ b/src/codeflash_python/benchmarking/plugin/plugin.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import importlib.util +import os +import sqlite3 +import sys +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import pytest + +from codeflash_python.benchmarking.codeflash_trace import codeflash_trace +from codeflash_python.code_utils.code_utils import module_name_from_file_path + +if TYPE_CHECKING: + from codeflash_python.models.models import BenchmarkKey + +PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None + + +class CodeFlashBenchmarkPlugin: + def __init__(self) -> None: + self.trace_path = None + self.connection = None + self.project_root = None + self.benchmark_timings = [] + + @staticmethod + @pytest.hookimpl + def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks" + ) + + def setup(self, trace_path: str | Path, project_root: str | Path) -> None: + try: + # Open connection + self.project_root = project_root + self.trace_path = trace_path + self.connection = sqlite3.connect(self.trace_path) + cur = self.connection.cursor() + cur.execute("PRAGMA synchronous = OFF") + cur.execute("PRAGMA journal_mode = MEMORY") + cur.execute( + "CREATE TABLE IF NOT EXISTS benchmark_timings(" + "benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER," + "benchmark_time_ns INTEGER)" + ) + self.connection.commit() + self.close() # Reopen only at the end of pytest session + except Exception as e: + print(f"Database setup error: {e}") + if self.connection: + self.connection.close() + self.connection = None + raise + + def write_benchmark_timings(self) -> None: + if not self.benchmark_timings: + return # No data to write + + if self.connection is None: + assert self.trace_path is not None + self.connection = sqlite3.connect(self.trace_path) + + try: + cur = self.connection.cursor() + # Insert data into the benchmark_timings table + cur.executemany( + "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", + self.benchmark_timings, + ) + self.connection.commit() + self.benchmark_timings = [] # Clear the benchmark timings list + except Exception as e: + print(f"Error writing to benchmark timings database: {e}") + self.connection.rollback() + raise + + def close(self) -> None: + if self.connection: + self.connection.close() + self.connection = None + + @staticmethod + def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]: + from codeflash_python.models.models import BenchmarkKey + + """Process the trace file and extract timing data for all functions. + + Args: + ---- + trace_path: Path to the trace file + + Returns: + ------- + A nested dictionary where: + - Outer keys are module_name.qualified_name (module.class.function) + - Inner keys are of type BenchmarkKey + - Values are function timing in milliseconds + + """ + # Initialize the result dictionary + result = {} + + # Connect to the SQLite database + connection = sqlite3.connect(trace_path) + cursor = connection.cursor() + + try: + # Query the function_calls table for all function calls + cursor.execute( + "SELECT module_name, class_name, function_name, " + "benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns " + "FROM benchmark_function_timings" + ) + + # Process each row + for row in cursor.fetchall(): + module_name, class_name, function_name, benchmark_file, benchmark_func, _benchmark_line, time_ns = row + + # Create the function key (module_name.class_name.function_name) + if class_name: + qualified_name = f"{module_name}.{class_name}.{function_name}" + else: + qualified_name = f"{module_name}.{function_name}" + + # Create the benchmark key (file::function::line) + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + # Initialize the inner dictionary if needed + if qualified_name not in result: + result[qualified_name] = {} + + # If multiple calls to the same function in the same benchmark, + # add the times together + if benchmark_key in result[qualified_name]: + result[qualified_name][benchmark_key] += time_ns + else: + result[qualified_name][benchmark_key] = time_ns + + finally: + # Close the connection + connection.close() + + return result + + @staticmethod + def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: + from codeflash_python.models.models import BenchmarkKey + + """Extract total benchmark timings from trace files. + + Args: + ---- + trace_path: Path to the trace file + + Returns: + ------- + A dictionary mapping where: + - Keys are of type BenchmarkKey + - Values are total benchmark timing in milliseconds (with overhead subtracted) + + """ + # Initialize the result dictionary + result = {} + overhead_by_benchmark = {} + + # Connect to the SQLite database + connection = sqlite3.connect(trace_path) + cursor = connection.cursor() + + try: + # Query the benchmark_function_timings table to get total overhead for each benchmark + cursor.execute( + "SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) " + "FROM benchmark_function_timings " + "GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number" + ) + + # Process overhead information + for row in cursor.fetchall(): + benchmark_file, benchmark_func, _benchmark_line, total_overhead_ns = row + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case + + # Query the benchmark_timings table for total times + cursor.execute( + "SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns " + "FROM benchmark_timings" + ) + + # Process each row and subtract overhead + for row in cursor.fetchall(): + benchmark_file, benchmark_func, _benchmark_line, time_ns = row + + # Create the benchmark key (file::function::line) + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + # Subtract overhead from total time + overhead = overhead_by_benchmark.get(benchmark_key, 0) + result[benchmark_key] = time_ns - overhead + + finally: + # Close the connection + connection.close() + + return result + + # Pytest hooks + @pytest.hookimpl + def pytest_sessionfinish(self, session, exitstatus) -> None: + """Execute after whole test run is completed.""" + # Write any remaining benchmark timings to the database + codeflash_trace.close() + if self.benchmark_timings: + self.write_benchmark_timings() + # Close the database connection + self.close() + + @staticmethod + def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + # Skip tests that don't have the benchmark fixture + if not config.getoption("--codeflash-trace"): + return + + skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture") + for item in items: + # Check for direct benchmark fixture usage + has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator] + + # Check for @pytest.mark.benchmark marker + has_marker = False + if hasattr(item, "get_closest_marker"): + marker = item.get_closest_marker("benchmark") + if marker is not None: + has_marker = True + + # Skip if neither fixture nor marker is present + if not (has_fixture or has_marker): + item.add_marker(skip_no_benchmark) + + @pytest.fixture + def benchmark(self, request: pytest.FixtureRequest) -> CodeFlashBenchmarkPlugin.Benchmark: + return self.Benchmark(request) + + # Benchmark fixture + class Benchmark: # noqa: D106 + def __init__(self, request: pytest.FixtureRequest) -> None: + self.request = request + + def __call__(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN204 + """Handle both direct function calls and decorator usage.""" + if args or kwargs: + # Used as benchmark(func, *args, **kwargs) + return self.run_benchmark(func, *args, **kwargs) + + # Used as @benchmark decorator + def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003 + return func(*args, **kwargs) + + self.run_benchmark(func) + return wrapped_func + + def run_benchmark(self, func, *args, **kwargs) -> Any: # noqa: ANN002, ANN003 + """Actual benchmark implementation.""" + node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None) + if node_path is None: + raise RuntimeError("Unable to determine test file path from pytest node") + + benchmark_module_path = module_name_from_file_path( + Path(str(node_path)), Path(str(codeflash_benchmark_plugin.project_root)), traverse_up=True + ) + + benchmark_function_name = self.request.node.name + line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001 + # Set env vars + os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name + os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path + os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) + os.environ["CODEFLASH_BENCHMARKING"] = "True" + # Run the function + start = time.perf_counter_ns() + result = func(*args, **kwargs) + end = time.perf_counter_ns() + # Reset the environment variable + os.environ["CODEFLASH_BENCHMARKING"] = "False" + + # Write function calls + codeflash_trace.write_function_timings() + # Reset function call count + codeflash_trace.function_call_count = 0 + # Add to the benchmark timings buffer + codeflash_benchmark_plugin.benchmark_timings.append( + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) + + return result + + +codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() diff --git a/src/codeflash_python/benchmarking/profile_stats.py b/src/codeflash_python/benchmarking/profile_stats.py new file mode 100644 index 000000000..7f6053037 --- /dev/null +++ b/src/codeflash_python/benchmarking/profile_stats.py @@ -0,0 +1,93 @@ +import json +import logging +import pstats +import sqlite3 +from copy import copy +from pathlib import Path + +logger = logging.getLogger("codeflash_python") + + +class ProfileStats(pstats.Stats): + def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None: + assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" + assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}" + self.trace_file_path = trace_file_path + self.time_unit = time_unit + logger.debug(hasattr(self, "create_stats")) + super().__init__(copy(self)) # type: ignore[arg-type] + + def create_stats(self) -> None: + self.con = sqlite3.connect(self.trace_file_path) + cur = self.con.cursor() + pdata = cur.execute("SELECT * FROM pstats").fetchall() + self.con.close() + time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit] + self.stats = {} + for ( + filename, + line_number, + function, + class_name, + call_count_nonrecursive, + num_callers, + total_time_ns, + cumulative_time_ns, + callers, + ) in pdata: + loaded_callers = json.loads(callers) + unmapped_callers = {} + for caller in loaded_callers: + caller_key = caller["key"] + if isinstance(caller_key, list): + caller_key = tuple(caller_key) + elif not isinstance(caller_key, tuple): + caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key) + unmapped_callers[caller_key] = caller["value"] + + # Create function key with class name if present (matching tracer.py format) + function_name = f"{class_name}.{function}" if class_name else function + + self.stats[(filename, line_number, function_name)] = ( + call_count_nonrecursive, + num_callers, + total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns, + cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns, + unmapped_callers, + ) + + def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 + # Copied from pstats.Stats.print_stats and modified to print the correct time unit + for filename in self.files: # type: ignore[attr-defined] + print(filename, file=self.stream) # type: ignore[attr-defined] + if self.files: # type: ignore[attr-defined] + print(file=self.stream) # type: ignore[attr-defined] + indent = " " * 8 + for func in self.top_level: # type: ignore[attr-defined] + print(indent, func[2], file=self.stream) # type: ignore[attr-defined] + + print(indent, self.total_calls, "function calls", end=" ", file=self.stream) # type: ignore[attr-defined] + if self.total_calls != self.prim_calls: # type: ignore[attr-defined] + print(f"({self.prim_calls:d} primitive calls)", end=" ", file=self.stream) # type: ignore[attr-defined] + time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] + print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) # type: ignore[attr-defined] + print(file=self.stream) # type: ignore[attr-defined] + _width, list_ = self.get_print_list(amount) + if list_: + self.print_title() + for func in list_: + self.print_line(func) + print(file=self.stream) # type: ignore[attr-defined] + print(file=self.stream) # type: ignore[attr-defined] + return self + + +def get_trace_total_run_time_ns(trace_file_path: Path) -> int: + if not trace_file_path.is_file(): + return 0 + con = sqlite3.connect(trace_file_path) + cur = con.cursor() + time_data = cur.execute("SELECT time_ns FROM total_time").fetchone() + con.close() + time_data = time_data[0] if time_data else 0 + return int(time_data) diff --git a/src/codeflash_python/benchmarking/pytest_new_process_trace_benchmarks.py b/src/codeflash_python/benchmarking/pytest_new_process_trace_benchmarks.py new file mode 100644 index 000000000..44b658396 --- /dev/null +++ b/src/codeflash_python/benchmarking/pytest_new_process_trace_benchmarks.py @@ -0,0 +1,50 @@ +import logging +import sys +from pathlib import Path + +from codeflash_python.benchmarking.codeflash_trace import codeflash_trace +from codeflash_python.benchmarking.plugin.plugin import codeflash_benchmark_plugin + +logger = logging.getLogger("codeflash_python") + +benchmarks_root = sys.argv[1] +tests_root = sys.argv[2] +trace_file = sys.argv[3] +project_root = Path.cwd() + +if __name__ == "__main__": + import pytest + + orig_recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(orig_recursion_limit * 2) + + try: + codeflash_benchmark_plugin.setup(trace_file, project_root) + codeflash_trace.setup(trace_file) + exitcode = pytest.main( + [ + benchmarks_root, + "--codeflash-trace", + "-p", + "no:benchmark", + "-p", + "no:codspeed", + "-p", + "no:cov", + "-p", + "no:profiling", + "-p", + "no:codeflash-benchmark", + "-s", + "-o", + "addopts=", + ], + plugins=[codeflash_benchmark_plugin], + ) + except Exception as e: + logger.warning("Failed to collect tests: %s", e) + exitcode = -1 + finally: + sys.setrecursionlimit(orig_recursion_limit) + + sys.exit(exitcode) diff --git a/src/codeflash_python/benchmarking/replay_test.py b/src/codeflash_python/benchmarking/replay_test.py new file mode 100644 index 000000000..90c1ce987 --- /dev/null +++ b/src/codeflash_python/benchmarking/replay_test.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import logging +import re +import sqlite3 +import textwrap +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash_python.code_utils.formatter import sort_imports +from codeflash_python.discovery.function_visitors import inspect_top_level_functions_or_methods +from codeflash_python.verification.verification_utils import get_test_file_path + +if TYPE_CHECKING: + from collections.abc import Generator + + +logger = logging.getLogger("codeflash_python") + +benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+") + + +def get_next_arg_and_return( + trace_file: str, + benchmark_function_name: str, + function_name: str, + file_path: str, + class_name: str | None = None, + num_to_get: int = 256, +) -> Generator[Any]: + db = sqlite3.connect(trace_file) + cur = db.cursor() + limit = num_to_get + + normalized_file_path = Path(file_path).as_posix() + + if class_name is not None: + cursor = cur.execute( + "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?", + (benchmark_function_name, function_name, normalized_file_path, class_name, limit), + ) + else: + cursor = cur.execute( + "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?", + (benchmark_function_name, function_name, normalized_file_path, limit), + ) + + try: + while (val := cursor.fetchone()) is not None: + yield val[9], val[10] # pickled_args, pickled_kwargs + finally: + db.close() + + +def get_function_alias(module: str, function_name: str) -> str: + return "_".join(module.split(".")) + "_" + function_name + + +def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str: + clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_") + + base_alias = get_function_alias(module, function_name) + if class_name: + class_alias = get_function_alias(module, class_name) + return f"{class_alias}_{function_name}_{clean_benchmark}" + return f"{base_alias}_{clean_benchmark}" + + +def create_trace_replay_test_code( + trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256 +) -> str: + """Create a replay test for functions based on trace data. + + Args: + ---- + trace_file: Path to the SQLite database file + functions_data: List of dictionaries with function info extracted from DB + max_run_count: Maximum number of runs to include in the test + + Returns: + ------- + A string containing the test code + + """ + # Create Imports + imports = """from codeflash_python.picklepatch.pickle_patcher import PicklePatcher as pickle +from codeflash_python.benchmarking.replay_test import get_next_arg_and_return +""" + + function_imports = [] + for func in functions_data: + module_name = func.get("module_name") + function_name = func.get("function_name") + class_name = func.get("class_name", "") + assert module_name is not None + assert function_name is not None + if class_name: + function_imports.append( + f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}" + ) + else: + function_imports.append( + f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}" + ) + + imports += "\n".join(function_imports) + + functions_to_optimize = sorted( + {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} + ) + metadata = f"""functions = {functions_to_optimize} +trace_file_path = r"{trace_file}" +""" + # Templates for different types of tests + test_function_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = {function_name}(*args, **kwargs) + """ + ) + + test_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + function_name = "{orig_function_name}" + if not args: + raise ValueError("No arguments provided for the method.") + if function_name == "__init__": + ret = {class_name_alias}(*args[1:], **kwargs) + else: + ret = {class_name_alias}{method_name}(*args, **kwargs) + """ + ) + + test_class_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + if not args: + raise ValueError("No arguments provided for the method.") + ret = {class_name_alias}{method_name}(*args[1:], **kwargs) + """ + ) + test_static_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + ret = {class_name_alias}{method_name}(*args, **kwargs) + """ + ) + + # Create main body + test_template = "" + + for func in functions_data: + module_name = func.get("module_name") + function_name = func.get("function_name") + class_name = func.get("class_name") + file_path_raw = func.get("file_path") + benchmark_function_name = func.get("benchmark_function_name") + function_properties = func.get("function_properties") + assert module_name is not None + assert function_name is not None + assert file_path_raw is not None + assert benchmark_function_name is not None + assert function_properties is not None + file_path = Path(file_path_raw).as_posix() + if not class_name: + alias = get_function_alias(module_name, function_name) + test_body = test_function_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + function_name=alias, + file_path=file_path, + max_run_count=max_run_count, + ) + else: + class_name_alias = get_function_alias(module_name, class_name) + alias = get_function_alias(module_name, class_name + "_" + function_name) + + filter_variables = "" + # filter_variables = '\n args.pop("cls", None)' + method_name = "." + function_name if function_name != "__init__" else "" + if function_properties.is_classmethod: + test_body = test_class_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + elif function_properties.is_staticmethod: + test_body = test_static_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + else: + test_body = test_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + + formatted_test_body = textwrap.indent(test_body, " ") + + unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name) + test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n" + + return imports + "\n" + metadata + "\n" + test_template + + +def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int: + """Generate multiple replay tests from the traced function calls, grouped by benchmark. + + Args: + ---- + trace_file_path: Path to the SQLite database file + output_dir: Directory to write the generated tests (if None, only returns the code) + max_run_count: Maximum number of runs to include per function + + Returns: + ------- + The number of replay tests generated + + """ + count = 0 + try: + # Connect to the database + conn = sqlite3.connect(trace_file_path.as_posix()) + cursor = conn.cursor() + + # Get distinct benchmark file paths + cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings") + benchmark_files = cursor.fetchall() + + # Generate a test for each benchmark file + for benchmark_file in benchmark_files: + benchmark_module_path = benchmark_file[0] + # Get all benchmarks and functions associated with this file path + cursor.execute( + "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " + "WHERE benchmark_module_path = ?", + (benchmark_module_path,), + ) + + functions_data = [] + for row in cursor.fetchall(): + benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row + # Add this function to our list + functions_data.append( + { + "function_name": function_name, + "class_name": class_name, + "file_path": file_path, + "module_name": module_name, + "benchmark_function_name": benchmark_function_name, + "benchmark_module_path": benchmark_module_path, + "benchmark_line_number": benchmark_line_number, + "function_properties": inspect_top_level_functions_or_methods( + file_name=Path(file_path), function_or_method_name=function_name, class_name=class_name + ), + } + ) + + if not functions_data: + logger.info("No benchmark test functions found in %s", benchmark_module_path) + continue + # Generate the test code for this benchmark + test_code = create_trace_replay_test_code( + trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count + ) + test_code = sort_imports(code=test_code) + output_file = get_test_file_path( + test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay" + ) + # Write test code to file, parents = true + output_dir.mkdir(parents=True, exist_ok=True) + output_file.write_text(test_code, "utf-8") + count += 1 + + conn.close() + except Exception as e: + logger.info("Error generating replay tests: %s", e) + + return count diff --git a/src/codeflash_python/benchmarking/trace_benchmarks.py b/src/codeflash_python/benchmarking/trace_benchmarks.py new file mode 100644 index 000000000..1596ecae0 --- /dev/null +++ b/src/codeflash_python/benchmarking/trace_benchmarks.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import logging +import re +import subprocess +from pathlib import Path + +from codeflash_python.code_utils.compat import SAFE_SYS_EXECUTABLE +from codeflash_python.code_utils.shell_utils import get_cross_platform_subprocess_run_args, make_env_with_project_root + +logger = logging.getLogger("codeflash_python") + + +def trace_benchmarks_pytest( + benchmarks_root: Path, tests_root: Path, project_root: Path, trace_file: Path, timeout: int = 300 +) -> None: + benchmark_env = make_env_with_project_root(project_root) + run_args = get_cross_platform_subprocess_run_args( + cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True + ) + result = subprocess.run( # type: ignore[no-matching-overload] # noqa: PLW1510 + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "pytest_new_process_trace_benchmarks.py", + benchmarks_root, + tests_root, + trace_file, + ], + **run_args, + ) + if result.returncode != 0: + # Combine stdout and stderr for error reporting (errors often go to stderr) + combined_output = result.stdout + if result.stderr: + combined_output = combined_output + "\n" + result.stderr if combined_output else result.stderr + + if "ERROR collecting" in combined_output: + # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after + error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)" + match = re.search(error_pattern, combined_output) + error_section = match.group(1) if match else combined_output + elif "FAILURES" in combined_output: + # Pattern matches "===== FAILURES =====" (any number of =) and captures everything after + error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)" + match = re.search(error_pattern, combined_output) + error_section = match.group(1) if match else combined_output + else: + error_section = combined_output + logger.warning("Error collecting benchmarks - Pytest Exit code: %s, %s", result.returncode, error_section) + logger.debug("Full pytest output:\n%s", combined_output) diff --git a/src/codeflash_python/benchmarking/tracing_new_process.py b/src/codeflash_python/benchmarking/tracing_new_process.py new file mode 100644 index 000000000..6393cdc56 --- /dev/null +++ b/src/codeflash_python/benchmarking/tracing_new_process.py @@ -0,0 +1,864 @@ +from __future__ import annotations + +import contextlib +import datetime +import importlib +import importlib.machinery +import io +import json +import logging +import os +import pickle +import re +import sqlite3 +import sys +import threading +import time +import warnings +from collections import defaultdict +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ClassVar + +from codeflash_python.benchmarking.tracing_utils import ( + FunctionModules, + filter_files_optimized, + module_name_from_file_path, +) +from codeflash_python.picklepatch.pickle_patcher import PicklePatcher + +logger = logging.getLogger("codeflash_python") + +# Suppress dill PicklingWarning +warnings.filterwarnings("ignore", message="Cannot locate reference to") +warnings.filterwarnings("ignore", message="Cannot pickle.*recursive self-references") + +if TYPE_CHECKING: + from types import FrameType, TracebackType + + +class FakeCode: + def __init__(self, filename: str, line: int, name: str) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + +class FakeFrame: + def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None: + self.f_code = code + self.f_back = prior + self.f_locals: dict = {} + + +def patch_ap_scheduler() -> None: + if find_spec("apscheduler"): + import apscheduler.schedulers.background as bg # type: ignore[unresolved-import] + import apscheduler.schedulers.blocking as bb # type: ignore[unresolved-import] + from apscheduler.schedulers import base # type: ignore[unresolved-import] + + bg.BackgroundScheduler.start = lambda _, *_a, **_k: None + bb.BlockingScheduler.start = lambda _, *_a, **_k: None + base.BaseScheduler.add_job = lambda _, *_a, **_k: None + + +# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. +class Tracer: + """Use this class as a 'with' context manager to trace a function call. + + Traces function calls, input arguments, and profiling info. + """ + + used_once: ClassVar[bool] = False + + def __init__( + self, + config: dict, + result_pickle_file_path: Path, + functions: list[str] | None = None, + *, + disable: bool = False, + project_root: Path | None = None, + max_function_count: int = 256, + timeout: int | None = None, # seconds + command: str = "", + ) -> None: + """Use this class to trace function calls. + + :param functions: List of functions to trace. If None, trace all functions + :param disable: Disable the tracer if True + :param max_function_count: Maximum number of times to trace one function + :param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing + stops and normal execution continues. If this is None then no timeout applies + :param command: The command that initiated the tracing (for metadata storage) + """ + if functions is None: + functions = [] + if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": + logger.warning("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE") + disable = True + self.disable = disable + self._db_lock: threading.Lock | None = None + if self.disable: + return + if sys.getprofile() is not None or sys.gettrace() is not None: + logger.warning( + "Codeflash: Another profiler, debugger or coverage tool is already running. " + "Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED." + ) + self.disable = True + return + + self._db_lock = threading.Lock() + + self.con = None + self.functions = functions + self.function_modules: list[FunctionModules] = [] + self.function_count = defaultdict(int) + self.current_file_path = Path(__file__).resolve() + self.ignored_qualified_functions = { + f"{self.current_file_path}:Tracer.__exit__", + f"{self.current_file_path}:Tracer.__enter__", + } + self.max_function_count = max_function_count + self.config = config + self.project_root = project_root + self.project_root_str = str(project_root) + os.sep if project_root else "" + logger.info("Project Root: %s", self.project_root) + self.ignored_functions = {"", "", "", "", "", ""} + + self.sanitized_filename = self.sanitize_to_filename(command) + # Place trace file next to replay tests in the tests directory + try: + from codeflash_python.verification.verification_utils import get_test_file_path + except ImportError: + # If verification_utils doesn't exist yet, create a simple fallback + def get_test_file_path(test_dir: Path, function_name: str, test_type: str) -> Path: + return test_dir / "codeflash_replay_tests" / f"test_{function_name}_{test_type}.py" + + function_path = "_".join(functions) if functions else self.sanitized_filename + test_file_path = get_test_file_path( + test_dir=Path(config["tests_root"]), function_name=function_path, test_type="replay" + ) + test_file_path.parent.mkdir(parents=True, exist_ok=True) + trace_filename = test_file_path.stem + ".trace" + self.output_file = test_file_path.parent / trace_filename + self.result_pickle_file_path = result_pickle_file_path + + assert timeout is None or timeout > 0, "Timeout should be greater than 0" + self.timeout = timeout + self.next_insert = 1000 + self.trace_count = 0 + self.path_cache = {} # Cache for resolved file paths + + # Profiler variables + self.bias = 0 # calibration constant + self.timings = {} + self.cur = None + self.start_time = None + self.timer = time.process_time_ns + self.total_tt = 0 + self.simulate_call("profiler") + self.t = self.timer() + + # Store command information for metadata table + self.command = command + + def __enter__(self) -> None: + if self.disable: + return + if getattr(Tracer, "used_once", False): + logger.warning( + "Codeflash: Tracer can only be used once per program run. " + "Please only enable the Tracer once. Skipping tracing this section." + ) + self.disable = True + return + Tracer.used_once = True + + if Path(self.output_file).exists(): + logger.info("Removing existing trace file") + Path(self.output_file).unlink(missing_ok=True) + + self.con = sqlite3.connect(self.output_file, check_same_thread=False) + cur = self.con.cursor() + cur.execute("""PRAGMA synchronous = OFF""") + cur.execute("""PRAGMA journal_mode = WAL""") + cur.execute( + "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" + ) + + # Create metadata table to store command information + cur.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + + # Store command metadata + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("command", self.command)) + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("program_name", self.sanitized_filename)) + cur.execute( + "INSERT INTO metadata VALUES (?, ?)", + ("functions_filter", json.dumps(self.functions) if self.functions else None), + ) + cur.execute( + "INSERT INTO metadata VALUES (?, ?)", + ("timestamp", datetime.datetime.now(datetime.timezone.utc).isoformat()), + ) + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("project_root", str(self.project_root))) + logger.info("Codeflash: Traced Program Output Begin") + frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 + self.dispatch["call"](self, frame, 0) + self.start_time = time.time() + sys.setprofile(self.trace_callback) + threading.setprofile(self.trace_callback) + + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> None: + if self.disable or self._db_lock is None: + return + sys.setprofile(None) + threading.setprofile(None) + + with self._db_lock: + if self.con is None: + return + + self.con.commit() # Commit any pending from tracer_logic + logger.info("Codeflash: Traced Program Output End") + self.create_stats() # This calls snapshot_stats which uses self.timings + + cur = self.con.cursor() + cur.execute( + "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " + "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " + "cumulative_time_ns INTEGER, callers BLOB)" + ) + # self.stats is populated by snapshot_stats() called within create_stats() + # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data + # For now, assuming self.stats is primarily in-memory after create_stats() + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + str(Path(func[0]).resolve()), + func[1], + func[2], + func[3], + cc, + nc, + tt, + ct, + json.dumps(remapped_callers), + ), + ) + self.con.commit() + + self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory + self.print_stats("tottime") # Uses self.stats, prints to console + + cur = self.con.cursor() # New cursor + cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) + self.con.commit() + self.con.close() + self.con = None # Mark connection as closed + + # filter any functions where we did not capture the return + self.function_modules = [ + function + for function in self.function_modules + if self.function_count[ + str(function.file_name) + + ":" + + (function.class_name + "." if function.class_name else "") + + function.function_name + ] + > 0 + ] + + # These modules have been imported here now the tracer is done. It is safe to import codeflash and external modules here + + from contextlib import suppress + + import isort + + try: + from codeflash_python.benchmarking.replay_test import ( + create_trace_replay_test, # type: ignore[unresolved-import] + ) + from codeflash_python.verification.verification_utils import get_test_file_path + except ImportError: + # If modules don't exist yet, create minimal fallbacks + def create_trace_replay_test(trace_file: Path, functions: list, max_run_count: int) -> str: + return f"# Replay test for {trace_file}\n# Functions: {[f.function_name for f in functions]}\n" + + def get_test_file_path(test_dir: Path, function_name: str, test_type: str) -> Path: + return test_dir / "codeflash_replay_tests" / f"test_{function_name}_{test_type}.py" + + replay_test = create_trace_replay_test( + trace_file=self.output_file, functions=self.function_modules, max_run_count=self.max_function_count + ) + function_path = "_".join(self.functions) if self.functions else self.sanitized_filename + test_file_path = get_test_file_path( + test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" + ) + with suppress(Exception): + replay_test = isort.code(replay_test) + + Path(test_file_path).parent.mkdir(parents=True, exist_ok=True) + with Path(test_file_path).open("w", encoding="utf8") as file: + file.write(replay_test) + self.replay_test_file_path = test_file_path + + logger.info( + "Codeflash: Traced %d function calls successfully and replay test created at - %s", + self.trace_count, + test_file_path, + ) + pickle_data = {"replay_test_file_path": self.replay_test_file_path} + import pickle + + with self.result_pickle_file_path.open("wb") as file: + pickle.dump(pickle_data, file) + + def tracer_logic(self, frame: FrameType, event: str) -> None: + if event != "call": + return + if self.start_time is not None and self.timeout is not None and (time.time() - self.start_time) > self.timeout: + sys.setprofile(None) + threading.setprofile(None) + logger.warning("Codeflash: Timeout reached! Stopping tracing at %s seconds.", self.timeout) + return + if self.disable or self._db_lock is None or self.con is None: + return + + code = frame.f_code + + # Check function name first before resolving path + if code.co_name in self.ignored_functions: + return + + # Resolve file path and check validity (cached) + co_filename = code.co_filename + if co_filename in self.path_cache: + file_name, is_valid = self.path_cache[co_filename] + if not is_valid: + return + else: + resolved = os.path.realpath(co_filename) + # startswith is cheaper than Path.is_relative_to, os.path.exists avoids Path construction + is_valid = resolved.startswith(self.project_root_str) and os.path.exists(resolved) + self.path_cache[co_filename] = (resolved, is_valid) + if not is_valid: + return + file_name = resolved + if self.functions and code.co_name not in self.functions: + return + class_name = None + arguments = frame.f_locals + try: + self_arg = arguments.get("self") + if self_arg is not None: + try: + class_name = self_arg.__class__.__name__ + except AttributeError: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ + else: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ + except: # noqa: E722 + # someone can override the getattr method and raise an exception. I'm looking at you wrapt + return + + # Extract class name from co_qualname for static methods that lack self/cls + co_qualname = getattr(code, "co_qualname", "") + if class_name is None and "." in co_qualname: + qualname_parts = co_qualname.split(".") + if len(qualname_parts) >= 2: + class_name = qualname_parts[-2] + + try: + function_qualified_name = f"{file_name}:{co_qualname}" + except AttributeError: + function_qualified_name = f"{file_name}:{(class_name + '.' if class_name else '')}{code.co_name}" + if function_qualified_name in self.ignored_qualified_functions: + return + if function_qualified_name not in self.function_count: + # seeing this function for the first time — Path construction only happens here + self.function_count[function_qualified_name] = 1 + file_path = Path(file_name) + file_valid = filter_files_optimized( + file_path=file_path, + tests_root=Path(self.config["tests_root"]), + ignore_paths=[Path(p) for p in self.config["ignore_paths"]], + module_root=Path(self.config["module_root"]), + ) + if not file_valid: + # we don't want to trace this function because it cannot be optimized + self.ignored_qualified_functions.add(function_qualified_name) + return + assert self.project_root is not None + self.function_modules.append( + FunctionModules( + function_name=code.co_name, + file_name=file_path, + module_name=module_name_from_file_path(file_path, project_root_path=self.project_root), + class_name=class_name, + line_no=code.co_firstlineno, + ) + ) + else: + self.function_count[function_qualified_name] += 1 + if self.function_count[function_qualified_name] >= self.max_function_count: + self.ignored_qualified_functions.add(function_qualified_name) + return + + with self._db_lock: + # Check connection again inside lock, in case __exit__ closed it. + if self.con is None: + return + + cur = self.con.cursor() + + t_ns = time.perf_counter_ns() + original_recursion_limit = sys.getrecursionlimit() + try: + # pickling can be a recursive operator, so we need to increase the recursion limit + sys.setrecursionlimit(10000) + # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class + # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory + # leaks, bad references or side effects when unpickling. + arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals + if class_name and code.co_name == "__init__" and "self" in arguments_copy: + del arguments_copy["self"] + local_vars = PicklePatcher.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL) + sys.setrecursionlimit(original_recursion_limit) + except Exception: + self.function_count[function_qualified_name] -= 1 + sys.setrecursionlimit(original_recursion_limit) + return + + cur.execute( + "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", + (event, code.co_name, class_name, file_name, frame.f_lineno, frame.f_back.__hash__(), t_ns, local_vars), + ) + self.trace_count += 1 + self.next_insert -= 1 + if self.next_insert == 0: + self.next_insert = 1000 + self.con.commit() + + def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: + # profiler section + timer = self.timer + t = timer() - self.t - self.bias + if event == "c_call" and arg is not None: + self.c_func_name = arg.__name__ # type: ignore[union-attr] + + prof_success = bool(self.dispatch[event](self, frame, t)) + # tracer section + self.tracer_logic(frame, event) + # measure the time as the last thing before return + if prof_success: + self.t = timer() + else: + self.t = timer() - t # put back unrecorded delta + + def trace_dispatch_call(self, frame: FrameType, t: int) -> int: + """Handle call events in the profiler.""" + try: + # In multi-threaded contexts, we need to be more careful about frame comparisons + if self.cur and frame.f_back is not self.cur[-2]: + # This happens when we're in a different thread + _rpt, _rit, _ret, _rfn, rframe, _rcur = self.cur + + # Only attempt to handle the frame mismatch if we have a valid rframe + if ( + not isinstance(rframe, FakeFrame) + and hasattr(rframe, "f_back") + and hasattr(frame, "f_back") + and rframe.f_back is frame.f_back + ): + self.trace_dispatch_return(rframe, 0) + + # Get function information + fcode = frame.f_code + arguments = frame.f_locals + class_name = None + try: + if ( + "self" in arguments + and hasattr(arguments["self"], "__class__") + and hasattr(arguments["self"].__class__, "__name__") + ): + class_name = arguments["self"].__class__.__name__ + elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): + class_name = arguments["cls"].__name__ + except Exception: + pass + + fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 + except Exception: + # Handle any errors gracefully + return 0 + + def trace_dispatch_exception(self, frame: FrameType, t: int) -> int: + assert self.cur is not None + rpt, rit, ret, rfn, rframe, rcur = self.cur + if (rframe is not frame) and rcur: + return self.trace_dispatch_return(rframe, t) + self.cur = rpt, rit + t, ret, rfn, rframe, rcur + return 1 + + def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: + fn = ("", 0, self.c_func_name, None) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 + + def trace_dispatch_return(self, frame: FrameType, t: int) -> int: + if not self.cur or not self.cur[-2]: + return 0 + + # In multi-threaded environments, frames can get mismatched + if frame is not self.cur[-2]: + # Don't assert in threaded environments - frames can legitimately differ + if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: + self.trace_dispatch_return(self.cur[-2], 0) + else: + # We're in a different thread or context, can't continue with this frame + return 0 + # Prefix "r" means part of the Returning or exiting frame. + # Prefix "p" means part of the Previous or Parent or older frame. + + rpt, rit, ret, rfn, frame, rcur = self.cur + + # Guard against invalid rcur (w threading) + if not rcur: + return 0 + + rit = rit + t + frame_total = rit + ret + + ppt, pit, pet, pfn, pframe, pcur = rcur + self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + + timings = self.timings + if rfn not in timings: + # w threading, rfn can be missing + timings[rfn] = 0, 0, 0, 0, {} + cc, ns, tt, ct, callers = timings[rfn] + if not ns: + # This is the only occurrence of the function on the stack. + # Else this is a (directly or indirectly) recursive call, and + # its cumulative time will get updated when the topmost call to + # it returns. + ct = ct + frame_total + cc = cc + 1 + + if pfn in callers: + # Increment call count between these functions + callers[pfn] = callers[pfn] + 1 + # Note: This tracks stats such as the amount of time added to ct + # of this specific call, and the contribution to cc + # courtesy of this call. + else: + callers[pfn] = 1 + + timings[rfn] = cc, ns - 1, tt + rit, ct, callers + + return 1 + + dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { + "call": trace_dispatch_call, + "exception": trace_dispatch_exception, + "return": trace_dispatch_return, + "c_call": trace_dispatch_c_call, + "c_exception": trace_dispatch_return, # the C function returned + "c_return": trace_dispatch_return, + } + + def simulate_call(self, name: str) -> None: + code = FakeCode("profiler", 0, name) + pframe_raw = self.cur[-2] if self.cur else None + pframe: FakeFrame | None = pframe_raw if isinstance(pframe_raw, FakeFrame) else None + frame = FakeFrame(code, pframe) + self.dispatch["call"](self, frame, 0) # type: ignore[arg-type] + + def simulate_cmd_complete(self) -> None: + get_time = self.timer + t = get_time() - self.t + while self.cur and self.cur[-1]: + # We *can* cause assertion errors here if + # dispatch_trace_return checks for a frame match! + self.dispatch["return"](self, self.cur[-2], t) + t = 0 + self.t = get_time() - t + + def print_stats(self, sort: str | int | tuple = -1) -> None: + if not self.stats: + logger.info("Codeflash: No stats available to print") + self.total_tt = 0 + return + + if not isinstance(sort, tuple): + sort = (sort,) + + # First, convert stats to make them pstats-compatible + try: + # Initialize empty collections for pstats + self.files = [] + self.top_level = [] + + # Create entirely new dictionaries instead of modifying existing ones + new_stats = {} + new_timings = {} + + # Convert stats dictionary + stats_items = list(self.stats.items()) + for func, stats_data in stats_items: + try: + # Make sure we have 5 elements in stats_data + if len(stats_data) != 5: + logger.debug("Skipping malformed stats data for %s: %s", func, stats_data) + continue + + cc, nc, tt, ct, callers = stats_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func # Keep as is if already in correct format + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + logger.debug("Unexpected caller format: %s", caller_func) + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + # Store with new format + new_stats[new_func] = (cc, nc, tt, ct, new_callers) + except Exception as e: + logger.debug("Error converting stats for %s: %s", func, e) + continue + + timings_items = list(self.timings.items()) + for func, timing_data in timings_items: + try: + if len(timing_data) != 5: + logger.debug("Skipping malformed timing data for %s: %s", func, timing_data) + continue + + cc, ns, tt, ct, callers = timing_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + logger.debug("Unexpected caller format: %s", caller_func) + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + new_timings[new_func] = (cc, ns, tt, ct, new_callers) + except Exception as e: + logger.debug("Error converting timings for %s: %s", func, e) + continue + + self.stats = new_stats + self.timings = new_timings + + self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values()) + + total_calls = sum(cc for cc, _, _, _, _ in self.stats.values()) + total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values()) + + logger.info( + "%d function calls (%d primitive calls) in %.3f milliseconds", + total_calls, + total_primitive, + self.total_tt / 1e6, + ) + + sorted_stats = sorted( + ((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3), + key=lambda x: x[1][2], + reverse=True, + )[:25] + + for func, (cc, nc, tt, ct, _) in sorted_stats: + filename, lineno, funcname = func + calls_str = f"{cc}/{nc}" if cc != nc else str(cc) + tt_ms = tt / 1e6 + ct_ms = ct / 1e6 + per_call = tt_ms / cc if cc > 0 else 0 + cum_per_call = ct_ms / nc if nc > 0 else 0 + logger.debug( + " %s %7.3fms %7.3fms/call %7.3fms cum %7.3fms/call %s (%s:%d)", + calls_str, + tt_ms, + per_call, + ct_ms, + cum_per_call, + funcname, + Path(filename).name, + lineno, + ) + + except Exception as e: + logger.warning("Error in stats processing: %s", e) + logger.info("Traced %d function calls", self.trace_count) + self.total_tt = 0 + + def make_pstats_compatible(self) -> None: + # delete the extra class_name item from the function tuple + self.files = [] + self.top_level = [] + new_stats = {} + for func, (cc, ns, tt, ct, callers) in self.stats.items(): + new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} + new_stats[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) + new_timings = {} + for func, (cc, ns, tt, ct, callers) in self.timings.items(): + new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} + new_timings[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) + self.stats = new_stats + self.timings = new_timings + + def dump_stats(self, file: str) -> None: + import marshal + + with Path(file).open("wb") as f: + marshal.dump(self.stats, f) + + def create_stats(self) -> None: + self.simulate_cmd_complete() + self.snapshot_stats() + + def snapshot_stats(self) -> None: + self.stats = {} + for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()): + callers = caller_dict.copy() + nc = 0 + for callcnt in callers.values(): + nc += callcnt + self.stats[func] = cc, nc, tt, ct, callers + + def sanitize_to_filename(self, arg: str) -> str: + # Replace newlines with underscores + arg = arg.replace("\n", "_").replace("\r", "_") + + # Replace contiguous whitespace (including tabs and multiple spaces) with a single underscore + # Limit to 5 whitespace splits + parts = re.split(r"\s+", arg) + if len(parts) > 5: + parts = parts[:5] + + arg = "_".join(parts) + + # Remove all characters that are not alphanumeric, underscore, or dot + arg = re.sub(r"[^\w._]", "", arg) + + # Avoid filenames starting or ending with a dot or underscore + arg = arg.strip("._") + + # Limit to 100 characters + arg = arg[:100] + + # Fallback if resulting name is empty + return arg or "untitled" + + def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: + self.__enter__() + try: + exec(cmd, global_vars, local_vars) # noqa: S102 + finally: + self.__exit__(None, None, None) + return self + + +if __name__ == "__main__": + args_dict = json.loads(sys.argv[-1]) + sys.argv = sys.argv[1:-1] + patch_ap_scheduler() + if args_dict["module"]: + import runpy + + code = "run_module(modname, run_name='__main__')" + globs = {"run_module": runpy.run_module, "modname": args_dict["progname"]} + else: + sys.path.insert(0, str(Path(args_dict["progname"]).resolve().parent)) + with io.open_code(args_dict["progname"]) as fp: + code = compile(fp.read(), args_dict["progname"], "exec") + spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=args_dict["progname"]) + globs = { + "__spec__": spec, + "__file__": spec.origin, + "__name__": spec.name, + "__package__": None, + "__cached__": None, + } + args_dict["config"]["module_root"] = Path(args_dict["config"]["module_root"]) + args_dict["config"]["tests_root"] = Path(args_dict["config"]["tests_root"]) + tracer = Tracer( + config=args_dict["config"], + functions=args_dict["functions"], + max_function_count=args_dict["max_function_count"], + timeout=args_dict["timeout"], + command=args_dict["command"], + disable=args_dict["disable"], + result_pickle_file_path=Path(args_dict["result_pickle_file_path"]), + project_root=Path(args_dict["project_root"]), + ) + # code is either str (for module) or CodeType (for file), but runctx expects str + # When code is CodeType, exec() in runctx will handle it correctly + tracer.runctx(code, globs, {}) # type: ignore[arg-type] diff --git a/src/codeflash_python/benchmarking/tracing_utils.py b/src/codeflash_python/benchmarking/tracing_utils.py new file mode 100644 index 000000000..b084f6ce7 --- /dev/null +++ b/src/codeflash_python/benchmarking/tracing_utils.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import os +import site +from dataclasses import dataclass +from functools import cache +from pathlib import Path +from typing import cast + +import git + +logger = logging.getLogger("codeflash_python") + + +# This can't be pydantic dataclass because then conflicts with the logfire pytest plugin +# for pydantic tracing. We want to not use pydantic in the tracing code. +@dataclass +class FunctionModules: + function_name: str + file_name: Path + module_name: str + class_name: str | None = None + line_no: int | None = None + + +def path_belongs_to_site_packages(file_path: Path) -> bool: + site_packages = [Path(p) for p in site.getsitepackages()] + return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages) + + +def is_git_repo(file_path: str) -> bool: + try: + git.Repo(file_path, search_parent_directories=True) + return True + except git.InvalidGitRepositoryError: + return False + + +@cache +def ignored_submodule_paths(module_root: str) -> list[Path]: + if is_git_repo(module_root): + git_repo = git.Repo(module_root, search_parent_directories=True) + working_tree_dir = cast("Path", git_repo.working_tree_dir) + try: + return [Path(working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules] + except Exception as e: + logger.warning("Failed to get submodule paths %s", e) + return [] + + +def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str: + relative_path = file_path.relative_to(project_root_path) + return relative_path.with_suffix("").as_posix().replace("/", ".") + + +def is_test_file_by_pattern(file_path: Path) -> bool: + """Check if a file is a test file using naming conventions. + + Used when tests_root overlaps with module_root, so directory-based filtering would + incorrectly exclude all source files. + """ + name = file_path.name.lower() + if name.startswith("test_") or name == "conftest.py": + return True + test_name_patterns = (".test.", ".spec.", "_test.", "_spec.") + if any(p in name for p in test_name_patterns): + return True + path_str = str(file_path).lower() + test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep) + return any(p in path_str for p in test_dir_patterns) + + +def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool: + """Optimized version of the filter_functions function above. + + Takes in file paths and returns the count of files that are to be optimized. + """ + submodule_paths = None + tests_root_overlaps = tests_root == module_root or module_root.is_relative_to(tests_root) + if tests_root_overlaps: + if is_test_file_by_pattern(file_path): + return False + elif file_path.is_relative_to(tests_root): + return False + if file_path in ignore_paths or any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths): + return False + if path_belongs_to_site_packages(file_path): + return False + if not file_path.is_relative_to(module_root): + return False + if submodule_paths is None: + submodule_paths = ignored_submodule_paths(module_root) + return not ( + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + ) diff --git a/src/codeflash_python/benchmarking/utils.py b/src/codeflash_python/benchmarking/utils.py new file mode 100644 index 000000000..4fe58f893 --- /dev/null +++ b/src/codeflash_python/benchmarking/utils.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.models.models import BenchmarkDetail, ProcessedBenchmarkInfo +from codeflash_python.result.critic import performance_gain + +if TYPE_CHECKING: + from codeflash_python.models.models import BenchmarkKey + + +logger = logging.getLogger("codeflash_python") + + +def validate_and_format_benchmark_table( + function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int] +) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: + function_to_result = {} + # Process each function's benchmark data + for func_path, test_times in function_benchmark_timings.items(): + # Sort by percentage (highest first) + sorted_tests = [] + for benchmark_key, func_time in test_times.items(): + total_time = total_benchmark_timings.get(benchmark_key, 0) + if func_time > total_time: + logger.debug( + "Skipping test %s due to func_time %s > total_time %s", benchmark_key, func_time, total_time + ) + # If the function time is greater than total time, likely to have multithreading / multiprocessing issues. + # Do not try to project the optimization impact for this function. + sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0)) + elif total_time > 0: + percentage = (func_time / total_time) * 100 + # Convert nanoseconds to milliseconds + func_time_ms = func_time / 1_000_000 + total_time_ms = total_time / 1_000_000 + sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage)) + sorted_tests.sort(key=lambda x: x[3], reverse=True) + function_to_result[func_path] = sorted_tests + return function_to_result + + +def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None: + headers = ["Benchmark Module Path", "Test Function", "Total Time (ms)", "Function Time (ms)", "Percentage (%)"] + for func_path, sorted_tests in function_to_results.items(): + function_name = func_path.split(":")[-1] + + rows = [] + for benchmark_key, total_time, func_time, percentage in sorted_tests: + module_path = benchmark_key.module_path + test_function = benchmark_key.function_name + if total_time == 0.0: + rows.append([module_path, test_function, "N/A", "N/A", "N/A"]) + else: + rows.append([module_path, test_function, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}"]) + + col_widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + fmt = " ".join(f"{{:<{w}}}" for w in col_widths) + lines = [f"Function: {function_name}", fmt.format(*headers), "-" * sum([*col_widths, 2 * (len(headers) - 1)])] + for row in rows: + lines.append(fmt.format(*row)) + logger.info("\n".join(lines)) + + +def process_benchmark_data( + replay_performance_gain: dict[BenchmarkKey, float], + fto_benchmark_timings: dict[BenchmarkKey, int], + total_benchmark_timings: dict[BenchmarkKey, int], +) -> ProcessedBenchmarkInfo | None: + """Process benchmark data and generate detailed benchmark information. + + Args: + ---- + replay_performance_gain: The performance gain from replay + fto_benchmark_timings: Function to optimize benchmark timings + total_benchmark_timings: Total benchmark timings + + Returns: + ------- + ProcessedBenchmarkInfo containing processed benchmark details + + """ + if not replay_performance_gain or not fto_benchmark_timings or not total_benchmark_timings: + return None + + benchmark_details = [] + + for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items(): + total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0) + + if total_benchmark_timing == 0: + continue # Skip benchmarks with zero timing + + # Calculate expected new benchmark timing + expected_new_benchmark_timing = ( + total_benchmark_timing + - og_benchmark_timing + + (1 / (replay_performance_gain[benchmark_key] + 1)) * og_benchmark_timing + ) + + # Calculate speedup + benchmark_speedup_percent = ( + performance_gain( + original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing) + ) + * 100 + ) + + benchmark_details.append( + BenchmarkDetail( + benchmark_name=benchmark_key.module_path, + test_function=benchmark_key.function_name, + original_timing=humanize_runtime(int(total_benchmark_timing)), + expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)), + speedup_percent=benchmark_speedup_percent, + ) + ) + + return ProcessedBenchmarkInfo(benchmark_details=benchmark_details) diff --git a/src/codeflash_python/cli.py b/src/codeflash_python/cli.py new file mode 100644 index 000000000..2976fe258 --- /dev/null +++ b/src/codeflash_python/cli.py @@ -0,0 +1,191 @@ +"""CLI module for codeflash_python - minimal stub for test support.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from argparse import Namespace + + +def process_pyproject_config(args: Namespace) -> Namespace: + """Process pyproject.toml config and populate args. + + Args: + args: Parsed command-line arguments. + + Returns: + Updated args namespace with config values. + + """ + try: + from codeflash_python.code_utils.config_parser import parse_config_file + except ImportError: + # Minimal fallback if config_parser not available + pyproject_config = {} + if hasattr(args, "config_file") and args.config_file: + try: + import tomlkit + + with Path(args.config_file).open() as f: + data = tomlkit.load(f) + pyproject_config = data.get("tool", {}).get("codeflash", {}) + except Exception: + pass + else: + try: + pyproject_config, _ = parse_config_file(args.config_file) + except (ValueError, FileNotFoundError): + pyproject_config = {} + + supported_keys = [ + "module_root", + "tests_root", + "benchmarks_root", + "ignore_paths", + "pytest_cmd", + "formatter_cmds", + "disable_telemetry", + "disable_imports_sorting", + "git_remote", + "override_fixtures", + ] + + for key in supported_keys: + normalized_key = key.replace("-", "_") + if key in pyproject_config and ( + (hasattr(args, normalized_key) and getattr(args, normalized_key) is None) + or not hasattr(args, normalized_key) + ): + setattr(args, normalized_key, pyproject_config[key]) + + # Set defaults + if not hasattr(args, "module_root") or args.module_root is None: + args.module_root = str(Path.cwd()) + + if not hasattr(args, "tests_root") or args.tests_root is None: + args.tests_root = str(Path.cwd() / "tests") + + pyproject_file_path = Path(args.config_file) if hasattr(args, "config_file") and args.config_file else None + + if not hasattr(args, "project_root") or args.project_root is None: + args.project_root = str(project_root_from_module_root(Path(args.module_root), pyproject_file_path)) + + if not hasattr(args, "test_project_root") or args.test_project_root is None: + args.test_project_root = str(project_root_from_module_root(Path(args.tests_root), pyproject_file_path)) + + return args + + +def project_root_from_module_root(module_root: Path, pyproject_file_path: Path | None) -> Path: + """Find the project root by walking up from module_root.""" + module_root = module_root.resolve() + if pyproject_file_path is not None and pyproject_file_path.parent.resolve() == module_root: + return module_root + + current = module_root + while current != current.parent: + if (current / "codeflash.toml").exists(): + return current + current = current.parent + + return module_root.parent.resolve() + + +def handle_show_config() -> None: + """Show current or auto-detected Codeflash configuration.""" + from codeflash_python.setup.detector import detect_project, has_existing_config + + project_root = Path.cwd() + config_exists, _ = has_existing_config(project_root) + + if config_exists: + from codeflash_python.code_utils.config_parser import parse_config_file + + config, config_file_path = parse_config_file() + status = "Saved config" + + print() + print(f"Codeflash Configuration ({status})") + print(f"Config file: {config_file_path}") + print() + + print(f" {'Setting':<20} {'Value'}") + print(f" {'-' * 20} {'-' * 40}") + print(f" {'Project root':<20} {project_root}") + print(f" {'Module root':<20} {config.get('module_root', '(not set)')}") + print(f" {'Tests root':<20} {config.get('tests_root', '(not set)')}") + print(f" {'Test runner':<20} {config.get('test_framework', config.get('pytest_cmd', '(not set)'))}") + print( + f" {'Formatter':<20} {', '.join(config['formatter_cmds']) if config.get('formatter_cmds') else '(not set)'}" + ) + ignore_paths = config.get("ignore_paths", []) + print(f" {'Ignore paths':<20} {', '.join(str(p) for p in ignore_paths) if ignore_paths else '(none)'}") + else: + detected = detect_project(project_root) + status = "Auto-detected (not saved)" + + print() + print(f"Codeflash Configuration ({status})") + print() + + print(f" {'Setting':<20} {'Value'}") + print(f" {'-' * 20} {'-' * 40}") + print(f" {'Language':<20} {detected.language}") + print(f" {'Project root':<20} {detected.project_root}") + print(f" {'Module root':<20} {detected.module_root}") + print(f" {'Tests root':<20} {detected.tests_root if detected.tests_root else '(not detected)'}") + print(f" {'Test runner':<20} {detected.test_runner or '(not detected)'}") + print( + f" {'Formatter':<20} {', '.join(detected.formatter_cmds) if detected.formatter_cmds else '(not detected)'}" + ) + print( + f" {'Ignore paths':<20} {', '.join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else '(none)'}" + ) + print(f" {'Confidence':<20} {detected.confidence:.0%}") + + print() + + if not config_exists: + print("Run codeflash --file to auto-save this config.") + + +def handle_reset_config(confirm: bool = True) -> None: + """Remove Codeflash configuration from project config file.""" + from codeflash_python.setup.config_writer import remove_config + from codeflash_python.setup.detector import detect_project, has_existing_config + + project_root = Path.cwd() + + config_exists, _ = has_existing_config(project_root) + if not config_exists: + print("No Codeflash configuration found to remove.") + return + + detected = detect_project(project_root) + + if confirm: + print("This will remove Codeflash configuration from your project.") + print() + + config_file = "pyproject.toml" + print(f" Config file: {project_root / config_file}") + print() + + try: + response = input("Are you sure you want to remove the config? [y/N] ") + except (EOFError, KeyboardInterrupt): + print("\nCancelled.") + return + + if response.lower() not in ("y", "yes"): + print("Cancelled.") + return + + result = remove_config(project_root) + + if result.is_ok(): + print(f"Done: {result.unwrap()}") + else: + print(f"Failed: {result.error}") # type: ignore[attr-defined] diff --git a/src/codeflash_python/cli_common.py b/src/codeflash_python/cli_common.py new file mode 100644 index 000000000..8facf507d --- /dev/null +++ b/src/codeflash_python/cli_common.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import logging +import sys +from typing import NoReturn + +logger = logging.getLogger("codeflash_python") + + +def apologize_and_exit() -> NoReturn: + logger.info( + "\U0001f4a1 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!" + ) + logger.info("\U0001f44b Exiting...") + sys.exit(1) diff --git a/src/codeflash_python/code_utils/__init__.py b/src/codeflash_python/code_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/code_utils/checkpoint.py b/src/codeflash_python/code_utils/checkpoint.py new file mode 100644 index 000000000..8cd7c90d7 --- /dev/null +++ b/src/codeflash_python/code_utils/checkpoint.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import datetime +import json +import time +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash_python.code_utils.compat import codeflash_temp_dir + +if TYPE_CHECKING: + import argparse + + +class CodeflashRunCheckpoint: + def __init__(self, module_root: Path, checkpoint_dir: Path | None = None) -> None: + if checkpoint_dir is None: + checkpoint_dir = codeflash_temp_dir + self.module_root = module_root + self.checkpoint_dir = Path(checkpoint_dir) + # Create a unique checkpoint file name + unique_id = str(uuid.uuid4())[:8] + checkpoint_filename = f"codeflash_checkpoint_{unique_id}.jsonl" + self.checkpoint_path = self.checkpoint_dir / checkpoint_filename + + # Initialize the checkpoint file with metadata + self.initialize_checkpoint_file() + + def initialize_checkpoint_file(self) -> None: + """Create a new checkpoint file with metadata.""" + metadata = { + "type": "metadata", + "module_root": str(self.module_root), + "created_at": time.time(), + "last_updated": time.time(), + } + + with self.checkpoint_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(metadata) + "\n") + + def add_function_to_checkpoint( + self, + function_fully_qualified_name: str, + status: str = "optimized", + additional_info: dict[str, Any] | None = None, + ) -> None: + """Add a function to the checkpoint after it has been processed. + + Args: + ---- + function_fully_qualified_name: The fully qualified name of the function + status: Status of optimization (e.g., "optimized", "failed", "skipped") + additional_info: Any additional information to store about the function + + """ + if additional_info is None: + additional_info = {} + + function_data = { + "type": "function", + "function_name": function_fully_qualified_name, + "status": status, + "timestamp": time.time(), + **additional_info, + } + + with self.checkpoint_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(function_data) + "\n") + + # Update the metadata last_updated timestamp + self.update_metadata_timestamp() + + def update_metadata_timestamp(self) -> None: + """Update the last_updated timestamp in the metadata.""" + # Read the first line (metadata) + with self.checkpoint_path.open(encoding="utf-8") as f: + metadata = json.loads(f.readline()) + rest_content = f.read() + + # Update the timestamp + metadata["last_updated"] = time.time() + + # Write all lines to a temporary file + + with self.checkpoint_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(metadata) + "\n") + f.write(rest_content) + + def cleanup(self) -> None: + """Unlink all the checkpoint files for this module_root.""" + to_delete = [] + self.checkpoint_path.unlink(missing_ok=True) + + for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"): + with file.open(encoding="utf-8") as f: + # Skip the first line (metadata) + first_line = next(f) + metadata = json.loads(first_line) + if metadata.get("module_root", str(self.module_root)) == str(self.module_root): + to_delete.append(file) + for file in to_delete: + file.unlink(missing_ok=True) + + +def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]: + """Get information about all processed functions, regardless of status. + + Returns + ------- + Dictionary mapping function names to their processing information + + """ + processed_functions = {} + to_delete = [] + + for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"): + with file.open(encoding="utf-8") as f: + # Skip the first line (metadata) + first_line = next(f) + metadata = json.loads(first_line) + if metadata.get("last_updated"): + last_updated = datetime.datetime.fromtimestamp(metadata["last_updated"]) # noqa: DTZ006 + if datetime.datetime.now() - last_updated >= datetime.timedelta(days=7): # noqa: DTZ005 + to_delete.append(file) + continue + if metadata.get("module_root") != str(module_root): + continue + + for line in f: + entry = json.loads(line) + if entry.get("type") == "function": + processed_functions[entry["function_name"]] = entry + for file in to_delete: + file.unlink(missing_ok=True) + return processed_functions + + +def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> dict[str, dict[str, str]] | None: + previous_checkpoint_functions = None + if getattr(args, "subagent", False): + return None + if args.all and codeflash_temp_dir.is_dir(): + previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir) + if previous_checkpoint_functions and ( + getattr(args, "yes", False) + or input( + "Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point? [Y/n] " + ) + .strip() + .lower() + not in ("n", "no") + ): + pass + else: + previous_checkpoint_functions = None + + return previous_checkpoint_functions diff --git a/src/codeflash_python/code_utils/code_utils.py b/src/codeflash_python/code_utils/code_utils.py new file mode 100644 index 000000000..2651ceacf --- /dev/null +++ b/src/codeflash_python/code_utils/code_utils.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import ast +import difflib +import logging +import re +import shutil +import site +import sys +from pathlib import Path + +logger = logging.getLogger("codeflash_python") + +ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) + + +def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str: + """Return the unified diff between two code strings as a single string. + + :param code1: First code string (original). + :param code2: Second code string (modified). + :param fromfile: Label for the first code string. + :param tofile: Label for the second code string. + :return: Unified diff as a string. + """ + code1_lines = code1.splitlines(keepends=True) + code2_lines = code2.splitlines(keepends=True) + + diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="") + + return "".join(diff) + + +def encoded_tokens_len(s: str) -> int: + """Return the approximate length of the encoded tokens. + + It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). + """ + return int(len(s) * 0.25) + + +def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str: + try: + relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + return relative_path.with_suffix("").as_posix().replace("/", ".") + except ValueError: + if traverse_up: + parent = file_path.parent + while parent not in (project_root_path, parent.parent): + try: + relative_path = file_path.resolve().relative_to(parent.resolve()) + return relative_path.with_suffix("").as_posix().replace("/", ".") + except ValueError: + parent = parent.parent + msg = f"File {file_path} is not within the project root {project_root_path}." + raise ValueError(msg) from None + + +def get_imports_from_file( + file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None +) -> list[ast.Import | ast.ImportFrom]: + assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, ( + "Must provide exactly one of file_path, file_string, or file_ast" + ) + if file_path: + with file_path.open(encoding="utf8") as file: + file_string = file.read() + if file_ast is None: + if file_string is None: + logger.error("file_string cannot be None when file_ast is not provided") + return [] + try: + file_ast = ast.parse(file_string) + except SyntaxError as e: + logger.exception("Syntax error in code: %s", e) + return [] + return [node for node in ast.walk(file_ast) if isinstance(node, (ast.Import, ast.ImportFrom))] + + +def get_all_function_names(code: str) -> tuple[bool, list[str]]: + try: + module = ast.parse(code) + except SyntaxError as e: + logger.exception("Syntax error in code: %s", e) + return False, [] + + function_names = [ + node.name for node in ast.walk(module) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] + return True, function_names + + +def get_run_tmp_file(file_path: Path | str) -> Path: + import tempfile + + if isinstance(file_path, str): + file_path = Path(file_path) + if not hasattr(get_run_tmp_file, "tmpdir_path"): + # Use mkdtemp instead of TemporaryDirectory to avoid auto-cleanup + # which can delete the dir before subprocess tests finish using it + get_run_tmp_file.tmpdir_path = Path(tempfile.mkdtemp(prefix="codeflash_")) # type: ignore[attr-defined] + return get_run_tmp_file.tmpdir_path / file_path # type: ignore[attr-defined] + + +def path_belongs_to_site_packages(file_path: Path) -> bool: + file_path_resolved = file_path.resolve() + site_packages = [Path(p).resolve() for p in site.getsitepackages()] + return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages) + + +def validate_python_code(code: str) -> str: + """Validate a string of Python code by attempting to compile it.""" + try: + compile(code, "", "exec") + except SyntaxError as e: + msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})" + raise ValueError(msg) from e + return code + + +def cleanup_paths(paths: list[Path]) -> None: + for path in paths: + if path and path.exists(): + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + else: + path.unlink(missing_ok=True) + + +def restore_conftest(path_to_content_map: dict[Path, str]) -> None: + for path, file_content in path_to_content_map.items(): + path.write_text(file_content, encoding="utf8") + + +def exit_with_message(message: str, *, error_on_exit: bool = False) -> None: + """Don't Call it inside the lsp process, it will terminate the lsp server.""" + print(message) + + sys.exit(1 if error_on_exit else 0) diff --git a/src/codeflash_python/code_utils/codeflash_wrap_decorator.py b/src/codeflash_python/code_utils/codeflash_wrap_decorator.py new file mode 100644 index 000000000..1d52309a4 --- /dev/null +++ b/src/codeflash_python/code_utils/codeflash_wrap_decorator.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import asyncio +import gc +import os +import sqlite3 +import time +from enum import Enum +from functools import wraps +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Callable, TypeVar + +import dill as pickle + + +class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py + FUNCTION_CALL = ( + "function_call" # Correctness verification for a test function, checks input values and output values) + ) + INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init + INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py + if not hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") # type: ignore[attr-defined] + return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined] + + +def extract_test_context_from_env() -> tuple[str, str | None, str]: + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_behavior_async(func: F) -> F: + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} # type: ignore[attr-defined] + if test_id in async_wrapper.index: # type: ignore[attr-defined] + async_wrapper.index[test_id] += 1 # type: ignore[attr-defined] + else: + async_wrapper.index[test_id] = 0 # type: ignore[attr-defined] + + codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined] + invocation_id = f"{line_id}_{codeflash_test_index}" + test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}" + + print(f"!$######{test_stdout_tag}######$!") + + iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0") + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + codeflash_con = sqlite3.connect(db_path) + codeflash_cur = codeflash_con.cursor() + + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, " + "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + "runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) # coroutine creation has some overhead, though it is very small + counter = loop.time() + return_value = await ret # let's measure the actual execution time of the code + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + + print(f"!######{test_stdout_tag}######!") + + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value)) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + VerificationType.FUNCTION_CALL.value, + ), + ) + codeflash_con.commit() + codeflash_con.close() + + if exception: + raise exception + return return_value + + return async_wrapper # type: ignore[return-value] + + +def codeflash_performance_async(func: F) -> F: + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + + test_module_name, test_class_name, test_name = extract_test_context_from_env() + + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} # type: ignore[attr-defined] + if test_id in async_wrapper.index: # type: ignore[attr-defined] + async_wrapper.index[test_id] += 1 # type: ignore[attr-defined] + else: + async_wrapper.index[test_id] = 0 # type: ignore[attr-defined] + + codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined] + invocation_id = f"{line_id}_{codeflash_test_index}" + test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}" + + print(f"!$######{test_stdout_tag}######$!") + + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + + print(f"!######{test_stdout_tag}:{codeflash_duration}######!") + + if exception: + raise exception + return return_value + + return async_wrapper # type: ignore[return-value] + + +def codeflash_concurrency_async(func: F) -> F: + """Measures concurrent vs sequential execution performance for async functions.""" + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + function_name = func.__name__ + concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")) + + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "") + test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "") + loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0") + + # Phase 1: Sequential execution timing + gc.disable() + try: + seq_start = time.perf_counter_ns() + for _ in range(concurrency_factor): + result = await func(*args, **kwargs) + sequential_time = time.perf_counter_ns() - seq_start + finally: + gc.enable() + + # Phase 2: Concurrent execution timing + gc.disable() + try: + conc_start = time.perf_counter_ns() + tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter_ns() - conc_start + finally: + gc.enable() + + # Output parseable metrics + tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}" + print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!") + + return result + + return async_wrapper # type: ignore[return-value] diff --git a/src/codeflash_python/code_utils/compat.py b/src/codeflash_python/code_utils/compat.py new file mode 100644 index 000000000..b73a6a5a7 --- /dev/null +++ b/src/codeflash_python/code_utils/compat.py @@ -0,0 +1,17 @@ +import os +import sys +import tempfile +from pathlib import Path + +from platformdirs import user_config_dir + +LF: str = os.linesep +IS_POSIX: bool = os.name != "nt" +SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() + +codeflash_cache_dir: Path = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) + +codeflash_temp_dir: Path = Path(tempfile.gettempdir()) / "codeflash" +codeflash_temp_dir.mkdir(parents=True, exist_ok=True) + +codeflash_cache_db: Path = codeflash_cache_dir / "codeflash_cache.db" diff --git a/src/codeflash_python/code_utils/config_consts.py b/src/codeflash_python/code_utils/config_consts.py new file mode 100644 index 000000000..3b9fdfa11 --- /dev/null +++ b/src/codeflash_python/code_utils/config_consts.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import platform +from enum import Enum +from typing import Any + +# Python language constants +PYTHON_FILE_EXTENSIONS = (".py", ".pyw") +PYTHON_DEFAULT_FILE_EXTENSION = ".py" +PYTHON_DIR_EXCLUDES = frozenset( + { + "__pycache__", + ".venv", + "venv", + ".tox", + ".nox", + ".eggs", + ".mypy_cache", + ".ruff_cache", + ".pytest_cache", + ".hypothesis", + "htmlcov", + ".pytype", + ".pyre", + ".pybuilder", + ".ipynb_checkpoints", + ".codeflash", + ".cache", + ".complexipy_cache", + "build", + "dist", + "sdist", + ".coverage*", + ".pyright*", + "*.egg-info", + } +) +PYTHON_COMMENT_PREFIX = "#" +PYTHON_VALID_TEST_FRAMEWORKS = ("pytest", "unittest") +PYTHON_LANGUAGE_VERSION = platform.python_version() + +MAX_TEST_RUN_ITERATIONS = 5 +OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000 +TESTGEN_CONTEXT_TOKEN_LIMIT = 64000 +READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed" +TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed" +INDIVIDUAL_TESTCASE_TIMEOUT = 15 +MAX_FUNCTION_TEST_SECONDS = 60 +MIN_IMPROVEMENT_THRESHOLD = 0.05 +MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput +MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required +CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark +MAX_TEST_FUNCTION_RUNS = 50 +MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms +TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget +COVERAGE_THRESHOLD = 60.0 +MIN_TESTCASE_PASSED_THRESHOLD = 6 +REPEAT_OPTIMIZATION_PROBABILITY = 0.1 +MAX_TEST_REPAIR_CYCLES = 2 +DEFAULT_IMPORTANCE_THRESHOLD = 0.001 + +# pytest loop stability +# For now, we use strict thresholds (large windows and low tolerances), since this is still experimental. +STABILITY_WINDOW_SIZE = 0.35 # 35% of total window +STABILITY_CENTER_TOLERANCE = 0.0025 # ±0.25% around median +STABILITY_SPREAD_TOLERANCE = 0.0025 # 0.25% window spread + +# Refinement +REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2 + +# setting this value to 1 will disable repair if there is at least one correct candidate +MIN_CORRECT_CANDIDATES = 2 + +TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME + +MAX_CONTEXT_LEN_REVIEW = 1000 + + +class EffortLevel(str, Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class EffortKeys(str, Enum): + N_OPTIMIZER_CANDIDATES = "N_OPTIMIZER_CANDIDATES" + N_OPTIMIZER_LP_CANDIDATES = "N_OPTIMIZER_LP_CANDIDATES" + N_GENERATED_TESTS = "N_GENERATED_TESTS" + MAX_CODE_REPAIRS_PER_TRACE = "MAX_CODE_REPAIRS_PER_TRACE" + REPAIR_UNMATCHED_PERCENTAGE_LIMIT = "REPAIR_UNMATCHED_PERCENTAGE_LIMIT" + TOP_VALID_CANDIDATES_FOR_REFINEMENT = "TOP_VALID_CANDIDATES_FOR_REFINEMENT" + ADAPTIVE_OPTIMIZATION_THRESHOLD = "ADAPTIVE_OPTIMIZATION_THRESHOLD" + MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = "MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE" + + +EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = { + EffortKeys.N_OPTIMIZER_CANDIDATES.value: {EffortLevel.LOW: 3, EffortLevel.MEDIUM: 5, EffortLevel.HIGH: 6}, + EffortKeys.N_OPTIMIZER_LP_CANDIDATES.value: {EffortLevel.LOW: 4, EffortLevel.MEDIUM: 6, EffortLevel.HIGH: 7}, + # we don't use effort with generated tests for now + EffortKeys.N_GENERATED_TESTS.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 2, EffortLevel.HIGH: 2}, + # maximum number of repairs we will do for each function (in case the valid candidates is less than MIN_CORRECT_CANDIDATES) + EffortKeys.MAX_CODE_REPAIRS_PER_TRACE.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 5}, + # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted) + # on the low effort we lower the limit to 20% to be more strict (less repairs, less time) + EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT.value: { + EffortLevel.LOW: 0.2, + EffortLevel.MEDIUM: 0.3, + EffortLevel.HIGH: 0.4, + }, + # Top valid candidates for refinements + EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 4}, + # max number of adaptive optimization calls to make per a single candidates tree + EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD.value: {EffortLevel.LOW: 0, EffortLevel.MEDIUM: 0, EffortLevel.HIGH: 2}, + # max number of adaptive optimization calls to make per a single trace + EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE.value: { + EffortLevel.LOW: 0, + EffortLevel.MEDIUM: 0, + EffortLevel.HIGH: 4, + }, +} + + +def get_effort_value(key: EffortKeys, effort: EffortLevel | str) -> Any: + key_str = key.value + + if isinstance(effort, str): + try: + effort = EffortLevel(effort) + except ValueError: + msg = f"Invalid effort level: {effort}" + raise ValueError(msg) from None + + if key_str not in EFFORT_VALUES: + msg = f"Invalid key: {key_str}" + raise ValueError(msg) + + return EFFORT_VALUES[key_str][effort] diff --git a/src/codeflash_python/code_utils/config_parser.py b/src/codeflash_python/code_utils/config_parser.py new file mode 100644 index 000000000..956f09288 --- /dev/null +++ b/src/codeflash_python/code_utils/config_parser.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, cast + +import tomlkit + +PYPROJECT_TOML_CACHE: dict[Path, Path] = {} +ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {} + + +def find_pyproject_toml(config_file: Path | None = None) -> Path: + # Find the pyproject.toml or codeflash.toml file on the root of the project + + if config_file is not None: + config_file = Path(config_file) + if config_file.suffix.lower() != ".toml": + msg = f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml" + raise ValueError(msg) + if not config_file.exists(): + msg = f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml" + raise ValueError(msg) + return config_file + dir_path = Path.cwd() + cur_path = dir_path + # see if it was encountered before in search + if cur_path in PYPROJECT_TOML_CACHE: + return PYPROJECT_TOML_CACHE[cur_path] + # map current path to closest file - check both pyproject.toml and codeflash.toml + while dir_path != dir_path.parent: + # First check pyproject.toml (Python projects) + config_file = dir_path / "pyproject.toml" + if config_file.exists(): + PYPROJECT_TOML_CACHE[cur_path] = config_file + return config_file + # Then check codeflash.toml (alternative config format) + config_file = dir_path / "codeflash.toml" + if config_file.exists(): + PYPROJECT_TOML_CACHE[cur_path] = config_file + return config_file + # Search in parent directories + dir_path = dir_path.parent + msg = f"Could not find pyproject.toml or codeflash.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to the config file with the --config-file argument." + + raise ValueError(msg) from None + + +def get_all_closest_config_files() -> list[Path]: + all_closest_config_files = [] + for file_type in ["pyproject.toml", "pytest.ini", ".pytest.ini", "tox.ini", "setup.cfg"]: + closest_config_file = find_closest_config_file(file_type) + if closest_config_file: + all_closest_config_files.append(closest_config_file) + return all_closest_config_files + + +def find_closest_config_file(file_type: str) -> Path | None: + # Find the closest pyproject.toml, pytest.ini, tox.ini, or setup.cfg file on the root of the project + dir_path = Path.cwd() + cur_path = dir_path + if cur_path in ALL_CONFIG_FILES and file_type in ALL_CONFIG_FILES[cur_path]: + return ALL_CONFIG_FILES[cur_path][file_type] + while dir_path != dir_path.parent: + config_file = dir_path / file_type + if config_file.exists(): + if cur_path not in ALL_CONFIG_FILES: + ALL_CONFIG_FILES[cur_path] = {} + ALL_CONFIG_FILES[cur_path][file_type] = config_file + return config_file + # Search for pyproject.toml in the parent directories + dir_path = dir_path.parent + return None + + +def find_conftest_files(test_paths: list[Path]) -> list[Path]: + list_of_conftest_files = set() + for test_path in test_paths: + # Find the conftest file on the root of the project + dir_path = Path.cwd() + cur_path = test_path + while cur_path != dir_path: + config_file = cur_path / "conftest.py" + if config_file.exists(): + list_of_conftest_files.add(config_file) + # Search for conftest.py in the parent directories + cur_path = cur_path.parent + return list(list_of_conftest_files) + + +def parse_config_file( + config_file_path: Path | None = None, override_formatter_check: bool = False +) -> tuple[dict[str, Any], Path]: + # Fall back to pyproject.toml + config_file_path = find_pyproject_toml(config_file_path) + try: + with config_file_path.open("rb") as f: + data = tomlkit.parse(f.read()) + except tomlkit.exceptions.ParseError as e: # type: ignore[attr-defined] + msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}" + raise ValueError(msg) from None + + try: + tool = data["tool"] + assert isinstance(tool, dict) + tool = cast("dict[str, Any]", tool) + config = tool["codeflash"] + except tomlkit.exceptions.NonExistentKey as e: # type: ignore[attr-defined] + msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to add Codeflash config." + raise ValueError(msg) from e + assert isinstance(config, dict) + config = cast("dict[str, Any]", config) + + # Preserve language field if present + # default values: + path_keys = ["module-root", "tests-root", "benchmarks-root"] + path_list_keys = ["ignore-paths"] + str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} + bool_keys = { + "override-fixtures": False, + "disable-telemetry": False, + "disable-imports-sorting": False, + "benchmark": False, + } + # Note: formatter-cmds defaults to empty list. Black is typically detected by the project detector. + list_str_keys = {"formatter-cmds": []} + + for key, default_value in str_keys.items(): + if key in config: + config[key] = str(config[key]) + else: + config[key] = default_value + for key, default_value in bool_keys.items(): + if key in config: + config[key] = bool(config[key]) + else: + config[key] = default_value + for key in path_keys: + if key in config: + config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve()) + for key, default_value in list_str_keys.items(): + if key in config: + config[key] = [str(cmd) for cmd in config[key]] + else: + config[key] = default_value + + for key in path_list_keys: + if key in config: + config[key] = [str((Path(config_file_path).parent / path).resolve()) for path in config[key]] + else: + config[key] = [] + + # see if this is happening during GitHub actions setup + formatter_cmds = config.get("formatter-cmds") + if formatter_cmds and len(formatter_cmds) > 0 and not override_formatter_check: + assert formatter_cmds[0] != "your-formatter $file", ( + "The formatter command is not set correctly in pyproject.toml. Please set the " + "formatter command in the 'formatter-cmds' key. More info - https://docs.codeflash.ai/configuration" + ) + for key in list(config.keys()): + if "-" in key: + config[key.replace("-", "_")] = config[key] + del config[key] + + return config, config_file_path diff --git a/src/codeflash_python/code_utils/env_utils.py b/src/codeflash_python/code_utils/env_utils.py new file mode 100644 index 000000000..e9636f3d0 --- /dev/null +++ b/src/codeflash_python/code_utils/env_utils.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import json +import logging +import os +import shlex +import shutil +import tempfile +from functools import lru_cache +from pathlib import Path +from typing import Any + +from codeflash_python.code_utils.code_utils import exit_with_message +from codeflash_python.code_utils.formatter import format_code +from codeflash_python.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc + +logger = logging.getLogger("codeflash_python") + + +def check_formatter_installed( + formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python" +) -> bool: + if not formatter_cmds or formatter_cmds[0] == "disabled": + return True + first_cmd = formatter_cmds[0] + cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd] + + if not cmd_tokens: + return True + + exe_name = cmd_tokens[0] + command_str = " ".join(formatter_cmds).replace(" $file", "") + + if shutil.which(exe_name) is None: + logger.error( + "Could not find formatter: %s\nPlease install it or update 'formatter-cmds' in your codeflash configuration", + command_str, + ) + return False + + tmp_code = """print("hello world")""" + + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmp_file = Path(tmpdir) / "test_codeflash_formatter.py" + tmp_file.write_text(tmp_code, encoding="utf-8") + format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False) + return True + except FileNotFoundError: + logger.error( # noqa: TRY400 + "Could not find formatter: %s\nPlease install it or update 'formatter-cmds' in your codeflash configuration", + command_str, + ) + return False + except Exception as e: + logger.exception("Formatter failed to run: %s\nError: %s", command_str, e) + return False + + +@lru_cache(maxsize=1) +def get_codeflash_api_key() -> str: + # Check environment variable first + env_api_key = os.environ.get("CODEFLASH_API_KEY") + shell_api_key = read_api_key_from_shell_config() + logger.debug( + "env_utils.py:get_codeflash_api_key - env_api_key: %s, shell_api_key: %s", + "***" + env_api_key[-4:] if env_api_key else None, + "***" + shell_api_key[-4:] if shell_api_key else None, + ) + # If we have an env var but it's not in shell config, save it for persistence + if env_api_key and not shell_api_key: + try: + logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config") + result = save_api_key_to_rc(env_api_key) + if result.is_ok(): + logger.debug( + "env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: %s", + result.unwrap(), + ) + else: + logger.debug("env_utils.py:get_codeflash_api_key - Failed to save API key: %s", result.error) # type: ignore[unresolved-attribute] + except Exception as e: + logger.debug( + "env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: %s", e + ) + + # Prefer the shell configuration over environment variables for lsp, + # as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart + # within the same process, the environment variable could become outdated. + api_key = env_api_key or shell_api_key + + api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup]." # noqa + if not api_key: + msg = ( + "I didn't find a Codeflash API key in your environment.\nYou can generate one at " + "https://app.codeflash.ai/app/apikeys ,\nthen set it as a CODEFLASH_API_KEY environment variable.\n" + f"{api_secret_docs_message}" + ) + if is_repo_a_fork(): + msg = ( + "Codeflash API key not detected in your environment. It appears you're running Codeflash from a GitHub fork.\n" + "For external contributors, please ensure you've added your own API key to your fork's repository secrets and set it as the CODEFLASH_API_KEY environment variable.\n" + f"{api_secret_docs_message}" + ) + exit_with_message(msg) + raise OSError(msg) + if not api_key.startswith("cf-"): + msg = ( + f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' " + f"instead.\nYou can generate one at https://app.codeflash.ai/app/apikeys ,\nthen set it as a " + f"CODEFLASH_API_KEY environment variable." + ) + raise OSError(msg) + return api_key + + +def ensure_codeflash_api_key() -> bool: + try: + get_codeflash_api_key() + except OSError: + logger.error( # noqa: TRY400 + "Codeflash API key not found in your environment.\nYou can generate one at " + "https://app.codeflash.ai/app/apikeys ,\nthen set it as a CODEFLASH_API_KEY environment variable." + ) + return False + return True + + +@lru_cache(maxsize=1) +def get_pr_number() -> int | None: + event_data = get_cached_gh_event_data() + pr_number = event_data.get("number") + if pr_number: + return int(pr_number) + + pr_number = os.environ.get("CODEFLASH_PR_NUMBER") + if pr_number: + return int(pr_number) + return None + + +def ensure_pr_number() -> bool: + if not get_pr_number(): + msg = ( + "Codeflash couldn't detect your pull request number. Are you running Codeflash within a GitHub Action?" + "If not, please set the CODEFLASH_PR_NUMBER environment variable to ensure Codeflash can comment on the correct PR." + ) + raise OSError(msg) + return True + + +@lru_cache(maxsize=1) +def is_end_to_end() -> bool: + return bool(os.environ.get("CODEFLASH_END_TO_END")) + + +@lru_cache(maxsize=1) +def get_cached_gh_event_data() -> dict[str, Any]: + event_path = os.getenv("GITHUB_EVENT_PATH") + if not event_path: + return {} + with open(event_path, encoding="utf-8") as f: + result: dict[str, Any] = json.load(f) + return result + + +def is_repo_a_fork() -> bool: + event = get_cached_gh_event_data() + return bool(event.get("pull_request", {}).get("head", {}).get("repo", {}).get("fork", False)) + + +@lru_cache(maxsize=1) +def is_ci() -> bool: + """Check if running in a CI environment.""" + return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS")) + + +def is_pr_draft() -> bool: + """Check if the PR is draft. in the github action context.""" + event = get_cached_gh_event_data() + return bool(event.get("pull_request", {}).get("draft", False)) diff --git a/src/codeflash_python/code_utils/formatter.py b/src/codeflash_python/code_utils/formatter.py new file mode 100644 index 000000000..1a951c7b5 --- /dev/null +++ b/src/codeflash_python/code_utils/formatter.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import difflib +import logging +import os +import re +import shlex +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +import isort + +logger = logging.getLogger("codeflash_python") + + +def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str: + line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") + + def split_lines(text: str) -> list[str]: + lines = [match[0] for match in line_pattern.finditer(text)] + if lines and lines[-1] == "": + lines.pop() + return lines + + original_lines = split_lines(original) + modified_lines = split_lines(modified) + + diff_output = [] + for line in difflib.unified_diff(original_lines, modified_lines, fromfile=from_file, tofile=to_file, n=5): + if line.endswith("\n"): + diff_output.append(line) + else: + diff_output.append(line + "\n") + diff_output.append("\\ No newline at end of file\n") + + return "".join(diff_output) + + +def apply_formatter_cmds( + cmds: list[str], path: Path, test_dir_str: str | None, print_status: bool, exit_on_failure: bool = True +) -> tuple[Path, str, bool]: + if not path.exists(): + msg = f"File {path} does not exist. Cannot apply formatter commands." + raise FileNotFoundError(msg) + + file_path = path + if test_dir_str: + file_path = Path(test_dir_str) / "temp.py" + shutil.copy2(path, file_path) + + file_token = "$file" # noqa: S105 + + changed = False + for command in cmds: + formatter_cmd_list = shlex.split(command, posix=os.name != "nt") + formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list] + try: + result = subprocess.run(formatter_cmd_list, capture_output=True, check=False) + if result.returncode == 0: + if print_status: + logger.info("Formatted Successfully with: %s", command.replace("$file", path.name)) + changed = True + else: + logger.error("Failed to format code with %s", " ".join(formatter_cmd_list)) + except FileNotFoundError as e: + command_str = " ".join(str(part) for part in formatter_cmd_list) + logger.warning("Formatter command not found: %s", command_str) + if exit_on_failure: + raise e from None + + return file_path, file_path.read_text(encoding="utf8"), changed + + +def get_diff_lines_count(diff_output: str) -> int: + lines = diff_output.split("\n") + + def is_diff_line(line: str) -> bool: + return line.startswith(("+", "-")) and not line.startswith(("+++", "---")) + + diff_lines = [line for line in lines if is_diff_line(line)] + return len(diff_lines) + + +def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str: + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + if formatter_name == "disabled": # nothing to do if no formatter provided + return re.sub(r"\n{2,}", "\n\n", generated_test_source) + with tempfile.TemporaryDirectory() as test_dir_str: + # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines + original_temp = Path(test_dir_str) / "original_temp.py" + original_temp.write_text(generated_test_source, encoding="utf8") + _, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False + ) + if not changed: + return re.sub(r"\n{2,}", "\n\n", formatted_code) + return formatted_code + + +def format_code( + formatter_cmds: list[str], + path: str | Path, + optimized_code: str = "", + check_diff: bool = False, + print_status: bool = True, + exit_on_failure: bool = True, +) -> str: + if isinstance(path, str): + path = Path(path) + + # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + if formatter_name == "disabled": + return path.read_text(encoding="utf8") + + with tempfile.TemporaryDirectory() as test_dir_str: + original_code = path.read_text(encoding="utf8") + original_code_lines = len(original_code.split("\n")) + + if check_diff and original_code_lines > 50: + # we don't count the formatting diff for the optimized function as it should be well-formatted + original_code_without_opfunc = original_code.replace(optimized_code, "") + + original_temp = Path(test_dir_str) / "original_temp.py" + original_temp.write_text(original_code_without_opfunc, encoding="utf8") + + formatted_temp, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure + ) + + if not changed: + logger.warning( + "No changes detected in %s after formatting, are you sure you have valid formatter commands?", path + ) + return original_code + + diff_output = generate_unified_diff( + original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp) + ) + diff_lines_count = get_diff_lines_count(diff_output) + + max_diff_lines = min(int(original_code_lines * 0.3), 50) + + if diff_lines_count > max_diff_lines: + logger.warning( + "Skipping formatting %s: %s lines would change (max: %s)", path, diff_lines_count, max_diff_lines + ) + return original_code + + _, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure + ) + + if not changed: + logger.warning( + "No changes detected in %s after formatting, are you sure you have valid formatter commands?", path + ) + return original_code + + logger.debug("Formatted %s with commands: %s", path, formatter_cmds) + return formatted_code + + +def sort_imports(code: str, **kwargs: Any) -> str: + try: + # Deduplicate and sort imports, modify the code in memory, not on disk + sorted_code = isort.code(code, **kwargs) + except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere + logger.exception("Failed to sort imports with isort.") + return code # Fall back to original code if isort fails + + return sorted_code diff --git a/src/codeflash_python/code_utils/git_utils.py b/src/codeflash_python/code_utils/git_utils.py new file mode 100644 index 000000000..25d2e7f32 --- /dev/null +++ b/src/codeflash_python/code_utils/git_utils.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import logging +import os +import sys +import time +from functools import cache +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING + +import git +from unidiff import PatchSet + +if TYPE_CHECKING: + from git import Repo + + +logger = logging.getLogger("codeflash_python") + + +def get_git_diff( + repo_directory: Path | None = None, *, only_this_commit: str | None = None, uncommitted_changes: bool = False +) -> dict[str, list[int]]: + if repo_directory is None: + repo_directory = Path.cwd() + repository = git.Repo(repo_directory, search_parent_directories=True) + commit = repository.head.commit + if only_this_commit: + uni_diff_text = repository.git.diff( + only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True + ) + elif uncommitted_changes: + uni_diff_text = repository.git.diff("HEAD", ignore_blank_lines=True, ignore_space_at_eol=True) + else: + uni_diff_text = repository.git.diff( + commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True + ) + patch_set = PatchSet(StringIO(uni_diff_text)) + change_list: dict[str, list[int]] = {} # list of changes + for patched_file in patch_set: + file_path: Path = Path(patched_file.path) + if file_path.suffix != ".py": + continue + file_path = Path(repository.working_dir) / file_path + logger.debug("file name: %s", file_path) + + add_line_no: list[int] = [ + line.target_line_no for hunk in patched_file for line in hunk if line.is_added and line.value.strip() != "" + ] # the row number of deleted lines + + logger.debug("added lines: %s", add_line_no) + + del_line_no: list[int] = [ + line.source_line_no + for hunk in patched_file + for line in hunk + if line.is_removed and line.value.strip() != "" + ] # the row number of added lines + + logger.debug("deleted lines: %s", del_line_no) + + if not add_line_no and del_line_no: + # Deletion-only changes: use hunk target start lines so we can still + # match the surrounding function in the current (target) file. + add_line_no = [hunk.target_start for hunk in patched_file] + change_list[str(file_path)] = add_line_no + return change_list + + +def get_current_branch(repo: Repo | None = None) -> str: + """Return the name of the current branch in the given repository. + + Handles detached HEAD state and other edge cases by falling back to + the default branch (main or master) or "main" if no default branch exists. + + :param repo: An optional Repo object. If not provided, the function will + search for a repository in the current and parent directories. + :return: The name of the current branch, or "main" if HEAD is detached or + the branch cannot be determined. + """ + repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + + # Check if HEAD is detached (active_branch will be None) + if repository.head.is_detached: + logger.warning( + "HEAD is detached. Cannot determine current branch. Falling back to 'main'. " + "Consider checking out a branch before running Codeflash." + ) + # Try to find the default branch (main or master) + for default_branch in ["main", "master"]: + try: + if default_branch in repository.branches: + logger.info("Using '%s' as fallback branch.", default_branch) + return default_branch + except Exception as e: + logger.debug("Error checking for branch '%s': %s", default_branch, e) + continue + # If no default branch found, return "main" as a safe default + return "main" + + # HEAD is not detached, safe to access active_branch + try: + return repository.active_branch.name + except (AttributeError, TypeError) as e: + logger.warning( + "Could not determine active branch: %s. Falling back to 'main'. " + "This may indicate the repository is in an unusual state.", + e, + ) + return "main" + + +def get_remote_url(repo: Repo | None = None, git_remote: str | None = "origin") -> str: + repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + remote_name = git_remote if git_remote is not None else "origin" + return repository.remote(name=remote_name).url + + +def get_git_remotes(repo: Repo) -> list[str]: + repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + return [remote.name for remote in repository.remotes] + + +@cache +def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]: + remote_url = get_remote_url(repo, git_remote) # call only once + remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url + # remote_url = get_remote_url(repo, git_remote).removesuffix(".git") if remote_url.endswith(".git") else remote_url + remote_url = remote_url.rstrip("/") + split_url = remote_url.split("/") + repo_owner_with_github, repo_name = split_url[-2], split_url[-1] + repo_owner = repo_owner_with_github.split(":")[1] if ":" in repo_owner_with_github else repo_owner_with_github + return repo_owner, repo_name + + +def git_root_dir(repo: Repo | None = None) -> Path: + repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + return Path(repository.working_dir) + + +def check_running_in_git_repo(module_root: str) -> bool: + try: + _ = git.Repo(module_root, search_parent_directories=True).git_dir + except git.InvalidGitRepositoryError: + return False + else: + return True + + +def confirm_proceeding_with_no_git_repo() -> str | bool: + assert sys.__stdin__ is not None + if sys.__stdin__.isatty(): + return input( + "WARNING: I did not find a git repository for your code. If you proceed with running codeflash, " + "optimized code will be written over your current code and you could irreversibly lose your current code. Proceed? [y/N] " + ).strip().lower() in ("y", "yes") + # continue running on non-interactive environments, important for GitHub actions + return True + + +def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", *, wait_for_push: bool = False) -> bool: + # Check if HEAD is detached + if repo.head.is_detached: + logger.warning("HEAD is detached. Cannot push branch. Please check out a branch before creating a PR.") + return False + + # Safe to access active_branch when HEAD is not detached + try: + current_branch = repo.active_branch + current_branch_name = current_branch.name + except (AttributeError, TypeError) as e: + logger.warning("Could not determine active branch: %s. Cannot push branch.", e) + return False + + remote_name = git_remote if git_remote is not None else "origin" + remote = repo.remote(name=remote_name) + + # Check if the branch is pushed + if f"{git_remote}/{current_branch_name}" not in repo.refs: + logger.warning("The branch '%s' is not pushed to the remote repository.", current_branch_name) + assert sys.__stdin__ is not None + if not sys.__stdin__.isatty(): + logger.warning("Non-interactive shell detected. Branch will not be pushed.") + return False + if sys.__stdin__.isatty() and input( + f"In order for me to create PRs, your current branch needs to be pushed. Do you want to push " + f"the branch '{current_branch_name}' to the remote repository? [y/N] " + ).strip().lower() in ("y", "yes"): + remote.push(current_branch) # type: ignore[arg-type] + logger.info("Branch '%s' has been pushed to %s.", current_branch_name, git_remote) + if wait_for_push: + time.sleep(3) # adding this to give time for the push to register with GitHub, + # so that our modifications to it are not rejected + return True + logger.info("Branch '%s' has not been pushed to %s.", current_branch_name, git_remote) + return False + logger.debug("The branch '%s' is present in the remote repository.", current_branch_name) + return True + + +def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None: + """Return the author's name of the last commit in the current branch if PR_NUMBER is set. + + Otherwise, return None. + """ + if "PR_NUMBER" not in os.environ: + return None + try: + repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + last_commit = repository.head.commit + except Exception: + logger.exception("Failed to get last commit author.") + return None + else: + return last_commit.author.name diff --git a/src/codeflash_python/code_utils/shell_utils.py b/src/codeflash_python/code_utils/shell_utils.py new file mode 100644 index 000000000..8d86dee8d --- /dev/null +++ b/src/codeflash_python/code_utils/shell_utils.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +import contextlib +import logging +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.danom import Err, Ok +from codeflash_python.code_utils.compat import LF + +if TYPE_CHECKING: + from collections.abc import Mapping + + from codeflash_core.danom import Result + + +# PowerShell patterns and prefixes + +logger = logging.getLogger("codeflash_python") + +POWERSHELL_RC_EXPORT_PATTERN = re.compile( + r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE +) +POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = " + +# CMD/Batch patterns and prefixes +CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE) +CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY=" + +# Unix shell patterns and prefixes +UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE) +UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY=" + + +def is_powershell() -> bool: + """Detect if we're running in PowerShell on Windows. + + Uses multiple heuristics: + 1. PSModulePath environment variable (PowerShell always sets this) + 2. COMSPEC pointing to powershell.exe + 3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell) + """ + if os.name != "nt": + return False + + # Primary check: PSMODULEPATH is set by PowerShell + # This is the most reliable indicator as PowerShell always sets this + ps_module_path = os.environ.get("PSMODULEPATH") + if ps_module_path: + logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath") + return True + + # Secondary check: COMSPEC points to PowerShell + comspec = os.environ.get("COMSPEC", "").lower() + if "powershell" in comspec: + logger.debug("shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: %s", comspec) + return True + + # Tertiary check: Windows Terminal often uses PowerShell by default + # But we only use this if other indicators are ambiguous + term_program = os.environ.get("TERM_PROGRAM", "").lower() + # Check if we can find evidence of CMD (cmd.exe in COMSPEC) + # If not, assume PowerShell for Windows Terminal + if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec: + logger.debug("shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: %s)", comspec) + return True + + logger.debug("shell_utils.py:is_powershell - Not PowerShell (COMSPEC: %s, TERM_PROGRAM: %s)", comspec, term_program) + return False + + +def read_api_key_from_shell_config() -> str | None: + """Read API key from shell configuration file.""" + shell_rc_path = get_shell_rc_path() + # Ensure shell_rc_path is a Path object for consistent handling + if not isinstance(shell_rc_path, Path): + shell_rc_path = Path(shell_rc_path) + + # Determine the correct pattern to use based on the file extension and platform + if os.name == "nt": # Windows + pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN + else: # Unix-like + pattern = UNIX_RC_EXPORT_PATTERN + + try: + # Convert Path to string using as_posix() for cross-platform path compatibility + shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path) + with open(shell_rc_path_str, encoding="utf8") as shell_rc: + shell_contents = shell_rc.read() + matches = pattern.findall(shell_contents) + if matches: + logger.debug("shell_utils.py:read_api_key_from_shell_config - Found API key in file: %s", shell_rc_path) + return matches[-1] + logger.debug("shell_utils.py:read_api_key_from_shell_config - No API key found in file: %s", shell_rc_path) + return None + except FileNotFoundError: + logger.debug("shell_utils.py:read_api_key_from_shell_config - File not found: %s", shell_rc_path) + return None + except Exception as e: + logger.debug("shell_utils.py:read_api_key_from_shell_config - Error reading file: %s", e) + return None + + +def get_shell_rc_path() -> Path: + """Get the path to the user's shell configuration file.""" + if os.name == "nt": # Windows + if is_powershell(): + return Path.home() / "codeflash_env.ps1" + return Path.home() / "codeflash_env.bat" + shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1] + shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get( + shell, ".bashrc" + ) # map each shell to its config file and default to .bashrc + return Path.home() / shell_rc_filename + + +def get_api_key_export_line(api_key: str) -> str: + """Get the appropriate export line based on the shell type.""" + if os.name == "nt": # Windows + if is_powershell(): + return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"' + return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"' + # Unix-like + return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"' + + +def save_api_key_to_rc(api_key: str) -> Result[str, str]: + """Save API key to the appropriate shell configuration file.""" + shell_rc_path = get_shell_rc_path() + # Ensure shell_rc_path is a Path object for consistent handling + if not isinstance(shell_rc_path, Path): + shell_rc_path = Path(shell_rc_path) + api_key_line = get_api_key_export_line(api_key) + + logger.debug("shell_utils.py:save_api_key_to_rc - Saving API key to: %s", shell_rc_path) + logger.debug("shell_utils.py:save_api_key_to_rc - API key line format: %s...", api_key_line[:30]) + + # Determine the correct pattern to use for replacement + if os.name == "nt": # Windows + if is_powershell(): + pattern = POWERSHELL_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern") + else: + pattern = CMD_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern") + else: # Unix-like + pattern = UNIX_RC_EXPORT_PATTERN + logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern") + + try: + # Create directory if it doesn't exist (ignore errors - file operation will fail if needed) + # Directory creation failed, but we'll still try to open the file + # The file operation itself will raise the appropriate exception if there are permission issues + with contextlib.suppress(OSError, PermissionError): + shell_rc_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert Path to string using as_posix() for cross-platform path compatibility + shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path) + + # Try to open in r+ mode (read and write in single operation) + # Handle FileNotFoundError if file doesn't exist (r+ requires file to exist) + try: + with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: + shell_contents = shell_file.read() + logger.debug("shell_utils.py:save_api_key_to_rc - Read existing file, length: %s", len(shell_contents)) + + # Initialize empty file with header for batch files if needed + if not shell_contents: + logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing") + if os.name == "nt" and not is_powershell(): + shell_contents = "@echo off" + logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file") + + # Check if API key already exists in the current file + matches = pattern.findall(shell_contents) + existing_in_file = bool(matches) + logger.debug("shell_utils.py:save_api_key_to_rc - Existing key in file: %s", existing_in_file) + + if existing_in_file: + # Replace the existing API key line in this file + updated_shell_contents = re.sub(pattern, api_key_line, shell_contents) + action = "Updated CODEFLASH_API_KEY in" + logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key") + else: + # Append the new API key line + if shell_contents and not shell_contents.endswith(LF): + updated_shell_contents = shell_contents + LF + api_key_line + LF + else: + updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}" + action = "Added CODEFLASH_API_KEY to" + logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key") + + # Write the updated contents + shell_file.seek(0) + shell_file.write(updated_shell_contents) + shell_file.truncate() + except FileNotFoundError: + # File doesn't exist, create it first with initial content + logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new") + shell_contents = "" + # Initialize with header for batch files if needed + if os.name == "nt" and not is_powershell(): + shell_contents = "@echo off" + logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file") + + # Create the file by opening in write mode + with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: + shell_file.write(shell_contents) + + # Re-open in r+ mode to add the API key (r+ allows both read and write) + with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: + # Append the new API key line + updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}" + action = "Added CODEFLASH_API_KEY to" + logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file") + + # Write the updated contents + shell_file.seek(0) + shell_file.write(updated_shell_contents) + shell_file.truncate() + + logger.debug("shell_utils.py:save_api_key_to_rc - Successfully wrote to %s", shell_rc_path) + + return Ok(f"✅ {action} {shell_rc_path}") + except PermissionError as e: + logger.debug("shell_utils.py:save_api_key_to_rc - Permission error: %s", e) + return Err( + f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}" + f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}" + ) + except Exception as e: + logger.debug("shell_utils.py:save_api_key_to_rc - Error: %s", e) + return Err( + f"💡 I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}" + f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}" + f"{LF}{api_key_line}{LF}" + ) + + +def make_env_with_project_root(project_root: Path | str) -> dict[str, str]: + """Return a copy of os.environ with project_root prepended to PYTHONPATH.""" + env = os.environ.copy() + project_root_str = str(project_root) + pythonpath = env.get("PYTHONPATH", "") + if pythonpath: + env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}" + else: + env["PYTHONPATH"] = project_root_str + return env + + +def get_cross_platform_subprocess_run_args( + cwd: Path | str | None = None, + env: Mapping[str, str] | None = None, + timeout: float | None = None, + check: bool = False, + text: bool = True, + capture_output: bool = True, +) -> dict[str, object]: + run_args: dict[str, object] = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check} + # When text=True, use errors='replace' to handle non-UTF-8 bytes gracefully + # instead of raising UnicodeDecodeError + if text: + run_args["errors"] = "replace" + if sys.platform == "win32": + creationflags = subprocess.CREATE_NEW_PROCESS_GROUP + run_args["creationflags"] = creationflags + run_args["stdout"] = subprocess.PIPE + run_args["stderr"] = subprocess.PIPE + run_args["stdin"] = subprocess.DEVNULL + else: + run_args["capture_output"] = capture_output + + return run_args diff --git a/src/codeflash_python/code_utils/tabulate.py b/src/codeflash_python/code_utils/tabulate.py new file mode 100644 index 000000000..e43ea68ab --- /dev/null +++ b/src/codeflash_python/code_utils/tabulate.py @@ -0,0 +1,915 @@ +"""Adapted from tabulate (https://github.com/astanin/python-tabulate) written by Sergey Astanin and contributors (MIT License).""" + +"""Pretty-print tabular data.""" +# ruff: noqa + +import dataclasses +import math +import re +import warnings +from collections import namedtuple +from collections.abc import Iterable +from functools import reduce +from itertools import chain +from itertools import zip_longest as izip_longest + +import wcwidth # optional wide-character (CJK) support + +__all__ = ["tabulate", "tabulate_formats"] + +# minimum extra space in headers +MIN_PADDING = 2 + +_DEFAULT_FLOATFMT = "g" +_DEFAULT_INTFMT = "" +_DEFAULT_MISSINGVAL = "" +# default align will be overwritten by "left", "center" or "decimal" +# depending on the formatter +_DEFAULT_ALIGN = "default" + + +# if True, enable wide-character (CJK) support +WIDE_CHARS_MODE = wcwidth is not None + +# Constant that can be used as part of passed rows to generate a separating line +# It is purposely an unprintable character, very unlikely to be used in a table +SEPARATING_LINE = "\001" + +Line = namedtuple("Line", ["begin", "hline", "sep", "end"]) # noqa: PYI024 + + +DataRow = namedtuple("DataRow", ["begin", "sep", "end"]) # noqa: PYI024 + +TableFormat = namedtuple( # noqa: PYI024 + "TableFormat", + [ + "lineabove", + "linebelowheader", + "linebetweenrows", + "linebelow", + "headerrow", + "datarow", + "padding", + "with_header_hide", + ], +) + + +def _is_separating_line_value(value): + return type(value) is str and value.strip() == SEPARATING_LINE + + +def _is_separating_line(row): + row_type = type(row) + is_sl = (row_type == list or row_type == str) and ( + (len(row) >= 1 and _is_separating_line_value(row[0])) or (len(row) >= 2 and _is_separating_line_value(row[1])) + ) + + return is_sl + + +def _pipe_segment_with_colons(align, colwidth): + """Return a segment of a horizontal line with optional colons which + indicate column's alignment (as in `pipe` output format). + """ + w = colwidth + if align in {"right", "decimal"}: + return ("-" * (w - 1)) + ":" + if align == "center": + return ":" + ("-" * (w - 2)) + ":" + if align == "left": + return ":" + ("-" * (w - 1)) + return "-" * w + + +def _pipe_line_with_colons(colwidths, colaligns): + """Return a horizontal line with optional colons to indicate column's + alignment (as in `pipe` output format). + """ + if not colaligns: # e.g. printing an empty data frame (github issue #15) + colaligns = [""] * len(colwidths) + segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] + return "|" + "|".join(segments) + "|" + + +_table_formats = { + "simple": TableFormat( + lineabove=Line("", "-", " ", ""), + linebelowheader=Line("", "-", " ", ""), + linebetweenrows=None, + linebelow=Line("", "-", " ", ""), + headerrow=DataRow("", " ", ""), + datarow=DataRow("", " ", ""), + padding=0, + with_header_hide=["lineabove", "linebelow"], + ), + "pipe": TableFormat( + lineabove=_pipe_line_with_colons, + linebelowheader=_pipe_line_with_colons, + linebetweenrows=None, + linebelow=None, + headerrow=DataRow("|", "|", "|"), + datarow=DataRow("|", "|", "|"), + padding=1, + with_header_hide=["lineabove"], + ), +} + +tabulate_formats = sorted(_table_formats.keys()) + +# The table formats for which multiline cells will be folded into subsequent +# table rows. The key is the original format specified at the API. The value is +# the format that will be used to represent the original format. +multiline_formats = {"plain": "plain", "pipe": "pipe"} + +_multiline_codes = re.compile(r"\r|\n|\r\n") +_multiline_codes_bytes = re.compile(b"\r|\n|\r\n") + +_esc = r"\x1b" +_csi = rf"{_esc}\[" +_osc = rf"{_esc}\]" +_st = rf"{_esc}\\" + +_ansi_escape_pat = rf""" + ( + # terminal colors, etc + {_csi} # CSI + [\x30-\x3f]* # parameter bytes + [\x20-\x2f]* # intermediate bytes + [\x40-\x7e] # final byte + | + # terminal hyperlinks + {_osc}8; # OSC opening + (\w+=\w+:?)* # key=value params list (submatch 2) + ; # delimiter + ([^{_esc}]+) # URI - anything but ESC (submatch 3) + {_st} # ST + ([^{_esc}]+) # link text - anything but ESC (submatch 4) + {_osc}8;;{_st} # "closing" OSC sequence + ) +""" +_ansi_codes = re.compile(_ansi_escape_pat, re.VERBOSE) +_ansi_codes_bytes = re.compile(_ansi_escape_pat.encode("utf8"), re.VERBOSE) +_ansi_color_reset_code = "\033[0m" + +_float_with_thousands_separators = re.compile(r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$") + + +def _isnumber_with_thousands_separator(string): + try: + string = string.decode() + except (UnicodeDecodeError, AttributeError): + pass + + return bool(re.match(_float_with_thousands_separators, string)) + + +def _isconvertible(conv, string): + try: + conv(string) + return True + except (ValueError, TypeError): + return False + + +def _isnumber(string): + return ( + # fast path + type(string) in {float, int} + # covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type + # convertible to int/float. + or ( + _isconvertible(float, string) + and ( + # some other type convertible to float + not isinstance(string, (str, bytes)) + # or, a numeric string eg. "1e1...", "NaN", ..., but isn't + # just an over/underflow + or ( + not (math.isinf(float(string)) or math.isnan(float(string))) + or string.lower() in {"inf", "-inf", "nan"} + ) + ) + ) + ) + + +def _isint(string, inttype=int): + return ( + type(string) is inttype + or ( + (hasattr(string, "is_integer") or hasattr(string, "__array__")) + and str(type(string)).startswith("= 0: + return len(string) - pos - 1 + return -1 # no point + return -1 # not a number + + +def _padleft(width, s): + fmt = "{0:>%ds}" % width + return fmt.format(s) + + +def _padright(width, s): + fmt = "{0:<%ds}" % width + return fmt.format(s) + + +def _padboth(width, s): + fmt = "{0:^%ds}" % width + return fmt.format(s) + + +def _padnone(ignore_width, s): + return s + + +def _strip_ansi(s): + if isinstance(s, str): + return _ansi_codes.sub(r"\4", s) + # a bytestring + return _ansi_codes_bytes.sub(rb"\4", s) + + +def _visible_width(s): + if wcwidth is not None and WIDE_CHARS_MODE: + len_fn = wcwidth.wcswidth + else: + len_fn = len + if isinstance(s, (str, bytes)): + return len_fn(_strip_ansi(s)) + return len_fn(str(s)) + + +def _is_multiline(s): + if isinstance(s, str): + return bool(re.search(_multiline_codes, s)) + # a bytestring + return bool(re.search(_multiline_codes_bytes, s)) + + +def _multiline_width(multiline_s, line_width_fn=len): + return max(map(line_width_fn, re.split("[\r\n]", multiline_s))) + + +def _choose_width_fn(has_invisible, enable_widechars, is_multiline): + if has_invisible: + line_width_fn = _visible_width + elif enable_widechars: # optional wide-character support if available + line_width_fn = wcwidth.wcswidth + else: + line_width_fn = len + if is_multiline: + width_fn = lambda s: _multiline_width(s, line_width_fn) # noqa + else: + width_fn = line_width_fn + return width_fn + + +def _align_column_choose_padfn(strings, alignment, has_invisible, preserve_whitespace): + if alignment == "right": + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padleft + elif alignment == "center": + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padboth + elif alignment == "decimal": + if has_invisible: + decimals = [_afterpoint(_strip_ansi(s)) for s in strings] + else: + decimals = [_afterpoint(s) for s in strings] + maxdecimals = max(decimals) + strings = [s + (maxdecimals - decs) * " " for s, decs in zip(strings, decimals)] + padfn = _padleft + elif not alignment: + padfn = _padnone + else: + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padright + return strings, padfn + + +def _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline): + if has_invisible: + line_width_fn = _visible_width + elif enable_widechars: # optional wide-character support if available + line_width_fn = wcwidth.wcswidth + else: + line_width_fn = len + if is_multiline: + width_fn = lambda s: _align_column_multiline_width(s, line_width_fn) # noqa + else: + width_fn = line_width_fn + return width_fn + + +def _align_column_multiline_width(multiline_s, line_width_fn=len): + return list(map(line_width_fn, re.split("[\r\n]", multiline_s))) + + +def _flat_list(nested_list): + ret = [] + for item in nested_list: + if isinstance(item, list): + ret.extend(item) + else: + ret.append(item) + return ret + + +def _align_column( + strings, + alignment, + minwidth=0, + has_invisible=True, + enable_widechars=False, + is_multiline=False, + preserve_whitespace=False, +): + strings, padfn = _align_column_choose_padfn(strings, alignment, has_invisible, preserve_whitespace) + width_fn = _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline) + + s_widths = list(map(width_fn, strings)) + maxwidth = max(max(_flat_list(s_widths)), minwidth) + if is_multiline: + if not enable_widechars and not has_invisible: + padded_strings = ["\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) for ms in strings] + else: + # enable wide-character width corrections + s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings] + visible_widths = [[maxwidth - (w - l) for w, l in zip(mw, ml)] for mw, ml in zip(s_widths, s_lens)] + # wcswidth and _visible_width don't count invisible characters; + # padfn doesn't need to apply another correction + padded_strings = [ + "\n".join([padfn(w, s) for s, w in zip((ms.splitlines() or ms), mw)]) + for ms, mw in zip(strings, visible_widths) + ] + elif not enable_widechars and not has_invisible: + padded_strings = [padfn(maxwidth, s) for s in strings] + else: + # enable wide-character width corrections + s_lens = list(map(len, strings)) + visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)] + # wcswidth and _visible_width don't count invisible characters; + # padfn doesn't need to apply another correction + padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)] + return padded_strings + + +def _more_generic(type1, type2): + types = {type(None): 0, bool: 1, int: 2, float: 3, bytes: 4, str: 5} + invtypes = {5: str, 4: bytes, 3: float, 2: int, 1: bool, 0: type(None)} + moregeneric = max(types.get(type1, 5), types.get(type2, 5)) + return invtypes[moregeneric] + + +def _column_type(strings, has_invisible=True, numparse=True): + types = [_type(s, has_invisible, numparse) for s in strings] + return reduce(_more_generic, types, bool) + + +def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): + if val is None: + return missingval + if isinstance(val, (bytes, str)) and not val: + return "" + + if valtype is str: + return f"{val}" + if valtype is int: + if isinstance(val, str): + val_striped = val.encode("unicode_escape").decode("utf-8") + colored = re.search(r"(\\[xX]+[0-9a-fA-F]+\[\d+[mM]+)([0-9.]+)(\\.*)$", val_striped) + if colored: + total_groups = len(colored.groups()) + if total_groups == 3: + digits = colored.group(2) + if digits.isdigit(): + val_new = colored.group(1) + format(int(digits), intfmt) + colored.group(3) + val = val_new.encode("utf-8").decode("unicode_escape") + intfmt = "" + return format(val, intfmt) + if valtype is bytes: + try: + return str(val, "ascii") + except (TypeError, UnicodeDecodeError): + return str(val) + elif valtype is float: + is_a_colored_number = has_invisible and isinstance(val, (str, bytes)) + if is_a_colored_number: + raw_val = _strip_ansi(val) + formatted_val = format(float(raw_val), floatfmt) + return val.replace(raw_val, formatted_val) + if isinstance(val, str) and "," in val: + val = val.replace(",", "") # handle thousands-separators + return format(float(val), floatfmt) + else: + return f"{val}" + + +def _align_header(header, alignment, width, visible_width, is_multiline=False, width_fn=None): + """Pad string header to width chars given known visible_width of the header.""" + if is_multiline: + assert width_fn is not None + header_lines = re.split(_multiline_codes, header) + padded_lines = [_align_header(h, alignment, width, width_fn(h)) for h in header_lines] + return "\n".join(padded_lines) + # else: not multiline + ninvisible = len(header) - visible_width + width += ninvisible + if alignment == "left": + return _padright(width, header) + if alignment == "center": + return _padboth(width, header) + if not alignment: + return f"{header}" + return _padleft(width, header) + + +def _remove_separating_lines(rows): + if isinstance(rows, list): + separating_lines = [] + sans_rows = [] + for index, row in enumerate(rows): + if _is_separating_line(row): + separating_lines.append(index) + else: + sans_rows.append(row) + return sans_rows, separating_lines + return rows, None + + +def _bool(val): + """A wrapper around standard bool() which doesn't throw on NumPy arrays""" + try: + return bool(val) + except ValueError: # val is likely to be a numpy array with many elements + return False + + +def _normalize_tabular_data(tabular_data, headers, showindex="default"): + try: + bool(headers) + except ValueError: # numpy.ndarray, pandas.core.index.Index, ... + headers = list(headers) + + err_msg = ( + "\n\nTo build a table python-tabulate requires two-dimensional data " + "like a list of lists or similar." + "\nDid you forget a pair of extra [] or ',' in ()?" + ) + index = None + if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"): + # dict-like and pandas.DataFrame? + if callable(tabular_data.values): + # likely a conventional dict + keys = tabular_data.keys() + try: + rows = list(izip_longest(*tabular_data.values())) # columns have to be transposed + except TypeError: # not iterable + raise TypeError(err_msg) + + elif hasattr(tabular_data, "index"): + # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) + keys = list(tabular_data) + if showindex in {"default", "always", True} and tabular_data.index.name is not None: + if isinstance(tabular_data.index.name, list): + keys[:0] = tabular_data.index.name + else: + keys[:0] = [tabular_data.index.name] + vals = tabular_data.values # values matrix doesn't need to be transposed + # for DataFrames add an index per default + index = list(tabular_data.index) + rows = [list(row) for row in vals] + else: + raise ValueError("tabular data doesn't appear to be a dict or a DataFrame") + + if headers == "keys": + headers = list(map(str, keys)) # headers should be strings + + else: # it's a usual iterable of iterables, or a NumPy array, or an iterable of dataclasses + try: + rows = list(tabular_data) + except TypeError: # not iterable + raise TypeError(err_msg) + + if headers == "keys" and not rows: + # an empty table (issue #81) + headers = [] + elif headers == "keys" and hasattr(tabular_data, "dtype") and tabular_data.dtype.names: + # numpy record array + headers = tabular_data.dtype.names + elif headers == "keys" and len(rows) > 0 and isinstance(rows[0], tuple) and hasattr(rows[0], "_fields"): + # namedtuple + headers = list(map(str, rows[0]._fields)) + elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"): + # dict-like object + uniq_keys = set() # implements hashed lookup + keys = [] # storage for set + if headers == "firstrow": + firstdict = rows[0] if len(rows) > 0 else {} + keys.extend(firstdict.keys()) + uniq_keys.update(keys) + rows = rows[1:] + for row in rows: + for k in row.keys(): + # Save unique items in input order + if k not in uniq_keys: + keys.append(k) + uniq_keys.add(k) + if headers == "keys": + headers = keys + elif isinstance(headers, dict): + # a dict of headers for a list of dicts + headers = [headers.get(k, k) for k in keys] + headers = list(map(str, headers)) + elif headers == "firstrow": + if len(rows) > 0: + headers = [firstdict.get(k, k) for k in keys] + headers = list(map(str, headers)) + else: + headers = [] + elif headers: + raise ValueError("headers for a list of dicts is not a dict or a keyword") + rows = [[row.get(k) for k in keys] for row in rows] + + elif ( + headers == "keys" + and hasattr(tabular_data, "description") + and hasattr(tabular_data, "fetchone") + and hasattr(tabular_data, "rowcount") + ): + # Python Database API cursor object (PEP 0249) + # print tabulate(cursor, headers='keys') + headers = [column[0] for column in tabular_data.description] + + elif dataclasses is not None and len(rows) > 0 and dataclasses.is_dataclass(rows[0]): + # Python's dataclass + field_names = [field.name for field in dataclasses.fields(rows[0])] + if headers == "keys": + headers = field_names + rows = [[getattr(row, f) for f in field_names] for row in rows] + + elif headers == "keys" and len(rows) > 0: + # keys are column indices + headers = list(map(str, range(len(rows[0])))) + + # take headers from the first row if necessary + if headers == "firstrow" and len(rows) > 0: + if index is not None: + headers = [index[0]] + list(rows[0]) + index = index[1:] + else: + headers = rows[0] + headers = list(map(str, headers)) # headers should be strings + rows = rows[1:] + elif headers == "firstrow": + headers = [] + + headers = list(map(str, headers)) # type: ignore[arg-type] + # rows = list(map(list, rows)) + rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows)) + + # add or remove an index column + showindex_is_a_str = type(showindex) in {str, bytes} + if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str): + pass + + # pad with empty headers for initial columns if necessary + headers_pad = 0 + if headers and len(rows) > 0: + headers_pad = max(0, len(rows[0]) - len(headers)) + headers = [""] * headers_pad + headers + + return rows, headers, headers_pad + + +def _to_str(s, encoding="utf8", errors="ignore"): + if isinstance(s, bytes): + return s.decode(encoding=encoding, errors=errors) + return str(s) + + +def tabulate( + tabular_data, + headers=(), + tablefmt="simple", + floatfmt=_DEFAULT_FLOATFMT, + intfmt=_DEFAULT_INTFMT, + numalign=_DEFAULT_ALIGN, + stralign=_DEFAULT_ALIGN, + missingval=_DEFAULT_MISSINGVAL, + showindex="default", + disable_numparse=False, + colglobalalign=None, + colalign=None, + preserve_whitespace=False, + maxcolwidths=None, + headersglobalalign=None, + headersalign=None, + rowalign=None, + maxheadercolwidths=None, +) -> str: + if tabular_data is None: + tabular_data = [] + + list_of_lists, headers, headers_pad = _normalize_tabular_data(tabular_data, headers, showindex=showindex) + list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) + + # PrettyTable formatting does not use any extra padding. + # Numbers are not parsed and are treated the same as strings for alignment. + # Check if pretty is the format being used and override the defaults so it + # does not impact other formats. + min_padding = MIN_PADDING + if tablefmt == "pretty": + min_padding = 0 + disable_numparse = True + numalign = "center" if numalign == _DEFAULT_ALIGN else numalign + stralign = "center" if stralign == _DEFAULT_ALIGN else stralign + else: + numalign = "decimal" if numalign == _DEFAULT_ALIGN else numalign + stralign = "left" if stralign == _DEFAULT_ALIGN else stralign + + # 'colon_grid' uses colons in the line beneath the header to represent a column's + # alignment instead of literally aligning the text differently. Hence, + # left alignment of the data in the text output is enforced. + if tablefmt == "colon_grid": + colglobalalign = "left" + headersglobalalign = "left" + + # optimization: look for ANSI control codes once, + # enable smart width functions only if a control code is found + # + # convert the headers and rows into a single, tab-delimited string ensuring + # that any bytestrings are decoded safely (i.e. errors ignored) + plain_text = "\t".join( + chain( + # headers + map(_to_str, headers), + # rows: chain the rows together into a single iterable after mapping + # the bytestring conversino to each cell value + chain.from_iterable(map(_to_str, row) for row in list_of_lists), + ) + ) + + has_invisible = _ansi_codes.search(plain_text) is not None + + enable_widechars = wcwidth is not None and WIDE_CHARS_MODE + if not isinstance(tablefmt, TableFormat) and tablefmt in multiline_formats and _is_multiline(plain_text): + tablefmt = multiline_formats.get(tablefmt, tablefmt) + is_multiline = True + else: + is_multiline = False + width_fn = _choose_width_fn(has_invisible, enable_widechars, is_multiline) + + # format rows and columns, convert numeric values to strings + cols = list(izip_longest(*list_of_lists)) + numparses = _expand_numparse(disable_numparse, len(cols)) + coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] + if isinstance(floatfmt, str): # old version + float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column + else: # if floatfmt is list, tuple etc we have one per column + float_formats = list(floatfmt) + if len(float_formats) < len(cols): + float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) + if isinstance(intfmt, str): # old version + int_formats = len(cols) * [intfmt] # just duplicate the string to use in each column + else: # if intfmt is list, tuple etc we have one per column + int_formats = list(intfmt) + if len(int_formats) < len(cols): + int_formats.extend((len(cols) - len(int_formats)) * [_DEFAULT_INTFMT]) + if isinstance(missingval, str): + missing_vals = len(cols) * [missingval] + else: + missing_vals = list(missingval) + if len(missing_vals) < len(cols): + missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) + cols = [ + [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] + for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals) + ] + + # align columns + # first set global alignment + if colglobalalign is not None: # if global alignment provided + aligns = [colglobalalign] * len(cols) + else: # default + aligns = [numalign if ct in {int, float} else stralign for ct in coltypes] + # then specific alignments + if colalign is not None: + assert isinstance(colalign, Iterable) + if isinstance(colalign, str): + warnings.warn( + f"As a string, `colalign` is interpreted as {[c for c in colalign]}. " + f'Did you mean `colglobalalign = "{colalign}"` or `colalign = ("{colalign}",)`?', + stacklevel=2, + ) + for idx, align in enumerate(colalign): + if not idx < len(aligns): + break + if align != "global": + aligns[idx] = align + minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) + aligns_copy = aligns.copy() + # Reset alignments in copy of alignments list to "left" for 'colon_grid' format, + # which enforces left alignment in the text output of the data. + if tablefmt == "colon_grid": + aligns_copy = ["left"] * len(cols) + cols = [ + _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline, preserve_whitespace) + for c, a, minw in zip(cols, aligns_copy, minwidths) + ] + + aligns_headers = None + if headers: + # align headers and add headers + t_cols = cols or [[""]] * len(headers) + # first set global alignment + if headersglobalalign is not None: # if global alignment provided + aligns_headers = [headersglobalalign] * len(t_cols) + else: # default + aligns_headers = aligns or [stralign] * len(headers) + # then specific header alignments + if headersalign is not None: + assert isinstance(headersalign, Iterable) + if isinstance(headersalign, str): + warnings.warn( + f"As a string, `headersalign` is interpreted as {[c for c in headersalign]}. " + f'Did you mean `headersglobalalign = "{headersalign}"` ' + f'or `headersalign = ("{headersalign}",)`?', + stacklevel=2, + ) + for idx, align in enumerate(headersalign): + hidx = headers_pad + idx + if not hidx < len(aligns_headers): + break + if align == "same" and hidx < len(aligns): # same as column align + aligns_headers[hidx] = aligns[hidx] + elif align != "global": + aligns_headers[hidx] = align + minwidths = [max(minw, max(width_fn(cl) for cl in c)) for minw, c in zip(minwidths, t_cols)] + headers = [ + _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) + for h, a, minw in zip(headers, aligns_headers, minwidths) + ] + rows = list(zip(*cols)) + else: + minwidths = [max(width_fn(cl) for cl in c) for c in cols] + rows = list(zip(*cols)) + + if not isinstance(tablefmt, TableFormat): + tablefmt = _table_formats.get(tablefmt, _table_formats["simple"]) + + ra_default = rowalign if isinstance(rowalign, str) else None + rowaligns = _expand_iterable(rowalign, len(rows), ra_default) + return _format_table(tablefmt, headers, aligns_headers, rows, minwidths, aligns, is_multiline, rowaligns=rowaligns) + + +def _expand_numparse(disable_numparse, column_count): + if isinstance(disable_numparse, Iterable): + numparses = [True] * column_count + for index in disable_numparse: + numparses[index] = False + return numparses + return [not disable_numparse] * column_count + + +def _expand_iterable(original, num_desired, default): + if isinstance(original, Iterable) and not isinstance(original, str): + return list(original) + [default] * (num_desired - len(original)) + return [default] * num_desired + + +def _pad_row(cells, padding): + if cells: + if cells == SEPARATING_LINE: + return SEPARATING_LINE + pad = " " * padding + padded_cells = [pad + cell + pad for cell in cells] + return padded_cells + return cells + + +def _build_simple_row(padded_cells, rowfmt): + begin, sep, end = rowfmt + return (begin + sep.join(padded_cells) + end).rstrip() + + +def _build_row(padded_cells, colwidths, colaligns, rowfmt): + if not rowfmt: + return None + if callable(rowfmt): + return rowfmt(padded_cells, colwidths, colaligns) + return _build_simple_row(padded_cells, rowfmt) + + +def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalign=None): + # NOTE: rowalign is ignored and exists for api compatibility with _append_multiline_row + lines.append(_build_row(padded_cells, colwidths, colaligns, rowfmt)) + return lines + + +def _build_line(colwidths, colaligns, linefmt): + """Return a string which represents a horizontal line.""" + if not linefmt: + return None + if callable(linefmt): + return linefmt(colwidths, colaligns) + begin, fill, sep, end = linefmt + cells = [fill * w for w in colwidths] + return _build_simple_row(cells, (begin, sep, end)) + + +def _append_line(lines, colwidths, colaligns, linefmt): + lines.append(_build_line(colwidths, colaligns, linefmt)) + return lines + + +def _format_table(fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns): + lines = [] + hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] + pad = fmt.padding + headerrow = fmt.headerrow + + padded_widths = [(w + 2 * pad) for w in colwidths] + pad_row = _pad_row + append_row = _append_basic_row + + padded_headers = pad_row(headers, pad) + + if fmt.lineabove and "lineabove" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.lineabove) + + if padded_headers: + append_row(lines, padded_headers, padded_widths, headersaligns, headerrow) + if fmt.linebelowheader and "linebelowheader" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.linebelowheader) + + if rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: + # initial rows with a line below + for row, ralign in zip(rows[:-1], rowaligns): + if row != SEPARATING_LINE: + append_row(lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow, rowalign=ralign) + _append_line(lines, padded_widths, colaligns, fmt.linebetweenrows) + # the last row without a line below + append_row(lines, pad_row(rows[-1], pad), padded_widths, colaligns, fmt.datarow, rowalign=rowaligns[-1]) + else: + separating_line = ( + fmt.linebetweenrows or fmt.linebelowheader or fmt.linebelow or fmt.lineabove or Line("", "", "", "") + ) + for row in rows: + # test to see if either the 1st column or the 2nd column (account for showindex) has + # the SEPARATING_LINE flag + if _is_separating_line(row): + _append_line(lines, padded_widths, colaligns, separating_line) + else: + append_row(lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow) + + if fmt.linebelow and "linebelow" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.linebelow) + + if headers or rows: + output = "\n".join(lines) + return output + # a completely empty table + return "" diff --git a/src/codeflash_python/code_utils/time_utils.py b/src/codeflash_python/code_utils/time_utils.py new file mode 100644 index 000000000..72aed461e --- /dev/null +++ b/src/codeflash_python/code_utils/time_utils.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from functools import lru_cache + + +def humanize_runtime(time_in_ns: int) -> str: + runtime_human: str = str(time_in_ns) + units = "nanoseconds" + if 1 <= time_in_ns < 2: + units = "nanosecond" + + if time_in_ns / 1000 >= 1: + time_micro = float(time_in_ns) / 1000 + + # Direct unit determination and formatting without external library + if time_micro < 1000: + runtime_human = f"{time_micro:.3g}" + units = "microseconds" if time_micro >= 2 else "microsecond" + elif time_micro < 1000000: + time_milli = time_micro / 1000 + runtime_human = f"{time_milli:.3g}" + units = "milliseconds" if time_milli >= 2 else "millisecond" + elif time_micro < 60000000: + time_sec = time_micro / 1000000 + runtime_human = f"{time_sec:.3g}" + units = "seconds" if time_sec >= 2 else "second" + elif time_micro < 3600000000: + time_min = time_micro / 60000000 + runtime_human = f"{time_min:.3g}" + units = "minutes" if time_min >= 2 else "minute" + elif time_micro < 86400000000: + time_hour = time_micro / 3600000000 + runtime_human = f"{time_hour:.3g}" + units = "hours" if time_hour >= 2 else "hour" + else: # days + time_day = time_micro / 86400000000 + runtime_human = f"{time_day:.3g}" + units = "days" if time_day >= 2 else "day" + + runtime_human_parts = str(runtime_human).split(".") + if len(runtime_human_parts[0]) == 1: + if runtime_human_parts[0] == "1" and len(runtime_human_parts) > 1: + units = units + "s" + if len(runtime_human_parts) == 1: + runtime_human = f"{runtime_human_parts[0]}.00" + elif len(runtime_human_parts[1]) >= 2: + runtime_human = f"{runtime_human_parts[0]}.{runtime_human_parts[1][0:2]}" + else: + runtime_human = ( + f"{runtime_human_parts[0]}.{runtime_human_parts[1]}{'0' * (2 - len(runtime_human_parts[1]))}" + ) + elif len(runtime_human_parts[0]) == 2: + if len(runtime_human_parts) > 1: + runtime_human = f"{runtime_human_parts[0]}.{runtime_human_parts[1][0]}" + else: + runtime_human = f"{runtime_human_parts[0]}.0" + else: + runtime_human = runtime_human_parts[0] + + return f"{runtime_human} {units}" + + +def format_time(nanoseconds: int) -> str: + """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" + # Define conversion factors and units + if not isinstance(nanoseconds, int): + raise TypeError("Input must be an integer.") + if nanoseconds < 0: + raise ValueError("Input must be a positive integer.") + + if nanoseconds < 1_000: + return f"{nanoseconds}ns" + if nanoseconds < 1_000_000: + value = nanoseconds / 1_000 + return ( + f"{value:.2f}\u03bcs" if value < 10 else (f"{value:.1f}\u03bcs" if value < 100 else f"{int(value)}\u03bcs") + ) + if nanoseconds < 1_000_000_000: + value = nanoseconds / 1_000_000 + return f"{value:.2f}ms" if value < 10 else (f"{value:.1f}ms" if value < 100 else f"{int(value)}ms") + value = nanoseconds / 1_000_000_000 + return f"{value:.2f}s" if value < 10 else (f"{value:.1f}s" if value < 100 else f"{int(value)}s") + + +def format_perf(percentage: float) -> str: + """Format percentage into a human-readable string with 3 significant digits when needed.""" + # Branch order optimized + abs_perc = abs(percentage) + if abs_perc >= 100: + return f"{percentage:.0f}" + if abs_perc >= 10: + return f"{percentage:.1f}" + if abs_perc >= 1: + return f"{percentage:.2f}" + return f"{percentage:.3f}" + + +@lru_cache(maxsize=4096) +def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str: + perf_gain = format_perf( + abs(((original_time_ns - optimized_time_ns) / optimized_time_ns) * 100) if optimized_time_ns else 0.0 + ) + status = "slower" if optimized_time_ns > original_time_ns else "faster" + return ( + f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})" + ) diff --git a/src/codeflash_python/code_utils/version_check.py b/src/codeflash_python/code_utils/version_check.py new file mode 100644 index 000000000..00b55773e --- /dev/null +++ b/src/codeflash_python/code_utils/version_check.py @@ -0,0 +1,86 @@ +"""Version checking utilities for codeflash.""" + +from __future__ import annotations + +import logging +import time + +import requests +from packaging import version + +logger = logging.getLogger("codeflash_python") + +try: + from codeflash_python.version import __version__ +except ImportError: + __version__: str = "0.0.0" + +# Simple cache to avoid checking too frequently +_version_cache: dict[str, str | float] = {"version": "0.0.0", "timestamp": float(0)} +_cache_duration = 3600 # 1 hour cache + + +def get_latest_version_from_pypi() -> str | None: + """Get the latest version of codeflash from PyPI. + + Returns: + The latest version string from PyPI, or None if the request fails. + + """ + # Check cache first + current_time = time.time() + cached_version = _version_cache["version"] + cached_timestamp = _version_cache["timestamp"] + assert isinstance(cached_timestamp, float) + if cached_version is not None and current_time - cached_timestamp < _cache_duration: + assert isinstance(cached_version, str) + return cached_version + + try: + response = requests.get("https://pypi.org/pypi/codeflash/json", timeout=2) + if response.status_code == 200: + data = response.json() + latest_version = data["info"]["version"] + + # Update cache + _version_cache["version"] = latest_version + _version_cache["timestamp"] = current_time + + return latest_version + logger.debug("Failed to fetch version from PyPI: %s", response.status_code) + return None + except requests.RequestException as e: + logger.debug("Network error fetching version from PyPI: %s", e) + return None + except (KeyError, ValueError) as e: + logger.debug("Invalid response format from PyPI: %s", e) + return None + except Exception as e: + logger.debug("Unexpected error fetching version from PyPI: %s", e) + return None + + +def check_for_newer_minor_version() -> None: + """Check if a newer minor version is available on PyPI and notify the user. + + This function compares the current version with the latest version on PyPI. + If a newer minor version is available, it prints an informational message + suggesting the user upgrade. + """ + latest_version = get_latest_version_from_pypi() + + if not latest_version: + return + + try: + current_parsed = version.parse(__version__) + latest_parsed = version.parse(latest_version) + + # Check if there's a newer minor version available + # We only notify for minor version updates, not patch updates + if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects + logger.warning("A newer version(%s) of Codeflash is available, please update soon!", latest_version) + + except version.InvalidVersion as e: + logger.debug("Invalid version format: %s", e) + return diff --git a/src/codeflash_python/context/__init__.py b/src/codeflash_python/context/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/context/ast_helpers.py b/src/codeflash_python/context/ast_helpers.py new file mode 100644 index 000000000..e5dd5193a --- /dev/null +++ b/src/codeflash_python/context/ast_helpers.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import ast +import os +from pathlib import Path + +from codeflash_python.models.models import CodeStringsMarkdown + + +def parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None: + all_code = "\n".join(cs.code for cs in code_context.code_strings) + try: + tree = ast.parse(all_code) + except SyntaxError: + return None + imported_names: dict[str, str] = {} + + # Directly iterate over the module body and nested structures instead of ast.walk + # This avoids traversing every single node in the tree + def collect_imports(nodes: list[ast.stmt]) -> None: + for node in nodes: + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + if alias.name != "*": + imported_name = alias.asname if alias.asname else alias.name + imported_names[imported_name] = node.module + # Recursively check nested structures (function defs, class defs, if statements, etc.) + elif isinstance( + node, + ( + ast.FunctionDef, + ast.AsyncFunctionDef, + ast.ClassDef, + ast.If, + ast.For, + ast.AsyncFor, + ast.While, + ast.With, + ast.AsyncWith, + ast.Try, + ast.ExceptHandler, + ), + ): + if hasattr(node, "body"): + collect_imports(node.body) + if hasattr(node, "orelse"): + collect_imports(node.orelse) # type: ignore[arg-type] + if hasattr(node, "finalbody"): + collect_imports(node.finalbody) # type: ignore[arg-type] + if hasattr(node, "handlers"): + for handler in node.handlers: # type: ignore[attr-defined] + collect_imports(handler.body) + # Handle match/case statements (Python 3.10+) + elif hasattr(ast, "Match") and isinstance(node, ast.Match): + for case in node.cases: # type: ignore[attr-defined] + collect_imports(case.body) + + collect_imports(tree.body) + return tree, imported_names + + +def collect_existing_class_names(tree: ast.Module) -> set[str]: + class_names = set() + stack = list(tree.body) + + while stack: + node = stack.pop() + if isinstance(node, ast.ClassDef): + class_names.add(node.name) + stack.extend(node.body) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + stack.extend(node.body) + elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): + stack.extend(node.body) + if hasattr(node, "orelse"): + stack.extend(node.orelse) # type: ignore[arg-type] + elif isinstance(node, ast.Try): + stack.extend(node.body) + stack.extend(node.orelse) + stack.extend(node.finalbody) + for handler in node.handlers: + stack.extend(handler.body) + + return class_names + + +BUILTIN_AND_TYPING_NAMES = frozenset( + { + "int", + "str", + "float", + "bool", + "bytes", + "bytearray", + "complex", + "list", + "dict", + "set", + "frozenset", + "tuple", + "type", + "object", + "None", + "NoneType", + "Ellipsis", + "NotImplemented", + "memoryview", + "range", + "slice", + "property", + "classmethod", + "staticmethod", + "super", + "Optional", + "Union", + "Any", + "List", + "Dict", + "Set", + "FrozenSet", + "Tuple", + "Type", + "Callable", + "Iterator", + "Generator", + "Coroutine", + "AsyncGenerator", + "AsyncIterator", + "Iterable", + "AsyncIterable", + "Sequence", + "MutableSequence", + "Mapping", + "MutableMapping", + "Collection", + "Awaitable", + "Literal", + "Final", + "ClassVar", + "TypeVar", + "TypeAlias", + "ParamSpec", + "Concatenate", + "Annotated", + "TypeGuard", + "Self", + "Unpack", + "TypeVarTuple", + "Never", + "NoReturn", + "SupportsInt", + "SupportsFloat", + "SupportsComplex", + "SupportsBytes", + "SupportsAbs", + "SupportsRound", + "IO", + "TextIO", + "BinaryIO", + "Pattern", + "Match", + } +) + + +def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]: + if node is None: + return set() + if isinstance(node, ast.Name): + return {node.id} + if isinstance(node, ast.Subscript): + names = collect_type_names_from_annotation(node.value) + names |= collect_type_names_from_annotation(node.slice) + return names + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right) + if isinstance(node, ast.Tuple): + names = set[str]() + for elt in node.elts: + names |= collect_type_names_from_annotation(elt) + return names + return set() + + +MAX_RAW_PROJECT_CLASS_BODY_ITEMS = 8 +MAX_RAW_PROJECT_CLASS_LINES = 40 + + +def get_expr_name(node: ast.AST | None) -> str | None: + if node is None: + return None + + parts: list[str] = [] + current = node + while True: + if isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + continue + if isinstance(current, ast.Call): + current = current.func + continue + if isinstance(current, ast.Name): + base_name = current.id + else: + base_name = None + break + + if not parts: + return base_name + + parts.reverse() + if base_name is not None: + parts.insert(0, base_name) + return ".".join(parts) + + +def collect_import_aliases(module_tree: ast.Module) -> dict[str, str]: + aliases: dict[str, str] = {} + for node in module_tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + bound_name = alias.asname if alias.asname else alias.name.split(".")[0] + aliases[bound_name] = alias.name + elif isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + bound_name = alias.asname if alias.asname else alias.name + aliases[bound_name] = f"{node.module}.{alias.name}" + return aliases + + +def find_class_node_by_name(class_name: str, module_tree: ast.Module) -> ast.ClassDef | None: + stack: list[ast.AST] = [module_tree] + while stack: + node = stack.pop() + body = getattr(node, "body", None) + if body: + for item in body: + if isinstance(item, ast.ClassDef): + if item.name == class_name: + return item + stack.append(item) + elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + stack.append(item) + return None + + +def expr_matches_name(node: ast.AST | None, import_aliases: dict[str, str], suffix: str) -> bool: + expr_name = get_expr_name(node) + if expr_name is None: + return False + + suffix_dot = "." + suffix + if expr_name == suffix or expr_name.endswith(suffix_dot): + return True + resolved_name = import_aliases.get(expr_name) + return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(suffix_dot)) + + +def get_node_source(node: ast.AST | None, module_source: str, fallback: str = "...") -> str: + if node is None: + return fallback + source_segment = ast.get_source_segment(module_source, node) + if source_segment is not None: + return source_segment + try: + return ast.unparse(node) + except Exception: + return fallback + + +def bool_literal(node: ast.AST) -> bool | None: + if isinstance(node, ast.Constant) and isinstance(node.value, bool): + return node.value + return None + + +def is_namedtuple_class(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool: + for base in class_node.bases: # noqa: SIM110 + if expr_matches_name(base, import_aliases, "NamedTuple"): + return True + return False + + +def get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]: + for decorator in class_node.decorator_list: + if not expr_matches_name(decorator, import_aliases, "dataclass"): + continue + init_enabled = True + kw_only = False + if isinstance(decorator, ast.Call): + for keyword in decorator.keywords: + literal_value = bool_literal(keyword.value) + if literal_value is None: + continue + if keyword.arg == "init": + init_enabled = literal_value + elif keyword.arg == "kw_only": + kw_only = literal_value + return True, init_enabled, kw_only + return False, False, False + + +def is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool: + annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation + return expr_matches_name(annotation_root, import_aliases, "ClassVar") + + +def is_project_subpath(module_path: Path, project_root_path: Path) -> bool: + return str(module_path.resolve()).startswith(str(project_root_path.resolve()) + os.sep) + + +def get_class_start_line(class_node: ast.ClassDef) -> int: + if class_node.decorator_list: + return min(d.lineno for d in class_node.decorator_list) + return class_node.lineno + + +def class_has_explicit_init(class_node: ast.ClassDef) -> bool: + for item in class_node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__": + return True + return False diff --git a/src/codeflash_python/context/call_graph_index.py b/src/codeflash_python/context/call_graph_index.py new file mode 100644 index 000000000..24eca72b0 --- /dev/null +++ b/src/codeflash_python/context/call_graph_index.py @@ -0,0 +1,668 @@ +from __future__ import annotations + +import hashlib +import logging +import os +import sqlite3 +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages +from codeflash_python.context.types import IndexResult +from codeflash_python.context.utils import get_qualified_name +from codeflash_python.models.models import FunctionSource + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from jedi.api.classes import Name + + from codeflash_python.models.call_graph import CallGraph + + +# --------------------------------------------------------------------------- +# Module-level helpers (must be top-level for ProcessPoolExecutor pickling) +# --------------------------------------------------------------------------- + + +logger = logging.getLogger("codeflash_python") + +PARALLEL_THRESHOLD = 8 + +# Per-worker state, initialised by init_index_worker in child processes +worker_jedi_project: object | None = None +worker_project_root_str: str | None = None + + +def init_index_worker(project_root: str) -> None: + import jedi + + global worker_jedi_project, worker_project_root_str + worker_jedi_project = jedi.Project(path=project_root) + worker_project_root_str = project_root + + +def resolve_definitions(ref: Name) -> list[Name]: + try: + inferred = ref.infer() + valid = [d for d in inferred if d.type in ("function", "class", "statement")] + if valid: + return valid + except Exception: + pass + + try: + result: list[Name] = ref.goto(follow_imports=True, follow_builtin_imports=False) + return result + except Exception: + return [] + + +def is_valid_definition(definition: Name, caller_qualified_name: str, project_root_str: str) -> bool: + definition_path = definition.module_path + if definition_path is None: + return False + + if not str(definition_path).startswith(project_root_str + os.sep): + return False + + if path_belongs_to_site_packages(definition_path): + return False + + if not definition.full_name or not definition.full_name.startswith(definition.module_name): + return False + + if definition.type not in ("function", "class", "statement"): + return False + + try: + def_qn = get_qualified_name(definition.module_name, definition.full_name) + if def_qn == caller_qualified_name: + return False + except ValueError: + return False + + try: + from codeflash_python.context.jedi_helpers import belongs_to_function_qualified + + if belongs_to_function_qualified(definition, caller_qualified_name): + return False + except Exception: + pass + + return True + + +def get_enclosing_function_qn(ref: Name) -> str | None: + try: + parent = ref.parent() + if parent is None or parent.type != "function": + return None + if not parent.full_name or not parent.full_name.startswith(parent.module_name): + return None + return get_qualified_name(parent.module_name, parent.full_name) + except (ValueError, AttributeError): + return None + + +def analyze_file(file_path: Path, jedi_project: object, project_root_str: str) -> tuple[set[tuple[str, ...]], bool]: + """Pure Jedi analysis -- no DB access. Returns (edges, had_error).""" + import jedi + + resolved = str(file_path.resolve()) + + try: + script = jedi.Script(path=file_path, project=jedi_project) + refs = script.get_names(all_scopes=True, definitions=False, references=True) + except Exception: + return set(), True + + edges: set[tuple[str, ...]] = set() + + for ref in refs: + try: + caller_qn = get_enclosing_function_qn(ref) + if caller_qn is None: + continue + + definitions = resolve_definitions(ref) + if not definitions: + continue + + definition = definitions[0] + definition_path = definition.module_path + if definition_path is None: + continue + + if not is_valid_definition(definition, caller_qn, project_root_str): + continue + + edge_base = (resolved, caller_qn, str(definition_path)) + + if definition.type == "function": + callee_qn = get_qualified_name(definition.module_name, definition.full_name) + if len(callee_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + callee_qn, + definition.full_name, + definition.name, + definition.type, + definition.get_line_code(), + ) + ) + elif definition.type == "class": + init_qn = get_qualified_name(definition.module_name, f"{definition.full_name}.__init__") + if len(init_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + init_qn, + f"{definition.full_name}.__init__", + "__init__", + definition.type, + definition.get_line_code(), + ) + ) + elif definition.type == "statement": + callee_qn = get_qualified_name(definition.module_name, definition.full_name) + edges.add( + ( + *edge_base, + callee_qn, + definition.full_name, + definition.name, + definition.type, + definition.get_line_code(), + ) + ) + except Exception: + continue + + return edges, False + + +def index_file_worker(args: tuple[str, str]) -> tuple[str, str, set[tuple[str, ...]], bool]: + """Worker entry point for ProcessPoolExecutor.""" + file_path_str, file_hash = args + assert worker_project_root_str is not None + edges, had_error = analyze_file(Path(file_path_str), worker_jedi_project, worker_project_root_str) + return file_path_str, file_hash, edges, had_error + + +# --------------------------------------------------------------------------- + + +class CallGraphIndex: + SCHEMA_VERSION = 2 + + def __init__(self, project_root: Path, language: str = "python", db_path: Path | None = None) -> None: + import jedi + + self.project_root = project_root.resolve() + self.project_root_str = str(self.project_root) + self.language = language + self.jedi_project = jedi.Project(path=self.project_root) + + if db_path is None: + from codeflash_python.code_utils.compat import codeflash_cache_db + + db_path = codeflash_cache_db + + self.conn = sqlite3.connect(str(db_path)) + self.conn.execute("PRAGMA journal_mode=WAL") + self.indexed_file_hashes: dict[str, str] = {} + self.resolved_paths: dict[Path, str] = {} + self.init_schema() + + def resolve_path(self, file_path: Path) -> str: + cached = self.resolved_paths.get(file_path) + if cached is not None: + return cached + resolved = str(file_path.resolve()) + self.resolved_paths[file_path] = resolved + return resolved + + def init_schema(self) -> None: + cur = self.conn.cursor() + cur.execute("CREATE TABLE IF NOT EXISTS cg_schema_version (version INTEGER PRIMARY KEY)") + + row = cur.execute("SELECT version FROM cg_schema_version LIMIT 1").fetchone() + if row is None: + cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + elif row[0] != self.SCHEMA_VERSION: + for table in [ + "cg_call_edges", + "cg_indexed_files", + "cg_languages", + "cg_projects", + "cg_project_meta", + "indexed_files", + "call_edges", + ]: + cur.execute(f"DROP TABLE IF EXISTS {table}") + cur.execute("DELETE FROM cg_schema_version") + cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS indexed_files ( + project_root TEXT NOT NULL, + language TEXT NOT NULL, + file_path TEXT NOT NULL, + file_hash TEXT NOT NULL, + PRIMARY KEY (project_root, language, file_path) + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS call_edges ( + project_root TEXT NOT NULL, + language TEXT NOT NULL, + caller_file TEXT NOT NULL, + caller_qualified_name TEXT NOT NULL, + callee_file TEXT NOT NULL, + callee_qualified_name TEXT NOT NULL, + callee_fully_qualified_name TEXT NOT NULL, + callee_only_function_name TEXT NOT NULL, + callee_definition_type TEXT NOT NULL, + callee_source_line TEXT NOT NULL, + PRIMARY KEY (project_root, language, caller_file, caller_qualified_name, + callee_file, callee_qualified_name) + ) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_call_edges_caller + ON call_edges (project_root, language, caller_file, caller_qualified_name) + """ + ) + self.conn.commit() + + def get_callees( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) + function_source_list: list[FunctionSource] = [] + + all_caller_keys: list[tuple[str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = self.resolve_path(file_path) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return file_path_to_function_source, function_source_list + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _caller_keys") + cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys) + + rows = cur.execute( + """ + SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name, + ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line + FROM call_edges ce + INNER JOIN _caller_keys ck + ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + AND NOT (ce.callee_file = ce.caller_file AND ce.callee_qualified_name = ce.caller_qualified_name) + AND ce.callee_definition_type IN ('function', 'class', 'statement') + """, + (self.project_root_str, self.language), + ).fetchall() + + for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows: + callee_path = Path(callee_file) + fs = FunctionSource( + file_path=callee_path, + qualified_name=callee_qn, + fully_qualified_name=callee_fqn, + only_function_name=callee_name, + source_code=callee_src, + definition_type=callee_type, + ) + file_path_to_function_source[callee_path].add(fs) + function_source_list.append(fs) + + return file_path_to_function_source, function_source_list + + def count_callees_per_function( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> dict[tuple[Path, str], int]: + all_caller_keys: list[tuple[Path, str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = self.resolve_path(file_path) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return {} + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _count_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _count_keys") + cur.executemany( + "INSERT INTO _count_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys] + ) + + rows = cur.execute( + """ + SELECT ck.caller_file, ck.caller_qualified_name, COUNT(ce.rowid) + FROM _count_keys ck + LEFT JOIN call_edges ce + ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name + AND ce.project_root = ? AND ce.language = ? + AND NOT (ce.callee_file = ce.caller_file AND ce.callee_qualified_name = ce.caller_qualified_name) + GROUP BY ck.caller_file, ck.caller_qualified_name + """, + (self.project_root_str, self.language), + ).fetchall() + + resolved_to_path: dict[str, Path] = {resolved: fp for fp, resolved, _ in all_caller_keys} + counts: dict[tuple[Path, str], int] = {} + for caller_file, caller_qn, cnt in rows: + counts[(resolved_to_path[caller_file], caller_qn)] = cnt + + return counts + + def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> IndexResult: + if resolved is None: + resolved = self.resolve_path(file_path) + + # Always read and hash the file before checking the cache so we detect on-disk changes + try: + content = file_path.read_text(encoding="utf-8") + except Exception: + return IndexResult(file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True) + + file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + + if self.is_file_cached(resolved, file_hash): + return IndexResult(file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False) + + return self.index_file(file_path, file_hash, resolved) + + def index_file(self, file_path: Path, file_hash: str, resolved: str | None = None) -> IndexResult: + if resolved is None: + resolved = self.resolve_path(file_path) + edges, had_error = analyze_file(file_path, self.jedi_project, self.project_root_str) + if had_error: + logger.debug("CallGraphIndex: failed to parse %s", file_path) + return self.persist_edges(file_path, resolved, file_hash, edges, had_error) + + def persist_edges( + self, file_path: Path, resolved: str, file_hash: str, edges: set[tuple[str, ...]], had_error: bool + ) -> IndexResult: + cur = self.conn.cursor() + scope = (self.project_root_str, self.language) + + # Clear existing data for this file + cur.execute( + "DELETE FROM call_edges WHERE project_root = ? AND language = ? AND caller_file = ?", (*scope, resolved) + ) + cur.execute( + "DELETE FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", (*scope, resolved) + ) + + # Insert new edges if parsing succeeded + if not had_error and edges: + cur.executemany( + """ + INSERT OR REPLACE INTO call_edges + (project_root, language, caller_file, caller_qualified_name, + callee_file, callee_qualified_name, callee_fully_qualified_name, + callee_only_function_name, callee_definition_type, callee_source_line) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [(*scope, *edge) for edge in edges], + ) + + # Record that this file has been indexed + cur.execute( + "INSERT OR REPLACE INTO indexed_files (project_root, language, file_path, file_hash) VALUES (?, ?, ?, ?)", + (*scope, resolved, file_hash), + ) + + self.conn.commit() + self.indexed_file_hashes[resolved] = file_hash + + # Build summary for return value + edges_summary = tuple( + (caller_qn, callee_name, caller_file != callee_file) + for (caller_file, caller_qn, callee_file, _, _, callee_name, _, _) in edges + ) + cross_file_count = sum(is_cross_file for _, _, is_cross_file in edges_summary) + + return IndexResult( + file_path=file_path, + cached=False, + num_edges=len(edges), + edges=edges_summary, + cross_file_edges=cross_file_count, + error=had_error, + ) + + def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None: + """Pre-index a batch of files, using multiprocessing for large uncached batches.""" + to_index: list[tuple[Path, str, str]] = [] + + for file_path in file_paths: + resolved = self.resolve_path(file_path) + + # Fast path: already indexed this session + if resolved in self.indexed_file_hashes: + self.report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False + ), + ) + continue + + try: + content = file_path.read_text(encoding="utf-8") + except Exception: + self.report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True + ), + ) + continue + + file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + + # Check if already cached (in-memory or DB) + if self.is_file_cached(resolved, file_hash): + self.report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False + ), + ) + continue + + to_index.append((file_path, resolved, file_hash)) + + if not to_index: + return + + # Index uncached files + if len(to_index) >= PARALLEL_THRESHOLD: + self.build_index_parallel(to_index, on_progress) + else: + for file_path, resolved, file_hash in to_index: + result = self.index_file(file_path, file_hash, resolved) + self.report_progress(on_progress, result) + + def is_file_cached(self, resolved: str, file_hash: str) -> bool: + """Check if file is cached in memory or DB.""" + if self.indexed_file_hashes.get(resolved) == file_hash: + return True + + row = self.conn.execute( + "SELECT file_hash FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", + (self.project_root_str, self.language, resolved), + ).fetchone() + + if row and row[0] == file_hash: + self.indexed_file_hashes[resolved] = file_hash + return True + + return False + + def report_progress(self, on_progress: Callable[[IndexResult], None] | None, result: IndexResult) -> None: + """Report progress if callback provided.""" + if on_progress is not None: + on_progress(result) + + def build_index_parallel( + self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None + ) -> None: + from concurrent.futures import ProcessPoolExecutor, as_completed + + max_workers = min(os.cpu_count() or 1, len(to_index), 8) + path_info: dict[str, tuple[Path, str]] = {resolved: (fp, fh) for fp, resolved, fh in to_index} + worker_args = [(resolved, fh) for _fp, resolved, fh in to_index] + + logger.debug("CallGraphIndex: indexing %s files across %s workers", len(to_index), max_workers) + + try: + with ProcessPoolExecutor( + max_workers=max_workers, initializer=init_index_worker, initargs=(self.project_root_str,) + ) as executor: + futures = {executor.submit(index_file_worker, args): args[0] for args in worker_args} + + for future in as_completed(futures): + resolved = futures[future] + file_path, file_hash = path_info[resolved] + + try: + _, _, edges, had_error = future.result() + except Exception: + logger.debug("CallGraphIndex: worker failed for %s", file_path) + self.persist_edges(file_path, resolved, file_hash, set(), had_error=True) + self.report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True + ), + ) + continue + + if had_error: + logger.debug("CallGraphIndex: failed to parse %s", file_path) + + result = self.persist_edges(file_path, resolved, file_hash, edges, had_error) + self.report_progress(on_progress, result) + + except Exception: + logger.debug("CallGraphIndex: parallel indexing failed, falling back to sequential") + self.fallback_sequential_index(to_index, on_progress) + + def fallback_sequential_index( + self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None + ) -> None: + """Fallback to sequential indexing when parallel processing fails.""" + for file_path, resolved, file_hash in to_index: + # Skip files already persisted before the failure + if resolved in self.indexed_file_hashes: + continue + result = self.index_file(file_path, file_hash, resolved) + self.report_progress(on_progress, result) + + def get_call_graph( + self, file_path_to_qualified_names: dict[Path, set[str]], *, include_metadata: bool = False + ) -> CallGraph: + from codeflash_python.models.call_graph import CallEdge, CalleeMetadata, CallGraph, FunctionNode + + all_caller_keys: list[tuple[Path, str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = self.resolve_path(file_path) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return CallGraph(edges=[]) + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _graph_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _graph_keys") + cur.executemany( + "INSERT INTO _graph_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys] + ) + + if include_metadata: + rows = cur.execute( + """ + SELECT ce.caller_file, ce.caller_qualified_name, + ce.callee_file, ce.callee_qualified_name, + ce.callee_fully_qualified_name, ce.callee_only_function_name, + ce.callee_definition_type, ce.callee_source_line + FROM call_edges ce + INNER JOIN _graph_keys gk + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + """, + (self.project_root_str, self.language), + ).fetchall() + + edges: list[CallEdge] = [] + for ( + caller_file, + caller_qn, + callee_file, + callee_qn, + callee_fqn, + callee_name, + callee_type, + callee_src, + ) in rows: + edges.append( + CallEdge( + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), + is_cross_file=caller_file != callee_file, + callee_metadata=CalleeMetadata( + fully_qualified_name=callee_fqn, + only_function_name=callee_name, + definition_type=callee_type, + source_line=callee_src, + ), + ) + ) + else: + rows = cur.execute( + """ + SELECT ce.caller_file, ce.caller_qualified_name, + ce.callee_file, ce.callee_qualified_name + FROM call_edges ce + INNER JOIN _graph_keys gk + ON ce.caller_file = gk.caller_file AND ce.caller_qualified_name = gk.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + """, + (self.project_root_str, self.language), + ).fetchall() + + edges = [] + for caller_file, caller_qn, callee_file, callee_qn in rows: + edges.append( + CallEdge( + caller=FunctionNode(file_path=Path(caller_file), qualified_name=caller_qn), + callee=FunctionNode(file_path=Path(callee_file), qualified_name=callee_qn), + is_cross_file=caller_file != callee_file, + ) + ) + + return CallGraph(edges=edges) + + def close(self) -> None: + self.conn.close() diff --git a/src/codeflash_python/context/class_extraction.py b/src/codeflash_python/context/class_extraction.py new file mode 100644 index 000000000..1e21f84ee --- /dev/null +++ b/src/codeflash_python/context/class_extraction.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import ast +import logging +import os +from typing import TYPE_CHECKING + +from codeflash_python.context.ast_helpers import ( + MAX_RAW_PROJECT_CLASS_BODY_ITEMS, + MAX_RAW_PROJECT_CLASS_LINES, + bool_literal, + collect_existing_class_names, + collect_import_aliases, + collect_type_names_from_annotation, + expr_matches_name, + find_class_node_by_name, + get_class_start_line, + get_dataclass_config, + get_expr_name, + get_node_source, + is_classvar_annotation, + is_namedtuple_class, + is_project_subpath, + parse_and_collect_imports, +) +from codeflash_python.context.jedi_helpers import get_jedi_project +from codeflash_python.models.models import CodeString, CodeStringsMarkdown + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger("codeflash_python") + + +def collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]: + is_dataclass, dataclass_init_enabled, _ = get_dataclass_config(class_node, import_aliases) + if not is_namedtuple_class(class_node, import_aliases) and not is_dataclass: + return set() + if is_dataclass and not dataclass_init_enabled: + return set() + + names = set[str]() + for item in class_node.body: + if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name) or item.annotation is None: + continue + if is_classvar_annotation(item.annotation, import_aliases): + continue + + include_in_init = True + if isinstance(item.value, ast.Call) and expr_matches_name(item.value.func, import_aliases, "field"): + for keyword in item.value.keywords: + if keyword.arg != "init": + continue + literal_value = bool_literal(keyword.value) + if literal_value is not None: + include_in_init = literal_value + break + + if include_in_init: + names |= collect_type_names_from_annotation(item.annotation) + + return names + + +def extract_synthetic_init_parameters( + class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str], *, kw_only_by_default: bool +) -> list[tuple[str, str, str | None, bool]]: + parameters: list[tuple[str, str, str | None, bool]] = [] + for item in class_node.body: + if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name): + continue + if is_classvar_annotation(item.annotation, import_aliases): + continue + + include_in_init = True + kw_only = kw_only_by_default + default_value: str | None = None + if item.value is not None: + if isinstance(item.value, ast.Call) and expr_matches_name(item.value.func, import_aliases, "field"): + for keyword in item.value.keywords: + if keyword.arg == "init": + literal_value = bool_literal(keyword.value) + if literal_value is not None: + include_in_init = literal_value + elif keyword.arg == "kw_only": + literal_value = bool_literal(keyword.value) + if literal_value is not None: + kw_only = literal_value + elif keyword.arg == "default": + default_value = get_node_source(keyword.value, module_source) + elif keyword.arg == "default_factory": + default_value = "..." + else: + default_value = get_node_source(item.value, module_source) + + if not include_in_init: + continue + + parameters.append( + (item.target.id, get_node_source(item.annotation, module_source, "Any"), default_value, kw_only) + ) + return parameters + + +def build_synthetic_init_stub( + class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str] +) -> str | None: + is_namedtuple = is_namedtuple_class(class_node, import_aliases) + is_dataclass, dataclass_init_enabled, dataclass_kw_only = get_dataclass_config(class_node, import_aliases) + if not is_namedtuple and not is_dataclass: + return None + if is_dataclass and not dataclass_init_enabled: + return None + + parameters = extract_synthetic_init_parameters( + class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only + ) + if not parameters: + return None + + signature_parts = ["self"] + inserted_kw_only_marker = False + for param_name, annotation_source, default_value, kw_only in parameters: + if kw_only and not inserted_kw_only_marker: + signature_parts.append("*") + inserted_kw_only_marker = True + part = f"{param_name}: {annotation_source}" + if default_value is not None: + part += f" = {default_value}" + signature_parts.append(part) + + signature = ", ".join(signature_parts) + return f" def __init__({signature}):\n ..." + + +def extract_function_stub_snippet(fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]) -> str: + start_line = min(d.lineno for d in fn_node.decorator_list) if fn_node.decorator_list else fn_node.lineno + return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno]) + + +def extract_raw_class_context(class_node: ast.ClassDef, module_source: str, module_tree: ast.Module) -> str: + class_source = "\n".join(module_source.splitlines()[get_class_start_line(class_node) - 1 : class_node.end_lineno]) + needed_imports = extract_imports_for_class(module_tree, class_node, module_source) + if needed_imports: + return f"{needed_imports}\n\n{class_source}" + return class_source + + +def has_non_property_method_decorator( + fn_node: ast.FunctionDef | ast.AsyncFunctionDef, import_aliases: dict[str, str] +) -> bool: + for decorator in fn_node.decorator_list: + if expr_matches_name(decorator, import_aliases, "property"): + continue + decorator_name = get_expr_name(decorator) + if decorator_name and decorator_name.endswith((".setter", ".deleter")): + continue + return True + return False + + +def should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool: + if class_node.decorator_list: + return True + + if is_namedtuple_class(class_node, import_aliases): + return True + is_dataclass, _, _ = get_dataclass_config(class_node, import_aliases) + if is_dataclass: + return True + + start_line = get_class_start_line(class_node) + assert class_node.end_lineno is not None + class_line_count = class_node.end_lineno - start_line + 1 + is_small = ( + class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS + ) + + has_explicit_init = False + + for item in class_node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if item.name == "__init__": + has_explicit_init = True + if is_small: + return True + if has_non_property_method_decorator(item, import_aliases): + return True + elif isinstance(item, (ast.Assign, ast.AnnAssign)) and isinstance(item.value, ast.Call): + return True + + return False + + +def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None: + class_node = find_class_node_by_name(class_name, module_tree) + + if class_node is None: + return None + + lines = module_source.splitlines() + import_aliases = collect_import_aliases(module_tree) + explicit_init_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] + support_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] + for item in class_node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if item.name == "__init__": + explicit_init_nodes.append(item) + support_nodes.append(item) + continue + if item.name == "__post_init__": + support_nodes.append(item) + continue + for d in item.decorator_list: + if (isinstance(d, ast.Name) and d.id == "property") or ( + isinstance(d, ast.Attribute) and d.attr == "property" + ): + support_nodes.append(item) + break + + snippets: list[str] = [] + if not explicit_init_nodes: + synthetic_init = build_synthetic_init_stub(class_node, module_source, import_aliases) + if synthetic_init is not None: + snippets.append(synthetic_init) + for fn_node in support_nodes: + snippets.append(extract_function_stub_snippet(fn_node, lines)) + + if not snippets: + return None + + return f"class {class_name}:\n" + "\n".join(snippets) + + +def get_module_source_and_tree( + module_path: Path, module_cache: dict[Path, tuple[str, ast.Module]] +) -> tuple[str, ast.Module] | None: + if module_path in module_cache: + return module_cache[module_path] + try: + module_source = module_path.read_text(encoding="utf-8") + module_tree = ast.parse(module_source) + except Exception: + return None + module_cache[module_path] = (module_source, module_tree) + return module_source, module_tree + + +def resolve_imported_class_reference( + base_expr_name: str, + current_module_tree: ast.Module, + current_module_path: Path, + project_root_path: Path, + module_cache: dict[Path, tuple[str, ast.Module]], +) -> tuple[str, Path] | None: + import jedi + + import_aliases = collect_import_aliases(current_module_tree) + class_name = base_expr_name.rsplit(".", 1)[-1] + if "." not in base_expr_name and find_class_node_by_name(class_name, current_module_tree) is not None: + return class_name, current_module_path + + resolved_name = base_expr_name + if base_expr_name in import_aliases: + resolved_name = import_aliases[base_expr_name] + elif "." in base_expr_name: + head, tail = base_expr_name.split(".", 1) + if head in import_aliases: + resolved_name = f"{import_aliases[head]}.{tail}" + + if "." not in resolved_name: + return None + + module_name, class_name = resolved_name.rsplit(".", 1) + try: + script_code = f"from {module_name} import {class_name}" + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) + definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True) + except Exception: + return None + + if not definitions or definitions[0].module_path is None: + return None + module_path = definitions[0].module_path + if not is_project_subpath(module_path, project_root_path): + return None + if get_module_source_and_tree(module_path, module_cache) is None: + return None + return class_name, module_path + + +def append_project_class_context( + class_name: str, + module_path: Path, + project_root_path: Path, + module_cache: dict[Path, tuple[str, ast.Module]], + existing_class_names: set[str], + emitted_classes: set[tuple[Path, str]], + emitted_class_names: set[str], + code_strings: list[CodeString], +) -> bool: + module_result = get_module_source_and_tree(module_path, module_cache) + if module_result is None: + return False + module_source, module_tree = module_result + class_node = find_class_node_by_name(class_name, module_tree) + if class_node is None: + return False + + class_key = (module_path, class_name) + if class_key in emitted_classes or class_name in existing_class_names: + return True + + for base in class_node.bases: + base_expr_name = get_expr_name(base) + if base_expr_name is None: + continue + resolved = resolve_imported_class_reference( + base_expr_name, module_tree, module_path, project_root_path, module_cache + ) + if resolved is None: + continue + base_name, base_module_path = resolved + if base_name in existing_class_names: + continue + append_project_class_context( + base_name, + base_module_path, + project_root_path, + module_cache, + existing_class_names, + emitted_classes, + emitted_class_names, + code_strings, + ) + + code_strings.append( + CodeString(code=extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path) + ) + emitted_classes.add(class_key) + emitted_class_names.add(class_name) + return True + + +def resolve_instance_class_name(name: str, module_tree: ast.Module) -> str | None: + for node in module_tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == name: + value = node.value + if isinstance(value, ast.Call): + func = value.func + if isinstance(func, ast.Name): + return func.id + if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + return func.value.id + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name: + ann = node.annotation + if isinstance(ann, ast.Name): + return ann.id + if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name): + return ann.value.id + return None + + +def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: + import jedi + + result = parse_and_collect_imports(code_context) + if result is None: + return CodeStringsMarkdown(code_strings=[]) + tree, imported_names = result + + if not imported_names: + return CodeStringsMarkdown(code_strings=[]) + + existing_classes = collect_existing_class_names(tree) + + code_strings: list[CodeString] = [] + emitted_class_names: set[str] = set() + + # --- Step 1: Project class definitions (jedi resolution + recursive base extraction) --- + extracted_classes: set[tuple[Path, str]] = set() + module_cache: dict[Path, tuple[str, ast.Module]] = {} + + def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: + if module_path in module_cache: + return module_cache[module_path] + try: + module_source = module_path.read_text(encoding="utf-8") + module_tree = ast.parse(module_source) + except Exception: + return None + else: + module_cache[module_path] = (module_source, module_tree) + return module_source, module_tree + + def extract_class_and_bases( + class_name: str, module_path: Path, module_source: str, module_tree: ast.Module + ) -> None: + if (module_path, class_name) in extracted_classes: + return + + class_node = None + for node in ast.walk(module_tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + class_node = node + break + + if class_node is None: + return + + for base in class_node.bases: + base_name = None + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + continue + + if base_name and base_name not in existing_classes: + extract_class_and_bases(base_name, module_path, module_source, module_tree) + + if (module_path, class_name) in extracted_classes: + return + + lines = module_source.split("\n") + start_line = class_node.lineno + if class_node.decorator_list: + start_line = min(d.lineno for d in class_node.decorator_list) + class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) + + full_source = class_source + + code_strings.append(CodeString(code=full_source, file_path=module_path)) + extracted_classes.add((module_path, class_name)) + emitted_class_names.add(class_name) + + for name, module_name in imported_names.items(): + if name in existing_classes or module_name == "__future__": + continue + try: + test_code = f"import {module_name}" + script = jedi.Script(test_code, project=get_jedi_project(str(project_root_path))) + completions = script.goto(1, len(test_code)) + + if not completions: + continue + + module_path = completions[0].module_path + if not module_path: + continue + + resolved_module = module_path.resolve() + module_str = str(resolved_module) + is_project = module_str.startswith(str(project_root_path.resolve()) + os.sep) + is_third_party = "site-packages" in module_str + if not is_project and not is_third_party: + continue + + mod_result = get_module_source_and_tree(module_path) + if mod_result is None: + continue + module_source, module_tree = mod_result + + if is_project: + extract_class_and_bases(name, module_path, module_source, module_tree) + if (module_path, name) not in extracted_classes: + resolved_class = resolve_instance_class_name(name, module_tree) + if resolved_class and resolved_class not in existing_classes: + extract_class_and_bases(resolved_class, module_path, module_source, module_tree) + elif is_third_party: + target_name = name + if not any(isinstance(n, ast.ClassDef) and n.name == name for n in ast.walk(module_tree)): + resolved_class = resolve_instance_class_name(name, module_tree) + if resolved_class: + target_name = resolved_class + if target_name not in emitted_class_names: + stub = extract_init_stub_from_class(target_name, module_source, module_tree) + if stub: + code_strings.append(CodeString(code=stub, file_path=module_path)) + emitted_class_names.add(target_name) + + except Exception: + logger.debug("Error extracting class definition for %s from %s", name, module_name) + continue + + return CodeStringsMarkdown(code_strings=code_strings) + + +def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: + """Extract import statements needed for a class definition. + + This extracts imports for base classes, decorators, and type annotations. + """ + needed_names: set[str] = set() + + # Get base class names + for base in class_node.bases: + if isinstance(base, ast.Name): + needed_names.add(base.id) + elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): + # For things like abc.ABC, we need the module name + needed_names.add(base.value.id) + + # Get decorator names (e.g., dataclass, field) + for decorator in class_node.decorator_list: + if isinstance(decorator, ast.Name): + needed_names.add(decorator.id) + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name): + needed_names.add(decorator.func.id) + elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name): + needed_names.add(decorator.func.value.id) + + # Get type annotation names from class body (for dataclass fields) + for item in class_node.body: + if isinstance(item, ast.AnnAssign) and item.annotation: + collect_names_from_annotation(item.annotation, needed_names) + # Also check for field() calls which are common in dataclasses + elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call): + if isinstance(item.value.func, ast.Name): + needed_names.add(item.value.func.id) + + # Find imports that provide these names + import_lines: list[str] = [] + source_lines = module_source.split("\n") + added_imports: set[int] = set() # Track line numbers to avoid duplicates + + for node in module_tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + name = alias.asname if alias.asname else alias.name.split(".")[0] + if name in needed_names and node.lineno not in added_imports: + import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) + break + elif isinstance(node, ast.ImportFrom): + for alias in node.names: + name = alias.asname if alias.asname else alias.name + if name in needed_names and node.lineno not in added_imports: + import_lines.append(source_lines[node.lineno - 1]) + added_imports.add(node.lineno) + break + + return "\n".join(import_lines) + + +def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None: + """Recursively collect type annotation names from an AST node.""" + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Subscript): + collect_names_from_annotation(node.value, names) + collect_names_from_annotation(node.slice, names) + elif isinstance(node, ast.Tuple): + for elt in node.elts: + collect_names_from_annotation(elt, names) + elif isinstance(node, ast.BinOp): # For Union types with | syntax + collect_names_from_annotation(node.left, names) + collect_names_from_annotation(node.right, names) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + names.add(node.value.id) diff --git a/src/codeflash_python/context/code_context_extractor.py b/src/codeflash_python/context/code_context_extractor.py new file mode 100644 index 000000000..a2087a842 --- /dev/null +++ b/src/codeflash_python/context/code_context_extractor.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import ast +import hashlib +import logging +from collections import defaultdict +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionToOptimize # noqa: TC001 +from codeflash_python.code_utils.code_utils import encoded_tokens_len +from codeflash_python.code_utils.config_consts import ( + OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + READ_WRITABLE_LIMIT_ERROR, + TESTGEN_CONTEXT_TOKEN_LIMIT, + TESTGEN_LIMIT_ERROR, +) +from codeflash_python.context.ast_helpers import collect_existing_class_names, parse_and_collect_imports +from codeflash_python.context.class_extraction import enrich_testgen_context +from codeflash_python.context.cst_pruning import parse_code_and_prune_cst +from codeflash_python.context.jedi_helpers import ( + get_function_sources_from_jedi, + get_function_to_optimize_as_function_source, +) +from codeflash_python.context.type_extraction import extract_parameter_type_constructors +from codeflash_python.context.types import CodeContextType +from codeflash_python.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash_python.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource +from codeflash_python.static_analysis.code_extractor import find_preexisting_objects +from codeflash_python.static_analysis.import_analysis import add_needed_imports_from_module + +if TYPE_CHECKING: + from codeflash_python.context.types import DependencyResolver + +logger = logging.getLogger("codeflash_python") + + +def build_testgen_context( + helpers_of_fto_dict: dict[Path, set[FunctionSource]], + helpers_of_helpers_dict: dict[Path, set[FunctionSource]], + project_root_path: Path, + *, + remove_docstrings: bool = False, + include_enrichment: bool = True, + function_to_optimize: FunctionToOptimize | None = None, +) -> CodeStringsMarkdown: + testgen_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=remove_docstrings, + code_context_type=CodeContextType.TESTGEN, + ) + + if include_enrichment: + enrichment = enrich_testgen_context(testgen_context, project_root_path) + if enrichment.code_strings: + testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings) + + if function_to_optimize is not None: + result = parse_and_collect_imports(testgen_context) + existing_classes = collect_existing_class_names(result[0]) if result else set() + constructor_stubs = extract_parameter_type_constructors( + function_to_optimize, project_root_path, existing_classes + ) + if constructor_stubs.code_strings: + testgen_context = CodeStringsMarkdown( + code_strings=testgen_context.code_strings + constructor_stubs.code_strings + ) + + return testgen_context + + +def get_code_optimization_context( + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, + call_graph: DependencyResolver | None = None, +) -> CodeOptimizationContext: + # Get FunctionSource representation of helpers of FTO + fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}} + if call_graph is not None: + helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input) + else: + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path) + + # Add function to optimize into helpers of FTO dict, as they'll be processed together + fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) + helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source) + + # Format data to search for helpers of helpers using get_function_sources_from_jedi + helpers_of_fto_qualified_names_dict = { + file_path: {source.qualified_name for source in sources} for file_path, sources in helpers_of_fto_dict.items() + } + + # __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist) + # This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO + for qualified_names in helpers_of_fto_qualified_names_dict.values(): + qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn}) + + helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi( + helpers_of_fto_qualified_names_dict, project_root_path + ) + + # Extract code context for optimization + final_read_writable_code = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + {}, + project_root_path, + remove_docstrings=False, + code_context_type=CodeContextType.READ_WRITABLE, + ) + + # Ensure the target file is first in the code blocks so the LLM knows which file to optimize + target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) + target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative] + other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative] + if target_blocks: + final_read_writable_code.code_strings = target_blocks + other_blocks + + read_only_code_markdown = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=False, + code_context_type=CodeContextType.READ_ONLY, + ) + hashing_code_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + code_context_type=CodeContextType.HASHING, + ) + + # Handle token limits + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) + if final_read_writable_tokens > optim_token_limit: + raise ValueError(READ_WRITABLE_LIMIT_ERROR) + + # Setup preexisting objects for code replacer + preexisting_objects = set( + chain( + *(find_preexisting_objects(codestring.code) for codestring in final_read_writable_code.code_strings), + *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), + ) + ) + read_only_context_code = read_only_code_markdown.markdown + + # Progressive fallback for read-only context token limits + read_only_tokens = encoded_tokens_len(read_only_context_code) + if final_read_writable_tokens + read_only_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") + read_only_code_no_docstrings = extract_code_markdown_context_from_files( + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True + ) + read_only_context_code = read_only_code_no_docstrings.markdown + if final_read_writable_tokens + encoded_tokens_len(read_only_context_code) > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing read-only code") + read_only_context_code = "" + + # Progressive fallback for testgen context token limits + testgen_context = build_testgen_context( + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize + ) + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context exceeded token limit, removing docstrings") + testgen_context = build_testgen_context( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + function_to_optimize=function_to_optimize, + ) + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context still exceeded token limit, removing enrichment") + testgen_context = build_testgen_context( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + include_enrichment=False, + ) + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + raise ValueError(TESTGEN_LIMIT_ERROR) + code_hash_context = hashing_code_context.markdown + code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() + + all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list}) + + return CodeOptimizationContext( + testgen_context=testgen_context, + read_writable_code=final_read_writable_code, + read_only_context_code=read_only_context_code, + hashing_code_context=code_hash_context, + hashing_code_context_hash=code_hash, + helper_functions=helpers_of_fto_list, + testgen_helper_fqns=all_helper_fqns, + preexisting_objects=preexisting_objects, + ) + + +def process_file_context( + file_path: Path, + primary_qualified_names: set[str], + secondary_qualified_names: set[str], + code_context_type: CodeContextType, + remove_docstrings: bool, + project_root_path: Path, + helper_functions: list[FunctionSource], +) -> CodeString | None: + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception("Error while parsing %s: %s", file_path, e) + return None + + try: + all_names = primary_qualified_names | secondary_qualified_names + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names) + pruned_module = parse_code_and_prune_cst( + code_without_unused_defs, + code_context_type, + primary_qualified_names, + secondary_qualified_names, + remove_docstrings, + ) + except ValueError as e: + logger.debug("Error while getting read-only code: %s", e) + return None + + if pruned_module.code.strip(): + if code_context_type == CodeContextType.HASHING: + code_context = ast.unparse(ast.parse(pruned_module.code)) + else: + code_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=pruned_module, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + try: + relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + relative_path = file_path + return CodeString(code=code_context, file_path=relative_path) + return None + + +def extract_code_markdown_context_from_files( + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], + project_root_path: Path, + remove_docstrings: bool = False, + code_context_type: CodeContextType = CodeContextType.READ_ONLY, +) -> CodeStringsMarkdown: + """Extract code context from files containing target functions and their helpers, formatting them as markdown. + + This function processes two sets of files: + 1. Files containing the function to optimize (fto) and their first-degree helpers + 2. Files containing only helpers of helpers (with no overlap with the first set) + + For each file, it extracts relevant code based on the specified context type, adds necessary + imports, and combines them into a structured markdown format. + + Args: + ---- + helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers + helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions + project_root_path: Root path of the project + remove_docstrings: Whether to remove docstrings from the extracted code + code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) + + Returns: + ------- + CodeStringsMarkdown containing the extracted code context with necessary imports, + formatted for inclusion in markdown + + """ + # Rearrange to remove overlaps, so we only access each file path once + helpers_of_helpers_no_overlap = defaultdict(set) + for file_path, function_sources in helpers_of_helpers.items(): + if file_path in helpers_of_fto: + # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto + helpers_of_helpers[file_path] -= helpers_of_fto[file_path] + else: + helpers_of_helpers_no_overlap[file_path] = function_sources + code_context_markdown = CodeStringsMarkdown() + # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files + for file_path, function_sources in helpers_of_fto.items(): + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + helper_functions = list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=qualified_function_names, + secondary_qualified_names=helpers_of_helpers_qualified_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) + + if result is not None: + code_context_markdown.code_strings.append(result) + # Extract code from file paths containing helpers of helpers + for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + helper_functions = list(helpers_of_helpers_no_overlap.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=set(), + secondary_qualified_names=qualified_helper_function_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) + + if result is not None: + code_context_markdown.code_strings.append(result) + return code_context_markdown diff --git a/src/codeflash_python/context/cst_pruning.py b/src/codeflash_python/context/cst_pruning.py new file mode 100644 index 000000000..2f1508d57 --- /dev/null +++ b/src/codeflash_python/context/cst_pruning.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_python.context.types import CodeContextType +from codeflash_python.context.unused_definition_remover import ( + collect_top_level_defs_with_usages, + get_section_names, + is_assignment_used, + recurse_sections, +) + +if TYPE_CHECKING: + from codeflash_python.context.unused_definition_remover import UsageInfo + + +def is_dunder_method(name: str) -> bool: + return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") + + +def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: + """Removes the docstring from an indented block if it exists.""" + if not isinstance(indented_block.body[0], cst.SimpleStatementLine): + return indented_block + first_stmt = indented_block.body[0].body[0] + if isinstance(first_stmt, cst.Expr) and isinstance(first_stmt.value, cst.SimpleString): + return indented_block.with_changes(body=indented_block.body[1:]) + return indented_block + + +def parse_code_and_prune_cst( + code: str, + code_context_type: CodeContextType, + target_functions: set[str], + helpers_of_helper_functions: set[str] | None = None, + remove_docstrings: bool = False, +) -> cst.Module: + """Parse and filter the code CST, returning the pruned Module.""" + if helpers_of_helper_functions is None: + helpers_of_helper_functions = set() + module = cst.parse_module(code) + defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) + + if code_context_type == CodeContextType.READ_WRITABLE: + filtered_node, found_target = prune_cst( + module, target_functions, defs_with_usages=defs_with_usages, keep_class_init=True + ) + elif code_context_type == CodeContextType.READ_ONLY: + filtered_node, found_target = prune_cst( + module, + target_functions, + helpers=helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_target_in_output=False, + include_dunder_methods=True, + ) + elif code_context_type == CodeContextType.TESTGEN: + filtered_node, found_target = prune_cst( + module, + target_functions, + helpers=helpers_of_helper_functions, + remove_docstrings=remove_docstrings, + include_dunder_methods=True, + include_init_dunder=True, + ) + elif code_context_type == CodeContextType.HASHING: + filtered_node, found_target = prune_cst( + module, target_functions, remove_docstrings=True, exclude_init_from_targets=True + ) + else: + raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 + + if not found_target: + raise ValueError("No target functions found in the provided code") + if filtered_node and isinstance(filtered_node, cst.Module): + return filtered_node + raise ValueError("Pruning produced no module") + + +def prune_cst( + node: cst.CSTNode, + target_functions: set[str], + prefix: str = "", + *, + defs_with_usages: dict[str, UsageInfo] | None = None, + helpers: set[str] | None = None, + remove_docstrings: bool = False, + include_target_in_output: bool = True, + exclude_init_from_targets: bool = False, + keep_class_init: bool = False, + include_dunder_methods: bool = False, + include_init_dunder: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + """Unified function to prune CST nodes based on various filtering criteria. + + Args: + node: The CST node to filter + target_functions: Set of qualified function names that are targets + prefix: Current qualified name prefix (for class methods) + defs_with_usages: Dict of definitions with usage info (for READ_WRITABLE mode) + helpers: Set of helper function qualified names (for READ_ONLY/TESTGEN modes) + remove_docstrings: Whether to remove docstrings from output + include_target_in_output: Whether to include target functions in output + exclude_init_from_targets: Whether to exclude __init__ from targets (HASHING mode) + keep_class_init: Whether to keep __init__ methods in classes (READ_WRITABLE mode) + include_dunder_methods: Whether to include dunder methods (READ_ONLY/TESTGEN modes) + include_init_dunder: Whether to include __init__ in dunder methods + + Returns: + (filtered_node, found_target): + filtered_node: The modified CST node or None if it should be removed. + found_target: True if a target function was found in this node's subtree. + + """ + if isinstance(node, (cst.Import, cst.ImportFrom)): + return None, False + + if isinstance(node, cst.FunctionDef): + qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + + # Check if it's a helper function (higher priority than target) + if helpers and qualified_name in helpers: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True + return node, True + + # Check if it's a target function + if qualified_name in target_functions: + # Handle exclude_init_from_targets for HASHING mode + if exclude_init_from_targets and node.name.value == "__init__": + return None, False + + if include_target_in_output: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), True + return node, True + return None, True + + # Handle class __init__ for READ_WRITABLE mode + if keep_class_init and node.name.value == "__init__": + return node, False + + # Handle dunder methods for READ_ONLY/TESTGEN modes + if ( + include_dunder_methods + and len(node.name.value) > 4 + and node.name.value.startswith("__") + and node.name.value.endswith("__") + ): + if not include_init_dunder and node.name.value == "__init__": + return None, False + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + return node.with_changes(body=remove_docstring_from_body(node.body)), False + return node, False + + return None, False + + if isinstance(node, cst.ClassDef): + if prefix: + return None, False + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + class_prefix = node.name.value + class_name = node.name.value + + # Handle dependency classes for READ_WRITABLE mode + if defs_with_usages: + # Check if this class contains any target functions + has_target_functions = any( + isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions + for stmt in node.body.body + ) + + # If the class is used as a dependency (not containing target functions), keep it entirely + if ( + not has_target_functions + and class_name in defs_with_usages + and defs_with_usages[class_name].used_by_qualified_function + ): + return node, True + + # Recursively filter each statement in the class body + new_class_body: list[cst.CSTNode] = [] + found_in_class = False + + for stmt in node.body.body: + filtered, found_target = prune_cst( + stmt, + target_functions, + class_prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, + include_init_dunder=include_init_dunder, + ) + found_in_class |= found_target + if filtered: + new_class_body.append(filtered) + + if not found_in_class: + return None, False + + # Apply docstring removal to class if needed + if remove_docstrings and new_class_body: + updated_body = node.body.with_changes(body=new_class_body) + assert isinstance(updated_body, cst.IndentedBlock) + return node.with_changes(body=remove_docstring_from_body(updated_body)), True + + return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + + # Handle assignments for READ_WRITABLE mode + if defs_with_usages is not None: + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, defs_with_usages): + return node, True + return None, False + + # For other nodes, recursively process children + section_names = get_section_names(node) + if not section_names: + return node, False + + if helpers is not None: + return recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, + target_functions, + prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, + include_init_dunder=include_init_dunder, + ), + keep_non_target_children=True, + ) + return recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, + target_functions, + prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, + include_init_dunder=include_init_dunder, + ), + ) diff --git a/src/codeflash_python/context/jedi_helpers.py b/src/codeflash_python/context/jedi_helpers.py new file mode 100644 index 000000000..c5ef9242f --- /dev/null +++ b/src/codeflash_python/context/jedi_helpers.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +import os +from collections import defaultdict +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionToOptimize # noqa: TC001 +from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages +from codeflash_python.context.utils import get_qualified_name +from codeflash_python.models.models import FunctionSource + +if TYPE_CHECKING: + from jedi.api.classes import Name + +logger = logging.getLogger("codeflash_python") + + +@cache +def get_jedi_project(project_root_path: str): # noqa: ANN201 + import sys + + import jedi + + return jedi.Project(path=project_root_path, added_sys_path=list(sys.path)) + + +def get_function_to_optimize_as_function_source( + function_to_optimize: FunctionToOptimize, project_root_path: Path +) -> FunctionSource: + import jedi + + # Use jedi to find function to optimize + script = jedi.Script(path=function_to_optimize.file_path, project=get_jedi_project(str(project_root_path))) + + # Get all names in the file + names = script.get_names(all_scopes=True, definitions=True, references=False) + + # Find the name that matches our function + for name in names: + try: + if ( + name.type == "function" + and name.full_name + and name.name == function_to_optimize.function_name + and name.full_name.startswith(name.module_name) + and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name + ): + return FunctionSource( + file_path=function_to_optimize.file_path, + qualified_name=function_to_optimize.qualified_name, + fully_qualified_name=name.full_name, + only_function_name=name.name, + source_code=name.get_line_code(), + ) + except Exception as e: + logger.exception("Error while getting function source: %s", e) + continue + raise ValueError( + f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" # noqa: EM102 + ) + + +def get_function_sources_from_jedi( + file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path +) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + import jedi + + file_path_to_function_source = defaultdict(set) + function_source_list: list[FunctionSource] = [] + for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): + script = jedi.Script(path=file_path, project=get_jedi_project(str(project_root_path))) + file_refs = script.get_names(all_scopes=True, definitions=False, references=True) + + for qualified_function_name in qualified_function_names: + names = [ + ref + for ref in file_refs + if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name) + ] + for name in names: + try: + definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) + except Exception: + logger.debug("Error while getting definitions for %s", qualified_function_name) + definitions = [] + if definitions: + # TODO: there can be multiple definitions, see how to handle such cases + definition = definitions[0] + definition_path = definition.module_path + if definition_path is not None: + try: + rel = definition_path.resolve().relative_to(project_root_path.resolve()) + definition_path = project_root_path / rel + except ValueError: + pass + + # The definition is part of this project and not defined within the original function + is_valid_definition = ( + definition_path is not None + and not path_belongs_to_site_packages(definition_path) + and str(definition_path).startswith(str(project_root_path) + os.sep) + and definition.full_name + and not belongs_to_function_qualified(definition, qualified_function_name) + and definition.full_name.startswith(definition.module_name) + ) + if is_valid_definition and definition.type in ("function", "class", "statement"): + assert definition_path is not None + if definition.type == "function": + fqn = definition.full_name + func_name = definition.name + elif definition.type == "class": + fqn = f"{definition.full_name}.__init__" + func_name = "__init__" + else: + fqn = definition.full_name + func_name = definition.name + qualified_name = get_qualified_name(definition.module_name, fqn) + # Avoid self-references (recursive calls) and nested functions/classes + if qualified_name == qualified_function_name: + continue + if len(qualified_name.split(".")) <= 2: + function_source = FunctionSource( + file_path=definition_path, + qualified_name=qualified_name, + fully_qualified_name=fqn, + only_function_name=func_name, + source_code=definition.get_line_code(), + definition_type=definition.type, + ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) + + return file_path_to_function_source, function_source_list + + +def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: + """Check if the given name belongs to the specified method.""" + return belongs_to_function(name, method_name) and belongs_to_class(name, class_name) + + +def belongs_to_function(name: Name, function_name: str) -> bool: + """Check if the given jedi Name is a direct child of the specified function.""" + if name.name == function_name: # Handles function definition and recursive function calls + return False + if (name := name.parent()) and name.type == "function": + return bool(name.name == function_name) + return False + + +def belongs_to_class(name: Name, class_name: str) -> bool: + """Check if given jedi Name is a direct child of the specified class.""" + while name := name.parent(): + if name.type == "class": + return bool(name.name == class_name) + return False + + +def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool: + """Check if the given jedi Name is a direct child of the specified function, matched by qualified function name.""" + try: + if ( + name.full_name.startswith(name.module_name) + and get_qualified_name(name.module_name, name.full_name) == qualified_function_name + ): + # Handles function definition and recursive function calls + return False + if (name := name.parent()) and name.type == "function": + return get_qualified_name(name.module_name, name.full_name) == qualified_function_name + return False + except ValueError: + return False diff --git a/src/codeflash_python/context/type_extraction.py b/src/codeflash_python/context/type_extraction.py new file mode 100644 index 000000000..8f4840397 --- /dev/null +++ b/src/codeflash_python/context/type_extraction.py @@ -0,0 +1,249 @@ +"""Parameter type constructor extraction and import analysis for class context.""" + +from __future__ import annotations + +import ast +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages +from codeflash_python.context.ast_helpers import ( + BUILTIN_AND_TYPING_NAMES, + collect_import_aliases, + collect_type_names_from_annotation, + find_class_node_by_name, + is_project_subpath, +) +from codeflash_python.context.class_extraction import ( + append_project_class_context, + collect_synthetic_constructor_type_names, + extract_init_stub_from_class, + get_module_source_and_tree, + should_use_raw_project_class_context, +) +from codeflash_python.context.jedi_helpers import get_jedi_project +from codeflash_python.models.models import CodeString, CodeStringsMarkdown + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + +logger = logging.getLogger("codeflash_python") + + +def collect_type_names_from_function( + func_node: ast.FunctionDef | ast.AsyncFunctionDef, tree: ast.Module, class_name: str | None +) -> set[str]: + type_names: set[str] = set() + for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs: + type_names |= collect_type_names_from_annotation(arg.annotation) + if func_node.args.vararg: + type_names |= collect_type_names_from_annotation(func_node.args.vararg.annotation) + if func_node.args.kwarg: + type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation) + for body_node in ast.walk(func_node): + if ( + isinstance(body_node, ast.Call) + and isinstance(body_node.func, ast.Name) + and body_node.func.id == "isinstance" + ): + if len(body_node.args) >= 2: + second_arg = body_node.args[1] + if isinstance(second_arg, ast.Name): + type_names.add(second_arg.id) + elif isinstance(second_arg, ast.Tuple): + for elt in second_arg.elts: + if isinstance(elt, ast.Name): + type_names.add(elt.id) + elif isinstance(body_node, ast.Compare): + if ( + isinstance(body_node.left, ast.Call) + and isinstance(body_node.left.func, ast.Name) + and body_node.left.func.id == "type" + ): + for comparator in body_node.comparators: + if isinstance(comparator, ast.Name): + type_names.add(comparator.id) + if class_name is not None: + for top_node in ast.walk(tree): + if isinstance(top_node, ast.ClassDef) and top_node.name == class_name: + for base in top_node.bases: + if isinstance(base, ast.Name): + type_names.add(base.id) + break + return type_names + + +def build_import_from_map(tree: ast.Module) -> dict[str, str]: + import_map: dict[str, str] = {} + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + import_map[alias.asname if alias.asname else alias.name] = node.module + return import_map + + +def extract_parameter_type_constructors( + function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str] +) -> CodeStringsMarkdown: + import jedi + + try: + source = function_to_optimize.file_path.read_text(encoding="utf-8") + tree = ast.parse(source) + except Exception: + return CodeStringsMarkdown(code_strings=[]) + + func_node = None + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == function_to_optimize.function_name + ): + if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line: + continue + func_node = node + break + if func_node is None: + return CodeStringsMarkdown(code_strings=[]) + + type_names = collect_type_names_from_function(func_node, tree, function_to_optimize.class_name) + type_names -= BUILTIN_AND_TYPING_NAMES + type_names -= existing_class_names + if not type_names: + return CodeStringsMarkdown(code_strings=[]) + + import_map = build_import_from_map(tree) + + code_strings: list[CodeString] = [] + module_cache: dict[Path, tuple[str, ast.Module]] = {} + emitted_classes: set[tuple[Path, str]] = set() + emitted_class_names: set[str] = set() + + def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None: + try: + script_code = f"from {module_name} import {type_name}" + script = jedi.Script(script_code, project=get_jedi_project(str(project_root_path))) + definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True) + if not definitions: + return + + module_path = definitions[0].module_path + if not module_path: + return + resolved_module = module_path.resolve() + module_str = str(resolved_module) + is_project = is_project_subpath(module_path, project_root_path) + is_third_party = "site-packages" in module_str + if transitive and not is_project and not is_third_party: + return + + module_result = get_module_source_and_tree(module_path, module_cache) + if module_result is None: + return + mod_source, mod_tree = module_result + + class_key = (module_path, type_name) + if class_key in emitted_classes or type_name in existing_class_names: + return + + class_node = find_class_node_by_name(type_name, mod_tree) + if class_node is not None and is_project: + import_aliases = collect_import_aliases(mod_tree) + if should_use_raw_project_class_context(class_node, import_aliases): + if append_project_class_context( + type_name, + module_path, + project_root_path, + module_cache, + existing_class_names, + emitted_classes, + emitted_class_names, + code_strings, + ): + return + + stub = extract_init_stub_from_class(type_name, mod_source, mod_tree) + if stub: + code_strings.append(CodeString(code=stub, file_path=module_path)) + emitted_classes.add(class_key) + emitted_class_names.add(type_name) + except Exception: + if transitive: + logger.debug("Error extracting transitive constructor stub for %s from %s", type_name, module_name) + else: + logger.debug("Error extracting constructor stub for %s from %s", type_name, module_name) + + for type_name in sorted(type_names): + module_name = import_map.get(type_name) + if not module_name: + continue + append_type_context(type_name, module_name) + + # Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs + transitive_import_map = dict(import_map) + for _, cached_tree in module_cache.values(): + for name, module in build_import_from_map(cached_tree).items(): + transitive_import_map.setdefault(name, module) + + emitted_names = type_names | existing_class_names | emitted_class_names | BUILTIN_AND_TYPING_NAMES + transitive_type_names: set[str] = set() + for cs in code_strings: + try: + stub_tree = ast.parse(cs.code) + except SyntaxError: + continue + import_aliases = collect_import_aliases(stub_tree) + for stub_node in ast.walk(stub_tree): + if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in ( + "__init__", + "__post_init__", + ): + for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs: + transitive_type_names |= collect_type_names_from_annotation(arg.annotation) + elif isinstance(stub_node, ast.ClassDef): + transitive_type_names |= collect_synthetic_constructor_type_names(stub_node, import_aliases) + transitive_type_names -= emitted_names + for type_name in sorted(transitive_type_names): + module_name = transitive_import_map.get(type_name) + if not module_name: + continue + append_type_context(type_name, module_name, transitive=True) + + return CodeStringsMarkdown(code_strings=code_strings) + + +def is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool: + cached = cache.get(module_name) + if cached is not None: + return cached + is_project = is_project_module(module_name, project_root_path) + cache[module_name] = is_project + return is_project + + +def is_project_path(module_path: Path | None, project_root_path: Path) -> bool: + if module_path is None: + return False + # site-packages must be checked first because .venv/site-packages is under project root + if path_belongs_to_site_packages(module_path): + return False + try: + module_path.resolve().relative_to(project_root_path.resolve()) + return True + except ValueError: + return False + + +def is_project_module(module_name: str, project_root_path: Path) -> bool: + """Check if a module is part of the project (not external/stdlib).""" + import importlib.util + + try: + spec = importlib.util.find_spec(module_name) + except (ImportError, ModuleNotFoundError, ValueError): + return False + else: + if spec is None or spec.origin is None: + return False + return is_project_path(Path(spec.origin), project_root_path) diff --git a/src/codeflash_python/context/types.py b/src/codeflash_python/context/types.py new file mode 100644 index 000000000..4cad143d1 --- /dev/null +++ b/src/codeflash_python/context/types.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from pathlib import Path + + from codeflash_python.models.models import FunctionSource + +from codeflash_core.models import HelperFunction + + +class CodeContextType(str, Enum): + READ_WRITABLE = "READ_WRITABLE" + READ_ONLY = "READ_ONLY" + TESTGEN = "TESTGEN" + HASHING = "HASHING" + + +@dataclass(frozen=True) +class IndexResult: + file_path: Path + cached: bool + num_edges: int + edges: tuple[tuple[str, str, bool], ...] # (caller_qn, callee_name, is_cross_file) + cross_file_edges: int + error: bool + + +@dataclass +class PythonCodeContext: + """Code context extracted for optimization. + + Contains the target function code and all relevant dependencies + needed for the AI to understand and optimize the function. + + Attributes: + target_code: Source code of the function to optimize. + target_file: Path to the file containing the target function. + helper_functions: List of helper functions called by the target. + read_only_context: Additional context code (read-only dependencies). + imports: List of import statements needed. + language: The programming language. + + """ + + target_code: str + target_file: Path + helper_functions: list[HelperFunction] = field(default_factory=list) + read_only_context: str = "" + imported_type_skeletons: str = "" + imports: list[str] = field(default_factory=list) + language: str = "python" + + +@dataclass +class ReferenceInfo: + """Information about a reference (call site) to a function. + + This class captures information about where a function is called + from, including the file, line number, context, and caller function. + + Attributes: + file_path: Path to the file containing the reference. + line: Line number (1-indexed). + column: Column number (0-indexed). + end_line: End line number (1-indexed). + end_column: End column number (0-indexed). + context: The line of code containing the reference. + reference_type: Type of reference ("call", "callback", "memoized", "import", "reexport"). + import_name: Name used to import the function (may differ from original). + caller_function: Name of the function containing this reference (or None for module-level). + + """ + + file_path: Path + line: int + column: int + end_line: int + end_column: int + context: str + reference_type: str + import_name: str | None + caller_function: str | None = None + + +def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]: + """Convert FunctionSource objects to HelperFunction objects.""" + return [ + HelperFunction( + name=fs.only_function_name, + qualified_name=fs.qualified_name, + file_path=fs.file_path, + source_code=fs.source_code, + start_line=1, # TODO: FunctionSource should carry real line numbers from jedi definitions + end_line=fs.source_code.count("\n") + 1, + ) + for fs in sources + ] + + +@runtime_checkable +class DependencyResolver(Protocol): + """Protocol for language-specific dependency resolution. + + Implementations analyze source files to discover call-graph edges + between functions so the optimizer can extract richer context. + """ + + def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None: + """Pre-index a batch of files.""" + ... + + def get_callees( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + """Return callees for the given functions.""" + ... + + def count_callees_per_function( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> dict[tuple[Path, str], int]: + """Return the number of callees for each (file_path, qualified_name) pair.""" + ... + + def close(self) -> None: + """Release resources (e.g. database connections).""" + ... diff --git a/src/codeflash_python/context/unused_definition_remover.py b/src/codeflash_python/context/unused_definition_remover.py new file mode 100644 index 000000000..6b366f7c0 --- /dev/null +++ b/src/codeflash_python/context/unused_definition_remover.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import libcst as cst + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger("codeflash_python") + + +@dataclass +class UsageInfo: + """Information about a name and its usage.""" + + name: str + used_by_qualified_function: bool = False + dependencies: set[str] = field(default_factory=set) + + +def extract_names_from_targets(target: cst.CSTNode) -> list[str]: + """Extract all variable names from a target node, including from tuple unpacking.""" + names = [] + + # Handle a simple name + if isinstance(target, cst.Name): + names.append(target.value) + + # Handle any node with a value attribute (StarredElement, etc.) + elif hasattr(target, "value"): + names.extend(extract_names_from_targets(target.value)) # type: ignore[arg-type] + + # Handle any node with elements attribute (tuples, lists, etc.) + elif hasattr(target, "elements"): + for element in target.elements: # type: ignore[attr-defined] + # Recursive call for each element + names.extend(extract_names_from_targets(element)) + + return names + + +def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool: + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + return False + + +def recurse_sections( + node: cst.CSTNode, + section_names: list[str], + prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]], + keep_non_target_children: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_fn(child) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + if keep_non_target_children: + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif section_found_target: + found_any_target = True + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_fn(original_content) + if keep_non_target_children: + found_any_target |= found_target + if filtered: + updates[section] = filtered + elif found_target: + found_any_target = True + if filtered: + updates[section] = filtered + if keep_non_target_children: + if updates: + return node.with_changes(**updates), found_any_target + return None, False + if not found_any_target: + return None, False + return (node.with_changes(**updates) if updates else node), True + + +def collect_top_level_definitions( + node: cst.CSTNode, definitions: dict[str, UsageInfo] | None = None +) -> dict[str, UsageInfo]: + """Recursively collect all top-level variable, function, and class definitions.""" + # Locally bind types and helpers for faster lookup + FunctionDef = cst.FunctionDef # noqa: N806 + ClassDef = cst.ClassDef # noqa: N806 + Assign = cst.Assign # noqa: N806 + AnnAssign = cst.AnnAssign # noqa: N806 + AugAssign = cst.AugAssign # noqa: N806 + IndentedBlock = cst.IndentedBlock # noqa: N806 + + if definitions is None: + definitions = {} + + # Speed: Single isinstance+local var instead of several type calls + node_type = type(node) + # Fast path: function def + if node_type is FunctionDef: + name = node.name.value # type: ignore[attr-defined] + definitions[name] = UsageInfo( + name=name, + used_by_qualified_function=False, # Will be marked later if in qualified functions + ) + return definitions + + # Fast path: class def + if node_type is ClassDef: + name = node.name.value # type: ignore[attr-defined] + definitions[name] = UsageInfo(name=name) + + # Collect class methods + body = getattr(node, "body", None) + if body is not None and type(body) is IndentedBlock: + statements = body.body + # Precompute f-string template for efficiency + prefix = name + "." + for statement in statements: + if type(statement) is FunctionDef: + method_name = prefix + statement.name.value + definitions[method_name] = UsageInfo(name=method_name) + + return definitions + + # Fast path: assignment + if node_type is Assign: + # Inline extract_names_from_targets for single-target speed + targets = node.targets # type: ignore[attr-defined] + append_def = definitions.__setitem__ + for target in targets: + names = extract_names_from_targets(target.target) + for name in names: + append_def(name, UsageInfo(name=name)) + return definitions + + if node_type is AnnAssign or node_type is AugAssign: + tgt = node.target # type: ignore[attr-defined] + if type(tgt) is cst.Name: + name = tgt.value + definitions[name] = UsageInfo(name=name) + else: + names = extract_names_from_targets(tgt) + for name in names: + definitions[name] = UsageInfo(name=name) + return definitions + + # Recursively process children. Takes care of top level assignments in if/else/while/for blocks + section_names = get_section_names(node) + + if section_names: + getattr_ = getattr + for section in section_names: + original_content = getattr_(node, section, None) + # Instead of isinstance check for list/tuple, rely on duck-type via iter + # If section contains a list of nodes + if isinstance(original_content, (list, tuple)): + defs = definitions # Move out for minor speed + for child in original_content: + collect_top_level_definitions(child, defs) + # If section contains a single node + elif original_content is not None: + collect_top_level_definitions(original_content, definitions) + + return definitions + + +def get_section_names(node: cst.CSTNode) -> list[str]: + """Return the section attribute names (e.g., body, orelse) for a given node if they exist.""" + possible_sections = ["body", "orelse", "finalbody", "handlers"] + return [sec for sec in possible_sections if hasattr(node, sec)] + + +class DependencyCollector(cst.CSTVisitor): + """Collects dependencies between definitions using the visitor pattern with depth tracking.""" + + METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) + + def __init__(self, definitions: dict[str, UsageInfo]) -> None: + super().__init__() + self.definitions = definitions + # Track function and class depths + self.function_depth = 0 + self.class_depth = 0 + # Track top-level qualified names + self.current_top_level_name = "" + self.current_class = "" + # Track if we're processing a top-level variable + self.processing_variable = False + self.current_variable_names = set() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + function_name = node.name.value + + if self.function_depth == 0: + # This is a top-level function + if self.class_depth > 0: + # If inside a class, we're now tracking dependencies at the class level + self.current_top_level_name = f"{self.current_class}.{function_name}" + else: + # Regular top-level function + self.current_top_level_name = function_name + + # Check parameter type annotations for dependencies + if hasattr(node, "params") and node.params: + for param in node.params.params: + if param.annotation: + # Visit the annotation to extract dependencies + self.collect_annotation_dependencies(param.annotation) + + self.function_depth += 1 + + def collect_annotation_dependencies(self, annotation: cst.Annotation) -> None: + """Extract dependencies from type annotations.""" + if hasattr(annotation, "annotation"): + # Extract names from annotation (could be Name, Attribute, Subscript, etc.) + self.extract_names_from_annotation(annotation.annotation) + + def extract_names_from_annotation(self, node: cst.CSTNode) -> None: + """Extract names from a type annotation node.""" + # Simple name reference like 'int', 'str', or custom type + if isinstance(node, cst.Name): + name = node.value + if ( + name in self.definitions + and name != self.current_top_level_name + and self.current_top_level_name + and self.current_top_level_name in self.definitions + ): + self.definitions[self.current_top_level_name].dependencies.add(name) + + # Handle compound annotations like List[int], Dict[str, CustomType], etc. + elif isinstance(node, cst.Subscript): + if hasattr(node, "value"): + self.extract_names_from_annotation(node.value) + if hasattr(node, "slice"): + for slice_item in node.slice: + if hasattr(slice_item, "slice"): + self.extract_names_from_annotation(slice_item.slice) + + # Handle attribute access like module.Type + elif isinstance(node, cst.Attribute): + if hasattr(node, "value"): + self.extract_names_from_annotation(node.value) + # No need to check the attribute name itself as it's likely not a top-level definition + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.function_depth -= 1 + + if self.function_depth == 0 and self.class_depth == 0: + # Exiting top-level function that's not in a class + self.current_top_level_name = "" + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + class_name = node.name.value + + if self.class_depth == 0: + # This is a top-level class + self.current_class = class_name + self.current_top_level_name = class_name + + # Track base classes as dependencies + for base in node.bases: + if isinstance(base.value, cst.Name): + base_name = base.value.value + if base_name in self.definitions and class_name in self.definitions: + self.definitions[class_name].dependencies.add(base_name) + elif isinstance(base.value, cst.Attribute): + # Handle cases like module.ClassName + attr_name = base.value.attr.value + if attr_name in self.definitions and class_name in self.definitions: + self.definitions[class_name].dependencies.add(attr_name) + + self.class_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.class_depth -= 1 + + if self.class_depth == 0: + # Exiting top-level class + self.current_class = "" + self.current_top_level_name = "" + + def visit_Assign(self, node: cst.Assign) -> None: + # Only handle top-level assignments + if self.function_depth == 0 and self.class_depth == 0: + for target in node.targets: + # Extract all variable names from the target + names = extract_names_from_targets(target.target) + + # Check if any of these names are top-level definitions we're tracking + tracked_names = [name for name in names if name in self.definitions] + if tracked_names: + self.processing_variable = True + self.current_variable_names.update(tracked_names) + # Use the first tracked name as the current top-level name (for dependency tracking) + self.current_top_level_name = tracked_names[0] + + def leave_Assign(self, original_node: cst.Assign) -> None: + if self.processing_variable: + self.processing_variable = False + self.current_variable_names.clear() + self.current_top_level_name = "" + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + # Extract names from the variable annotations + if hasattr(node, "annotation") and node.annotation: + # First mark we're processing a variable to avoid recording it as a dependency of itself + self.processing_variable = True + if isinstance(node.target, cst.Name): + self.current_variable_names.add(node.target.value) + else: + self.current_variable_names.update(extract_names_from_targets(node.target)) + + # Process the annotation + self.collect_annotation_dependencies(node.annotation) + + # Reset processing state + self.processing_variable = False + self.current_variable_names.clear() + + def visit_Name(self, node: cst.Name) -> None: + name = node.value + + # Skip if we're not inside a tracked definition + if not self.current_top_level_name or self.current_top_level_name not in self.definitions: + return + + # Skip if we're looking at the variable name itself in an assignment + if self.processing_variable and name in self.current_variable_names: + return + + if name in self.definitions and name != self.current_top_level_name: + # Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x') + # We only want to track the base/value of attribute access, not the attribute name itself + if self.class_depth > 0: + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if parent is not None and isinstance(parent, cst.Attribute): + # Check if this Name is the .attr (property name), not the .value (base) + # If it's the .attr, skip it - attribute names aren't references to definitions + if parent.attr is node: + return + # If it's the .value (base), only skip if it's self/cls + if name in ("self", "cls"): + return + self.definitions[self.current_top_level_name].dependencies.add(name) + + +class QualifiedFunctionUsageMarker: + """Marks definitions that are used by specific qualified functions.""" + + def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None: + self.definitions = definitions + self.qualified_function_names = qualified_function_names + self.expanded_qualified_functions = self.expand_qualified_functions() + + def expand_qualified_functions(self) -> set[str]: + """Expand the qualified function names to include related methods.""" + expanded = set(self.qualified_function_names) + + # Find class methods and add their containing classes and dunder methods + for qualified_name in list(self.qualified_function_names): + if "." in qualified_name: + class_name, _method_name = qualified_name.split(".", 1) + + # Add the class itself + expanded.add(class_name) + + # Add all dunder methods of the class + for name in self.definitions: + if name.startswith(f"{class_name}.__") and name.endswith("__"): + expanded.add(name) + + return expanded + + def mark_used_definitions(self) -> None: + """Find all qualified functions and mark them and their dependencies as used.""" + # Avoid list comprehension for set intersection + expanded_names = self.expanded_qualified_functions + defs = self.definitions + # Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently) + fnames = ( + expanded_names & defs.keys() + if isinstance(expanded_names, set) + else [name for name in expanded_names if name in defs] + ) + + # For each specified function, mark it and all its dependencies as used + for func_name in fnames: + defs[func_name].used_by_qualified_function = True + for dep in defs[func_name].dependencies: + self.mark_as_used_recursively(dep) + + def mark_as_used_recursively(self, name: str) -> None: + """Mark a name and all its dependencies as used recursively.""" + if name not in self.definitions: + return + + if self.definitions[name].used_by_qualified_function: + return # Already marked + + self.definitions[name].used_by_qualified_function = True + + # Mark all dependencies as used + for dep in self.definitions[name].dependencies: + self.mark_as_used_recursively(dep) + + +def remove_unused_definitions_recursively( + node: cst.CSTNode, definitions: dict[str, UsageInfo] +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node to remove unused definitions. + + Args: + ---- + node: The CST node to process + definitions: Dictionary of definition info + + Returns: + ------- + (filtered_node, used_by_function): + filtered_node: The modified CST node or None if it should be removed + used_by_function: True if this node or any child is used by qualified functions + + """ + # Skip import statements + if isinstance(node, (cst.Import, cst.ImportFrom)): + return node, True + + # Never remove function definitions + if isinstance(node, cst.FunctionDef): + return node, True + + # Never remove class definitions + if isinstance(node, cst.ClassDef): + class_name = node.name.value + + # Check if any methods or variables in this class are used + method_or_var_used = False + class_has_dependencies = False + + # Check if class itself is marked as used + if class_name in definitions and definitions[class_name].used_by_qualified_function: + class_has_dependencies = True + + if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): + updates = {} + new_statements = [] + + for statement in node.body.body: + # Keep all function definitions + if isinstance(statement, cst.FunctionDef): + method_name = f"{class_name}.{statement.name.value}" + if method_name in definitions and definitions[method_name].used_by_qualified_function: + method_or_var_used = True + new_statements.append(statement) + # Only process variable assignments + elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + var_used = False + + if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."): + var_used = True + method_or_var_used = True + + if var_used or class_has_dependencies: + new_statements.append(statement) + else: + # Keep all other statements in the class + new_statements.append(statement) + + # Update the class body + new_body = node.body.with_changes(body=new_statements) + updates["body"] = new_body + + return node.with_changes(**updates), True + + return node, method_or_var_used or class_has_dependencies + + # Handle assignments (Assign, AnnAssign, AugAssign) + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, definitions): + return node, True + return None, False + + # For other nodes, recursively process children + section_names = get_section_names(node) + if not section_names: + return node, False + return recurse_sections( + node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions) + ) + + +def collect_top_level_defs_with_usages( + code: str | cst.Module, qualified_function_names: set[str] +) -> dict[str, UsageInfo]: + """Collect all top level definitions (classes, variables or functions) and their usages.""" + module = code if isinstance(code, cst.Module) else cst.parse_module(code) + # Collect all definitions (top level classes, variables or function) + definitions = collect_top_level_definitions(module) + + # Collect dependencies between definitions using the visitor pattern + wrapper = cst.MetadataWrapper(module) + dependency_collector = DependencyCollector(definitions) + wrapper.visit(dependency_collector) + + # Mark definitions used by specified functions, and their dependencies recursively + usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names) + usage_marker.mark_used_definitions() + return definitions + + +def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: + """Analyze a file and remove top level definitions not used by specified functions. + + Top level definitions, in this context, are only classes, variables or functions. + If a class is referenced by a qualified function, we keep the entire class. + + Args: + ---- + code: The code to process + qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' + + """ + try: + module = cst.parse_module(code) + except Exception as e: + logger.debug("Failed to parse code with libcst: %s: %s", type(e).__name__, e) + return code + + try: + defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names) + + # Apply the recursive removal transformation + modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages) + + return modified_module.code if modified_module else "" # type: ignore[unresolved-attribute] + except Exception as e: + # If any other error occurs during processing, return the original code + logger.debug("Error processing code to remove unused definitions: %s: %s", type(e).__name__, e) + return code diff --git a/src/codeflash_python/context/unused_helper_detection.py b/src/codeflash_python/context/unused_helper_detection.py new file mode 100644 index 000000000..e7a375f92 --- /dev/null +++ b/src/codeflash_python/context/unused_helper_detection.py @@ -0,0 +1,313 @@ +"""Detection and reversion of unused helper functions in optimized code.""" + +from __future__ import annotations + +import ast +import logging +from collections import defaultdict +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.models.models import CodeString, CodeStringsMarkdown +from codeflash_python.static_analysis.code_replacer import replace_function_definitions_in_module + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + from codeflash_python.models.models import CodeOptimizationContext, FunctionSource + + +logger = logging.getLogger("codeflash_python") + + +def revert_unused_helper_functions( + project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] +) -> None: + """Revert unused helper functions back to their original definitions. + + Args: + project_root: project_root + unused_helpers: List of unused helper functions to revert + original_helper_code: Dictionary mapping file paths to their original code + + """ + if not unused_helpers: + return + + logger.debug("Reverting %s unused helper function(s) to original definitions", len(unused_helpers)) + + # Resolve all path keys for consistent comparison (Windows 8.3 short names may differ from Jedi-resolved paths) + resolved_original_helper_code = {p.resolve(): code for p, code in original_helper_code.items()} + + # Group unused helpers by file path + unused_helpers_by_file = defaultdict(list) + for helper in unused_helpers: + unused_helpers_by_file[helper.file_path.resolve()].append(helper) + + # For each file, revert the unused helper functions to their original definitions + for file_path, helpers_in_file in unused_helpers_by_file.items(): + if file_path in resolved_original_helper_code: + try: + # Get original code for this file + original_code = resolved_original_helper_code[file_path] + + # Use the code replacer to selectively revert only the unused helper functions + helper_names = [helper.qualified_name for helper in helpers_in_file] + reverted_code = replace_function_definitions_in_module( + function_names=helper_names, + optimized_code=CodeStringsMarkdown( + code_strings=[ + CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root)) + ] + ), # Use original code as the "optimized" code to revert + module_abspath=file_path, + preexisting_objects=set(), # Empty set since we're reverting + project_root_path=project_root, + should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice. + ) + + if reverted_code: + logger.debug("Reverted unused helpers in %s: %s", file_path, ", ".join(helper_names)) + + except Exception as e: + logger.exception("Error reverting unused helpers in %s: %s", file_path, e) + + +def analyze_imports_in_optimized_code( + optimized_ast: ast.AST, code_context: CodeOptimizationContext +) -> dict[str, set[str]]: + """Analyze import statements in optimized code to map imported names to qualified helper names. + + Args: + optimized_ast: The AST of the optimized code + code_context: The code optimization context containing helper functions + + Returns: + Dictionary mapping imported names to sets of possible qualified helper names + + """ + imported_names_map = defaultdict(set) + + # Precompute a two-level dict: module_name -> func_name -> [helpers] + helpers_by_file_and_func = defaultdict(dict) + helpers_by_file = defaultdict(list) # preserved for "import module" + for helper in code_context.helper_functions: + jedi_type = helper.definition_type + if jedi_type != "class": # Include when definition_type is None (non-Python) + func_name = helper.only_function_name + module_name = helper.file_path.stem + # Cache function lookup for this (module, func) + helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) + helpers_by_file[module_name].append(helper) + + # Collect only import nodes to avoid per-node isinstance checks across the whole AST + class _ImportCollector(ast.NodeVisitor): + def __init__(self) -> None: + self.nodes: list[ast.AST] = [] + + def visit_Import(self, node: ast.Import) -> None: + self.nodes.append(node) + # No need to recurse further for import nodes + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.nodes.append(node) + # No need to recurse further for import-from nodes + + collector = _ImportCollector() + collector.visit(optimized_ast) + + for node in collector.nodes: + if isinstance(node, ast.ImportFrom): + # Handle "from module import function" statements + module_name = node.module + if module_name: + file_entry = helpers_by_file_and_func.get(module_name) + if file_entry: + for alias in node.names: + imported_name = alias.asname if alias.asname else alias.name + original_name = alias.name + helpers = file_entry.get(original_name) + if helpers: + imported_set = imported_names_map[imported_name] + for helper in helpers: + imported_set.add(helper.qualified_name) + imported_set.add(helper.fully_qualified_name) + + elif isinstance(node, ast.Import): + # Handle "import module" statements + for alias in node.names: + imported_name = alias.asname if alias.asname else alias.name + module_name = alias.name + helpers = helpers_by_file.get(module_name) + if helpers: + imported_set = imported_names_map[f"{imported_name}.{{func}}"] + for helper in helpers: + # For "import module" statements, functions would be called as module.function + full_call = f"{imported_name}.{helper.only_function_name}" + full_call_set = imported_names_map[full_call] + full_call_set.add(helper.qualified_name) + full_call_set.add(helper.fully_qualified_name) + + return dict(imported_names_map) + + +def find_target_node( + root: ast.AST, function_to_optimize: FunctionToOptimize +) -> ast.FunctionDef | ast.AsyncFunctionDef | None: + parents = function_to_optimize.parents + node = root + for parent in parents: + # Fast loop: directly look for the matching ClassDef in node.body + body = getattr(node, "body", None) + if not body: + return None + for child in body: + if isinstance(child, ast.ClassDef) and child.name == parent.name: + node = child + break + else: + return None + + # Now node is either the root or the target parent class; look for function + body = getattr(node, "body", None) + if not body: + return None + target_name = function_to_optimize.function_name + for child in body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name: + return child + return None + + +def detect_unused_helper_functions( + function_to_optimize: FunctionToOptimize, + code_context: CodeOptimizationContext, + optimized_code: str | CodeStringsMarkdown, +) -> list[FunctionSource]: + """Detect helper functions that are no longer called by the optimized entrypoint function. + + Args: + function_to_optimize: The function to optimize + code_context: The code optimization context containing helper functions + optimized_code: The optimized code to analyze + + Returns: + List of FunctionSource objects representing unused helper functions + + """ + if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0: + return list( + chain.from_iterable( + detect_unused_helper_functions(function_to_optimize, code_context, code.code) + for code in optimized_code.code_strings + ) + ) + + try: + # Parse the optimized code to analyze function calls and imports + optimized_ast = ast.parse(optimized_code) # type: ignore[call-overload] + + # Find the optimized entrypoint function + entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize) + + if not entrypoint_function_ast: + logger.debug("Could not find entrypoint function %s in optimized code", function_to_optimize.function_name) + return [] + + # First, analyze imports to build a mapping of imported names to their original qualified names + imported_names_map = analyze_imports_in_optimized_code(optimized_ast, code_context) + + # Extract all function calls and attribute references in the entrypoint function + called_function_names = {function_to_optimize.function_name} + for node in ast.walk(entrypoint_function_ast): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + # Regular function call: function_name() + called_name = node.func.id + called_function_names.add(called_name) + # Also add the qualified name if this is an imported function + mapped_names = imported_names_map.get(called_name) + if mapped_names: + called_function_names.update(mapped_names) + elif isinstance(node.func, ast.Attribute): + # Method call: obj.method() or self.method() or module.function() + if isinstance(node.func.value, ast.Name): + attr_name = node.func.attr + value_id = node.func.value.id + if value_id == "self": + # self.method_name() -> add both method_name and ClassName.method_name + called_function_names.add(attr_name) + # For class methods, also add the qualified name + if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: + class_name = function_to_optimize.parents[0].name + called_function_names.add(f"{class_name}.{attr_name}") + else: + called_function_names.add(attr_name) + full_call = f"{value_id}.{attr_name}" + called_function_names.add(full_call) + # Check if this is a module.function call that maps to a helper + mapped_names = imported_names_map.get(full_call) + if mapped_names: + called_function_names.update(mapped_names) + # Handle nested attribute access like obj.attr.method() + else: + called_function_names.add(node.func.attr) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + # Attribute reference without call: e.g. self._parse1 = self._parse_literal + # This covers methods used as callbacks, stored in variables, passed as arguments, etc. + attr_name = node.attr + value_id = node.value.id + if value_id == "self": + called_function_names.add(attr_name) + if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: + class_name = function_to_optimize.parents[0].name + called_function_names.add(f"{class_name}.{attr_name}") + else: + called_function_names.add(attr_name) + full_ref = f"{value_id}.{attr_name}" + called_function_names.add(full_ref) + mapped_names = imported_names_map.get(full_ref) + if mapped_names: + called_function_names.update(mapped_names) + + logger.debug("Functions called in optimized entrypoint: %s", called_function_names) + logger.debug("Imported names mapping: %s", imported_names_map) + + # Find helper functions that are no longer called + unused_helpers = [] + entrypoint_file_path = function_to_optimize.file_path + for helper_function in code_context.helper_functions: + jedi_type = helper_function.definition_type + if jedi_type != "class": # Include when definition_type is None (non-Python) + # Check if the helper function is called using multiple name variants + helper_qualified_name = helper_function.qualified_name + helper_simple_name = helper_function.only_function_name + helper_fully_qualified_name = helper_function.fully_qualified_name + + # Check membership efficiently - exit early on first match + if ( + helper_qualified_name in called_function_names + or helper_simple_name in called_function_names + or helper_fully_qualified_name in called_function_names + ): + is_called = True + # For cross-file helpers, also consider module-based calls + elif helper_function.file_path != entrypoint_file_path: + # Add potential module.function combinations + module_name = helper_function.file_path.stem + module_call = f"{module_name}.{helper_simple_name}" + is_called = module_call in called_function_names + else: + is_called = False + + if not is_called: + unused_helpers.append(helper_function) + logger.debug("Helper function %s is not called in optimized code", helper_qualified_name) + else: + logger.debug("Helper function %s is still called in optimized code", helper_qualified_name) + + except Exception as e: + logger.debug("Error detecting unused helper functions: %s", e) + return [] + else: + return unused_helpers diff --git a/src/codeflash_python/context/utils.py b/src/codeflash_python/context/utils.py new file mode 100644 index 000000000..bcb332f2c --- /dev/null +++ b/src/codeflash_python/context/utils.py @@ -0,0 +1,14 @@ +from __future__ import annotations + + +def get_qualified_name(module_name: str, full_qualified_name: str) -> str: + if not full_qualified_name: + msg = "full_qualified_name cannot be empty" + raise ValueError(msg) + if not full_qualified_name.startswith(module_name): + msg = f"{full_qualified_name} does not start with {module_name}" + raise ValueError(msg) + if module_name == full_qualified_name: + msg = f"{full_qualified_name} is the same as {module_name}" + raise ValueError(msg) + return full_qualified_name[len(module_name) + 1 :] diff --git a/src/codeflash_python/discovery/__init__.py b/src/codeflash_python/discovery/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/discovery/discover_unit_tests.py b/src/codeflash_python/discovery/discover_unit_tests.py new file mode 100644 index 000000000..c823f9610 --- /dev/null +++ b/src/codeflash_python/discovery/discover_unit_tests.py @@ -0,0 +1,509 @@ +# ruff: noqa: SLF001 +from __future__ import annotations + +import enum +import logging +import os +import pickle +import re +import subprocess +import unittest +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Callable, final + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize +from pydantic.dataclasses import dataclass + +from codeflash_python.code_utils.code_utils import ImportErrorPattern, get_run_tmp_file, module_name_from_file_path +from codeflash_python.code_utils.compat import SAFE_SYS_EXECUTABLE +from codeflash_python.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash_python.discovery.import_analyzer import filter_test_files_by_imports +from codeflash_python.discovery.tests_cache import TestsCache +from codeflash_python.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType +from codeflash_python.verification.addopts import custom_addopts + +if TYPE_CHECKING: + from codeflash_core.config import TestConfig + +logger = logging.getLogger("codeflash_python") + + +def existing_unit_test_count( + func: FunctionToOptimize, project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]] +) -> int: + key = f"{module_name_from_file_path_cached(func.file_path, project_root)}.{func.qualified_name}" + tests = function_to_tests.get(key, set()) + seen: set[tuple[Path, str | None, str]] = set() + for t in tests: + if t.tests_in_file.test_type != TestType.EXISTING_UNIT_TEST: + continue + tif = t.tests_in_file + base_name = tif.test_function.split("[", 1)[0] + seen.add((tif.test_file, tif.test_class, base_name)) + return len(seen) + + +@final +class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this + #: Tests passed. + OK = 0 + #: Tests failed. + TESTS_FAILED = 1 + #: pytest was interrupted. + INTERRUPTED = 2 + #: An internal error got in the way. + INTERNAL_ERROR = 3 + #: pytest was misused. + USAGE_ERROR = 4 + #: pytest couldn't find tests. + NO_TESTS_COLLECTED = 5 + + +@dataclass(frozen=True) +class TestFunction: + function_name: str + test_class: str | None + parameters: str | None + test_type: TestType + + +ERROR_PATTERN = re.compile(r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)") +PYTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"[\[\]]") +UNITTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"^test_\w+_\d+(?:_\w+)*") +UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX = re.compile(r"_\d+(?:_\w+)*$") +FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$") + + +def discover_unit_tests( + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: + framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} + strategy = framework_strategies.get(cfg.test_framework) + if not strategy: + error_message = f"Unsupported test framework: {cfg.test_framework}" + raise ValueError(error_message) + + # Extract all functions to optimize for import filtering + functions_to_optimize = None + if file_to_funcs_to_optimize: + functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list] + function_to_tests, num_discovered_tests, num_discovered_replay_tests = strategy( + cfg, discover_only_these_tests, functions_to_optimize + ) + return function_to_tests, num_discovered_tests, num_discovered_replay_tests + + +def discover_tests_pytest( + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: + tests_root = cfg.tests_root + project_root = cfg.project_root + + tmp_pickle_path = get_run_tmp_file("collected_tests.pkl") + with custom_addopts(): + run_kwargs = get_cross_platform_subprocess_run_args( + cwd=project_root, check=False, text=True, capture_output=True + ) + result = subprocess.run( # noqa: PLW1510 # type: ignore[call-overload] + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "pytest_new_process_discovery.py", + str(project_root), + str(tests_root), + str(tmp_pickle_path), + ], + **run_kwargs, + ) + try: + with tmp_pickle_path.open(mode="rb") as f: + exitcode, tests, pytest_rootdir = pickle.load(f) + except Exception as e: + tests, pytest_rootdir = [], None + logger.exception("Failed to discover tests: %s", e) + exitcode = -1 + finally: + tmp_pickle_path.unlink(missing_ok=True) + if exitcode != 0: + if exitcode == 2 and "ERROR collecting" in result.stdout: + # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after + match = ERROR_PATTERN.search(result.stdout) + error_section = match.group(1) if match else result.stdout + + logger.warning( + "Failed to collect tests. Pytest Exit code: %s=%s\n %s", + exitcode, + PytestExitCode(exitcode).name, + error_section, + ) + if "ModuleNotFoundError" in result.stdout: + match = ImportErrorPattern.search(result.stdout) + if match: + error_message = match.group() + logger.warning("⚠️ %s", error_message) + + elif 0 <= exitcode <= 5: + logger.warning("Failed to collect tests. Pytest Exit code: %s=%s", exitcode, PytestExitCode(exitcode).name) + else: + logger.warning("Failed to collect tests. Pytest Exit code: %s", exitcode) + else: + logger.debug("Pytest collection exit code: %s", exitcode) + if pytest_rootdir is not None: + cfg.tests_project_rootdir = Path(pytest_rootdir) + if discover_only_these_tests: + resolved_discover_only = {p.resolve() for p in discover_only_these_tests} + else: + resolved_discover_only = None + file_to_test_map: dict[Path, list[TestsInFile]] = defaultdict(list) + for test in tests: + if "__replay_test" in test["test_file"]: + test_type = TestType.REPLAY_TEST + elif "test_concolic_coverage" in test["test_file"]: + test_type = TestType.CONCOLIC_COVERAGE_TEST + else: + test_type = TestType.EXISTING_UNIT_TEST + + test_file_path = Path(test["test_file"]).resolve() + test_obj = TestsInFile( + test_file=test_file_path, + test_class=test["test_class"], + test_function=test["test_function"], + test_type=test_type, + ) + if resolved_discover_only and test_obj.test_file not in resolved_discover_only: + continue + file_to_test_map[test_obj.test_file].append(test_obj) + # Within these test files, find the project functions they are referring to and return their names/locations + return process_test_files(file_to_test_map, cfg, functions_to_optimize) + + +def discover_tests_unittest( + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: + tests_root: Path = cfg.tests_root + loader: unittest.TestLoader = unittest.TestLoader() + tests: unittest.TestSuite = loader.discover(str(tests_root)) + file_to_test_map: defaultdict[Path, list[TestsInFile]] = defaultdict(list) + + def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: + _test_function, _test_module, _test_suite_name = ( + _test._testMethodName, + _test.__class__.__module__, + _test.__class__.__qualname__, + ) + + _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") + _test_module_path = tests_root / _test_module_path + if not _test_module_path.exists() or ( + discover_only_these_tests and _test_module_path not in discover_only_these_tests + ): + return None + if "__replay_test" in str(_test_module_path): + test_type = TestType.REPLAY_TEST + elif "test_concolic_coverage" in str(_test_module_path): + test_type = TestType.CONCOLIC_COVERAGE_TEST + else: + test_type = TestType.EXISTING_UNIT_TEST + return TestsInFile( + test_file=_test_module_path, test_function=_test_function, test_type=test_type, test_class=_test_suite_name + ) + + for _test_suite in tests._tests: + for test_suite_2 in _test_suite._tests: # type: ignore[unresolved-attribute] + if not hasattr(test_suite_2, "_tests"): + logger.warning("Didn't find tests for %s", test_suite_2) + continue + + for test in test_suite_2._tests: + # some test suites are nested, so we need to go deeper + if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"): + for test_2 in test._tests: + if not hasattr(test_2, "_testMethodName"): + logger.warning("Didn't find tests for %s", test_2) # it goes deeper? + continue + details = get_test_details(test_2) + if details is not None: + file_to_test_map[details.test_file].append(details) + else: + details = get_test_details(test) + if details is not None: + file_to_test_map[details.test_file].append(details) + return process_test_files(file_to_test_map, cfg, functions_to_optimize) + + +def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]: + function_parts = function_name.split("_") + if len(function_parts) > 1 and function_parts[-1].isdigit(): + return True, "_".join(function_parts[:-1]), function_parts[-1] + + return False, function_name, None + + +def process_test_files( + file_to_test_map: dict[Path, list[TestsInFile]], + cfg: TestConfig, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: + import jedi + + project_root_path = cfg.project_root + test_framework = cfg.test_framework + + if functions_to_optimize: + target_function_names = {func.qualified_name for func in functions_to_optimize} + file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names) + + function_to_test_map = defaultdict(set) + num_discovered_tests = 0 + num_discovered_replay_tests = 0 + functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list) + if functions_to_optimize: + for function_to_optimize in functions_to_optimize: + functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize) + + # Set up sys_path for Jedi to resolve imports correctly + import sys + + jedi_sys_path = list(sys.path) + # Add project root and its parent to sys_path so modules can be imported + if str(project_root_path) not in jedi_sys_path: + jedi_sys_path.insert(0, str(project_root_path)) + parent_path = project_root_path.parent + if str(parent_path) not in jedi_sys_path: + jedi_sys_path.insert(0, str(parent_path)) + + jedi_project = jedi.Project(path=project_root_path, sys_path=jedi_sys_path) + + tests_cache = TestsCache(project_root_path) + for test_file, functions in file_to_test_map.items(): + file_hash = TestsCache.compute_file_hash(test_file) + + cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash) + + if cfg.use_cache and cached_function_to_test_map: + for qualified_name, test_set in cached_function_to_test_map.items(): + function_to_test_map[qualified_name].update(test_set) + + for function_called_in_test in test_set: + if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + num_discovered_tests += 1 + + continue + try: + script = jedi.Script(path=test_file, project=jedi_project) + test_functions = set() + + all_names = script.get_names(all_scopes=True, references=True) + all_names_top = script.get_names(all_scopes=True) + all_defs = [name for name in all_names if name.is_definition()] + + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + + except Exception as e: + logger.debug("Failed to get jedi script for %s: %s", test_file, e) + continue + + if test_framework == "pytest": + for function in functions: + if "[" in function.test_function: + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] + if function_name in top_level_functions: + test_functions.add( + TestFunction(function_name, function.test_class, parameters, function.test_type) + ) + elif function.test_function in top_level_functions: + test_functions.add( + TestFunction(function.test_function, function.test_class, None, function.test_type) + ) + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) + if base_name in top_level_functions: + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, + ) + ) + + elif test_framework == "unittest": + functions_to_search = [elem.test_function for elem in functions] + test_suites = {elem.test_class for elem in functions} + + matching_names = test_suites & top_level_classes.keys() + for matched_name in matching_names: + for def_name in all_defs: + if ( + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name + ): + for function in functions_to_search: + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) + + if is_parameterized and new_function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=parameters, + test_type=functions[0].test_type, + ) + ) + elif function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=None, + test_type=functions[0].test_type, + ) + ) + + test_functions_by_name = defaultdict(list) + for func in test_functions: + test_functions_by_name[func.function_name].append(func) + + test_function_names_set = set(test_functions_by_name.keys()) + relevant_names = [] + + names_with_full_name = [name for name in all_names if name.full_name is not None] + + for name in names_with_full_name: + match = FUNCTION_NAME_REGEX.search(name.full_name) + if match and match.group(1) in test_function_names_set: + relevant_names.append((name, match.group(1))) + + for name, scope in relevant_names: + try: + definition = name.goto(follow_imports=True, follow_builtin_imports=False) + except Exception as e: + logger.debug(str(e)) + continue + try: + if not definition or definition[0].type != "function": + # Fallback: Try to match against functions_to_optimize when Jedi can't resolve + # This handles cases where Jedi fails with pytest fixtures + if functions_to_optimize_by_name and name.name: + for func_to_opt in functions_to_optimize_by_name.get(name.name, []): + from codeflash_python.models.function_types import qualified_name_with_modules_from_root + + qualified_name_with_modules = qualified_name_with_modules_from_root( + func_to_opt, project_root_path + ) + + # Only add if this test actually tests the function we're optimizing + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: + if test_framework == "pytest": + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name + + function_to_test_map[qualified_name_with_modules].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=test_func.test_class, + test_function=scope_test_function, + test_type=test_func.test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), + ) + ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) + + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + + num_discovered_tests += 1 + continue + definition_obj = definition[0] + definition_path = str(definition_obj.module_path) + + project_root_str = str(project_root_path) + if ( + definition_path.startswith(project_root_str + os.sep) + and definition_obj.module_name != name.module_name + and definition_obj.full_name is not None + ): + # Pre-compute common values outside the inner loop + module_prefix = definition_obj.module_name + "." + full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}" + + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: + if test_framework == "pytest": + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=test_func.test_class, + test_function=scope_test_function, + test_type=test_func.test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), + ) + ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) + + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + + num_discovered_tests += 1 + except Exception as e: + logger.debug(str(e)) + continue + + tests_cache.flush() + + tests_cache.close() + + return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests + + +# Cache module name resolution to avoid repeated Path.resolve()/relative_to() calls +@lru_cache(maxsize=128) +def module_name_from_file_path_cached(file_path: Path, project_root: Path) -> str: + return module_name_from_file_path(file_path, project_root) diff --git a/src/codeflash_python/discovery/filter_criteria.py b/src/codeflash_python/discovery/filter_criteria.py new file mode 100644 index 000000000..533ea7ec3 --- /dev/null +++ b/src/codeflash_python/discovery/filter_criteria.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import fnmatch +import re +from dataclasses import dataclass, field + + +@dataclass +class FunctionFilterCriteria: + """Criteria for filtering which functions to discover. + + Attributes: + include_patterns: Glob patterns for functions to include. + exclude_patterns: Glob patterns for functions to exclude. + require_return: Only include functions with return statements. + include_async: Include async functions. + include_methods: Include class methods. + min_lines: Minimum number of lines in the function. + max_lines: Maximum number of lines in the function. + + """ + + include_patterns: list[str] = field(default_factory=list) + exclude_patterns: list[str] = field(default_factory=list) + require_return: bool = True + require_export: bool = True + include_async: bool = True + include_methods: bool = True + min_lines: int | None = None + max_lines: int | None = None + + def __post_init__(self) -> None: + """Pre-compile regex patterns from glob patterns for faster matching.""" + self.include_regexes = [re.compile(fnmatch.translate(p)) for p in self.include_patterns] + self.exclude_regexes = [re.compile(fnmatch.translate(p)) for p in self.exclude_patterns] + + def matches_include_patterns(self, name: str) -> bool: + """Check if name matches any include pattern.""" + if not self.include_regexes: + return True + return any(regex.match(name) for regex in self.include_regexes) + + def matches_exclude_patterns(self, name: str) -> bool: + """Check if name matches any exclude pattern.""" + if not self.exclude_regexes: + return False + return any(regex.match(name) for regex in self.exclude_regexes) diff --git a/src/codeflash_python/discovery/function_filtering.py b/src/codeflash_python/discovery/function_filtering.py new file mode 100644 index 000000000..452bb4802 --- /dev/null +++ b/src/codeflash_python/discovery/function_filtering.py @@ -0,0 +1,281 @@ +"""Function filtering and validation for optimization discovery.""" + +from __future__ import annotations + +import ast +import contextlib +import logging +import os +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import git + +from codeflash_python.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again +from codeflash_python.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages +from codeflash_python.code_utils.env_utils import get_pr_number +from codeflash_python.code_utils.git_utils import get_repo_owner_and_name +from codeflash_python.models.function_types import qualified_name_with_modules_from_root + +if TYPE_CHECKING: + from argparse import Namespace + + from codeflash_core.models import FunctionToOptimize + from codeflash_python.models.models import CodeOptimizationContext + +logger = logging.getLogger("codeflash_python") + + +def is_git_repo(file_path: str) -> bool: + try: + git.Repo(file_path, search_parent_directories=True) + return True + except git.InvalidGitRepositoryError: + return False + + +@cache +def ignored_submodule_paths(module_root: str) -> list[Path]: + if is_git_repo(module_root): + git_repo = git.Repo(module_root, search_parent_directories=True) + try: + working_dir = git_repo.working_tree_dir + if working_dir is not None: + return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules] + except Exception as e: + logger.warning("Error getting submodule paths: %s", e) + return [] + + +def was_function_previously_optimized( + function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, args: Namespace +) -> bool: + """Check which functions have already been optimized and filter them out. + + This function calls the optimization API to: + 1. Check which functions are already optimized + 2. Log new function hashes to the database + 3. Return only functions that need optimization + + Returns: + Tuple of (filtered_functions_dict, remaining_count) + + """ + # was_function_previously_optimized is for the checking the optimization duplicates in the github action, no need to do this in the LSP mode + + # Check optimization status if repository info is provided + # already_optimized_count = 0 + + # Check optimization status if repository info is provided + # already_optimized_count = 0 + owner = None + repo = None + with contextlib.suppress(git.exc.InvalidGitRepositoryError): + owner, repo = get_repo_owner_and_name() + + pr_number = get_pr_number() + + if not owner or not repo or pr_number is None or getattr(args, "no_pr", False): + return False + + func_hash = code_context.hashing_code_context_hash + + code_contexts = [ + { + "file_path": str(function_to_optimize.file_path), + "function_name": function_to_optimize.qualified_name, + "code_hash": func_hash, + } + ] + + try: + result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) + already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_tuples", []) + return len(already_optimized_paths) > 0 + + except Exception as e: + logger.warning("Failed to check optimization status: %s", e) + # Return all functions if API call fails + return False + + +def filter_functions( + modified_functions: dict[Path, list[FunctionToOptimize]], + tests_root: Path, + ignore_paths: list[Path], + project_root: Path, + module_root: Path, + previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None, + *, + disable_logs: bool = False, +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + resolved_project_root = project_root.resolve() + filtered_modified_functions: dict[Path, list[FunctionToOptimize]] = {} + blocklist_funcs = get_blocklisted_functions() + logger.debug("Blocklisted functions: %s", blocklist_funcs) + # Remove any function that we don't want to optimize + # already_optimized_paths = check_optimization_status(modified_functions, project_root) + + # Ignore files with submodule path, cache the submodule paths + submodule_paths = ignored_submodule_paths(module_root) + + functions_count: int = 0 + test_functions_removed_count: int = 0 + non_modules_removed_count: int = 0 + site_packages_removed_count: int = 0 + ignore_paths_removed_count: int = 0 + malformed_paths_count: int = 0 + submodule_ignored_paths_count: int = 0 + blocklist_funcs_removed_count: int = 0 + previous_checkpoint_functions_removed_count: int = 0 + # Normalize paths for case-insensitive comparison on Windows + tests_root_str = os.path.normcase(str(tests_root)) + module_root_str = os.path.normcase(str(module_root)) + project_root_str = os.path.normcase(str(project_root)) + + # Check if tests_root overlaps with module_root or project_root + # In this case, we need to use file pattern matching instead of directory matching + tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith( + tests_root_str + os.sep + ) + + # Test file patterns for when tests_root overlaps with source + test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.") + test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep) + + def is_test_file(file_path_normalized: str) -> bool: + if tests_root_overlaps_source: + file_lower = file_path_normalized.lower() + basename = Path(file_lower).name + if basename.startswith("test_") or basename == "conftest.py": + return True + if any(pattern in file_lower for pattern in test_file_name_patterns): + return True + if project_root_str and file_lower.startswith(project_root_str.lower()): + relative_path = file_lower[len(project_root_str) :] + return any(pattern in relative_path for pattern in test_dir_patterns) + return False + return file_path_normalized.startswith(tests_root_str + os.sep) + + # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching + for file_path_path, functions in modified_functions.items(): + _functions = functions + file_path = str(file_path_path) + file_path_normalized = os.path.normcase(file_path) + if is_test_file(file_path_normalized): + test_functions_removed_count += len(_functions) + continue + if file_path_path in ignore_paths or any( + file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths + ): + ignore_paths_removed_count += 1 + continue + if file_path_path in submodule_paths or any( + file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep) + for submodule_path in submodule_paths + ): + submodule_ignored_paths_count += 1 + continue + if path_belongs_to_site_packages(Path(file_path)): + site_packages_removed_count += len(_functions) + continue + if not file_path_normalized.startswith(module_root_str + os.sep): + non_modules_removed_count += len(_functions) + continue + + try: + ast.parse(f"import {module_name_from_file_path(Path(file_path), resolved_project_root)}") + except SyntaxError: + malformed_paths_count += 1 + continue + + if blocklist_funcs: + functions_tmp = [] + for function in _functions: + if ( + function.file_path.name in blocklist_funcs + and function.qualified_name in blocklist_funcs[function.file_path.name] + ): + # This function is in blocklist, we can skip it + blocklist_funcs_removed_count += 1 + continue + # This function is NOT in blocklist. we can keep it + functions_tmp.append(function) + _functions = functions_tmp + + if previous_checkpoint_functions: + functions_tmp = [] + for function in _functions: + if ( + qualified_name_with_modules_from_root(function, resolved_project_root) + in previous_checkpoint_functions + ): + previous_checkpoint_functions_removed_count += 1 + continue + functions_tmp.append(function) + _functions = functions_tmp + + filtered_modified_functions[file_path_path] = _functions + functions_count += len(_functions) + + if not disable_logs: + log_info = { + "Test functions removed": test_functions_removed_count, + "Site-package functions removed": site_packages_removed_count, + "Non-importable file paths": malformed_paths_count, + "Functions outside module-root": non_modules_removed_count, + "Files from ignored paths": ignore_paths_removed_count, + "Files from ignored submodules": submodule_ignored_paths_count, + "Blocklisted functions removed": blocklist_funcs_removed_count, + "Functions skipped from checkpoint": previous_checkpoint_functions_removed_count, + } + entries = [f"{label}: {cnt}" for label, cnt in log_info.items() if cnt > 0] + if entries: + logger.info("Ignored functions and files: %s", ", ".join(entries)) + return {k: v for k, v in filtered_modified_functions.items() if v}, functions_count + + +def is_test_file_by_pattern(file_path: Path) -> bool: + """Check if a file is a test file using naming conventions. + + Used when tests_root overlaps with module_root, so directory-based filtering would + incorrectly exclude all source files. Falls back to filename and directory patterns. + """ + name = file_path.name.lower() + if name.startswith("test_") or name == "conftest.py": + return True + test_name_patterns = (".test.", ".spec.", "_test.", "_spec.") + if any(p in name for p in test_name_patterns): + return True + path_str = str(file_path).lower() + test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep) + return any(p in path_str for p in test_dir_patterns) + + +def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool: + """Optimized version of the filter_functions function above. + + Takes in file paths and returns the count of files that are to be optimized. + """ + submodule_paths = None + # When tests_root overlaps module_root (e.g., both are "src"), use pattern matching + # instead of directory matching to avoid filtering out all source files. + tests_root_overlaps = tests_root == module_root or module_root.is_relative_to(tests_root) + if tests_root_overlaps: + if is_test_file_by_pattern(file_path): + return False + elif file_path.is_relative_to(tests_root): + return False + if file_path in ignore_paths or any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths): + return False + if path_belongs_to_site_packages(file_path): + return False + if not file_path.is_relative_to(module_root): + return False + if submodule_paths is None: + submodule_paths = ignored_submodule_paths(module_root) + return not ( + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + ) diff --git a/src/codeflash_python/discovery/function_visitors.py b/src/codeflash_python/discovery/function_visitors.py new file mode 100644 index 000000000..8a6d9a312 --- /dev/null +++ b/src/codeflash_python/discovery/function_visitors.py @@ -0,0 +1,250 @@ +"""AST/CST-based function discovery and inspection.""" + +from __future__ import annotations + +import ast +import logging +from typing import TYPE_CHECKING + +import libcst as cst +from pydantic.dataclasses import dataclass + +from codeflash_core.models import FunctionParent, FunctionToOptimize +from codeflash_python.discovery.filter_criteria import FunctionFilterCriteria + +if TYPE_CHECKING: + from pathlib import Path + + from libcst import CSTNode + from libcst.metadata import CodeRange + +logger = logging.getLogger("codeflash_python") + + +def is_class_defined_in_file(class_name: str, file_path: Path) -> bool: + if not file_path.exists(): + return False + with file_path.open(encoding="utf8") as file: + source = file.read() + tree = ast.parse(source) + return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree)) + + +# ============================================================================= +# CST-based function discovery +# ============================================================================= + + +class ReturnStatementVisitor(cst.CSTVisitor): + def __init__(self) -> None: + super().__init__() + self.has_return_statement: bool = False + + def visit_Return(self, node: cst.Return) -> None: + self.has_return_statement = True + + +class FunctionVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider) + + def __init__(self, file_path: Path) -> None: + super().__init__() + self.file_path: Path = file_path + self.functions: list[FunctionToOptimize] = [] + + @staticmethod + def is_pytest_fixture(node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + dec = decorator.decorator + if isinstance(dec, cst.Call): + dec = dec.func + if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture": + if isinstance(dec.value, cst.Name) and dec.value.value == "pytest": + return True + if isinstance(dec, cst.Name) and dec.value == "fixture": + return True + return False + + @staticmethod + def is_property(node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + dec = decorator.decorator + if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"): + return True + return False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + return_visitor: ReturnStatementVisitor = ReturnStatementVisitor() + node.visit(return_visitor) + if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node): + pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node) + parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node) + ast_parents: list[FunctionParent] = [] + while parents is not None: + if isinstance(parents, cst.FunctionDef): + # Skip nested functions — only discover top-level and class-level functions + return + if isinstance(parents, cst.ClassDef): + ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__)) + parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None) + self.functions.append( + FunctionToOptimize( + function_name=node.name.value, + file_path=self.file_path, + parents=list(reversed(ast_parents)), + starting_line=pos.start.line, + ending_line=pos.end.line, + is_async=bool(node.asynchronous), + language="python", + ) + ) + + +def discover_functions( + source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None +) -> list[FunctionToOptimize]: + criteria = filter_criteria or FunctionFilterCriteria() + + tree = cst.parse_module(source) + + wrapper = cst.metadata.MetadataWrapper(tree) + function_visitor = FunctionVisitor(file_path=file_path) + wrapper.visit(function_visitor) + + functions: list[FunctionToOptimize] = [] + for func in function_visitor.functions: + if not criteria.include_async and func.is_async: + continue + + if not criteria.include_methods and func.parents: + continue + + if criteria.require_return and func.starting_line is None: + continue + + func_with_is_method = FunctionToOptimize( + function_name=func.function_name, + file_path=file_path, + parents=func.parents, + starting_line=func.starting_line, + ending_line=func.ending_line, + starting_col=func.starting_col, + ending_col=func.ending_col, + is_async=func.is_async, + is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents), + language=func.language, + ) + functions.append(func_with_is_method) + + return functions + + +@dataclass(frozen=True) +class FunctionProperties: + is_top_level: bool + has_args: bool | None + is_staticmethod: bool | None + is_classmethod: bool | None + staticmethod_class_name: str | None + + +class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): + def __init__( + self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None + ) -> None: + self.file_name = file_name + self.class_name = class_name + self.function_name = function_or_method_name + self.is_top_level = False + self.function_has_args: bool | None = None + self.line_no = line_no + self.is_staticmethod = False + self.is_classmethod = False + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if self.class_name is None and node.name == self.function_name: + self.is_top_level = True + self.function_has_args = any( + ( + bool(node.args.args), + bool(node.args.kwonlyargs), + bool(node.args.kwarg), + bool(node.args.posonlyargs), + bool(node.args.vararg), + ) + ) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + if self.class_name is None and node.name == self.function_name: + self.is_top_level = True + self.function_has_args = any( + ( + bool(node.args.args), + bool(node.args.kwonlyargs), + bool(node.args.kwarg), + bool(node.args.posonlyargs), + bool(node.args.vararg), + ) + ) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + # iterate over the class methods + if node.name == self.class_name: + for body_node in node.body: + if ( + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and body_node.name == self.function_name + ): + self.is_top_level = True + if any( + isinstance(decorator, ast.Name) and decorator.id == "classmethod" + for decorator in body_node.decorator_list + ): + self.is_classmethod = True + elif any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ): + self.is_staticmethod = True + return + elif self.line_no: + # If we have line number info, check if class has a static method with the same line number + # This way, if we don't have the class name, we can still find the static method + for body_node in node.body: + if ( + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and body_node.name == self.function_name + and body_node.lineno in {self.line_no, self.line_no + 1} + and any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ) + ): + self.is_staticmethod = True + self.is_top_level = True + self.class_name = node.name + return + + return + + +def inspect_top_level_functions_or_methods( + file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None +) -> FunctionProperties | None: + with file_name.open(encoding="utf8") as file: + try: + ast_module = ast.parse(file.read()) + except Exception: + return None + visitor = TopLevelFunctionOrMethodVisitor( + file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no + ) + visitor.visit(ast_module) + staticmethod_class_name = visitor.class_name if visitor.is_staticmethod else None + return FunctionProperties( + is_top_level=visitor.is_top_level, + has_args=visitor.function_has_args, + is_staticmethod=visitor.is_staticmethod, + is_classmethod=visitor.is_classmethod, + staticmethod_class_name=staticmethod_class_name, + ) diff --git a/src/codeflash_python/discovery/functions_to_optimize.py b/src/codeflash_python/discovery/functions_to_optimize.py new file mode 100644 index 000000000..d4e528741 --- /dev/null +++ b/src/codeflash_python/discovery/functions_to_optimize.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +import ast +import logging +import os +import random +import warnings +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.code_utils import exit_with_message +from codeflash_python.code_utils.config_consts import PYTHON_DIR_EXCLUDES, PYTHON_FILE_EXTENSIONS +from codeflash_python.code_utils.git_utils import get_git_diff +from codeflash_python.discovery.discover_unit_tests import discover_unit_tests +from codeflash_python.discovery.filter_criteria import FunctionFilterCriteria +from codeflash_python.discovery.function_filtering import filter_functions +from codeflash_python.discovery.function_visitors import discover_functions, is_class_defined_in_file +from codeflash_python.telemetry.posthog_cf import ph + +if TYPE_CHECKING: + from codeflash_core.config import TestConfig + from codeflash_core.models import FunctionToOptimize + +logger = logging.getLogger("codeflash_python") + + +# ============================================================================= +# Multi-language support helpers +# ============================================================================= + +_VCS_EXCLUDES = frozenset({".git", ".hg", ".svn"}) + + +def parse_dir_excludes(patterns: frozenset[str]) -> tuple[frozenset[str], tuple[str, ...], tuple[str, ...]]: + """Split glob patterns into exact names, prefixes, and suffixes. + + Patterns ending with ``*`` become prefix matches, patterns starting with ``*`` + become suffix matches, and plain strings become exact matches. + """ + exact: set[str] = set() + prefixes: list[str] = [] + suffixes: list[str] = [] + for p in patterns: + if p.endswith("*"): + prefixes.append(p[:-1]) + elif p.startswith("*"): + suffixes.append(p[1:]) + else: + exact.add(p) + return frozenset(exact), tuple(prefixes), tuple(suffixes) + + +def get_files_for_language( + module_root_path: Path, ignore_paths: list[Path] | None = None, language: str | None = None +) -> list[Path]: + """Get all source files for supported languages. + + Args: + module_root_path: Root path to search for source files. + ignore_paths: List of paths to ignore (can be files or directories). + language: Optional specific language to filter for. If None, includes all supported languages. + + Returns: + List of file paths matching supported extensions. + + """ + if ignore_paths is None: + ignore_paths = [] + + extensions = PYTHON_FILE_EXTENSIONS + all_patterns = PYTHON_DIR_EXCLUDES | _VCS_EXCLUDES + + dir_excludes, prefixes, suffixes = parse_dir_excludes(all_patterns) + + ignore_dirs: set[str] = set() + ignore_files: set[Path] = set() + for p in ignore_paths: + p = Path(p) if not isinstance(p, Path) else p + if p.is_file(): + ignore_files.add(p) + else: + ignore_dirs.add(str(p)) + + files: list[Path] = [] + for dirpath, dirnames, filenames in os.walk(module_root_path): + dirnames[:] = [ + d + for d in dirnames + if d not in dir_excludes + and not (prefixes and d.startswith(prefixes)) + and not (suffixes and d.endswith(suffixes)) + and str(Path(dirpath) / d) not in ignore_dirs + ] + for fname in filenames: + if fname.endswith(extensions): + fpath = Path(dirpath, fname) + if fpath not in ignore_files: + files.append(fpath) + return files + + +def find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]: + """Find all optimizable functions using the language support abstraction. + + This function uses the registered language support for the file's language + to discover functions, then converts them to FunctionToOptimize instances. + """ + functions: dict[Path, list[FunctionToOptimize]] = {} + + try: + criteria = FunctionFilterCriteria(require_return=True) + source = file_path.read_text(encoding="utf-8") + functions[file_path] = discover_functions(source, file_path, criteria) + except Exception as e: + logger.debug("Failed to discover functions in %s: %s", file_path, e) + + return functions + + +def get_functions_to_optimize( + optimize_all: str | None, + replay_test: list[Path] | None, + file: Path | str | None, + only_get_this_function: str | None, + test_cfg: TestConfig, + ignore_paths: list[Path], + project_root: Path, + module_root: Path, + previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, +) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: + assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( + "Only one of optimize_all, replay_test, or file should be provided" + ) + functions: dict[Path, list[FunctionToOptimize]] + trace_file_path: Path | None = None + is_lsp = False + with warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=SyntaxWarning) + if optimize_all: + functions = get_all_files_and_functions(Path(optimize_all), ignore_paths) + elif replay_test: + functions, trace_file_path = get_all_replay_test_functions( + replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root + ) + elif file is not None: + file = Path(file) if isinstance(file, str) else file + functions = find_all_functions_in_file(file) + if only_get_this_function is not None: + split_function = only_get_this_function.split(".") + if len(split_function) > 2: + if is_lsp: + return functions, 0, None + exit_with_message( + "Function name should be in the format 'function_name' or 'class_name.function_name'" + ) + if len(split_function) == 2: + class_name, only_function_name = split_function + else: + class_name = None + only_function_name = split_function[0] + found_function = None + for fn in functions.get(file, []): + if only_function_name == fn.function_name and ( + class_name is None or class_name == fn.top_level_parent_name + ): + found_function = fn + if found_function is None: + if is_lsp: + return functions, 0, None + + found = closest_matching_file_function_name(only_get_this_function, functions) + if found is not None: + file, found_function = found + exit_with_message( + f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property.\n" + f"Did you mean {found_function.qualified_name} instead?" + ) + + exit_with_message( + f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property" + ) + + assert found_function is not None + functions[file] = [found_function] + else: + logger.info("Finding all functions modified in the current git diff ...") + ph("cli-optimizing-git-diff") + functions = get_functions_within_git_diff(uncommitted_changes=False) + filtered_modified_functions, functions_count = filter_functions( + functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions + ) + + return filtered_modified_functions, functions_count, trace_file_path + + +def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]: + modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes) + return get_functions_within_lines(modified_lines) + + +def closest_matching_file_function_name( + qualified_fn_to_find: str, found_fns: dict[Path, list[FunctionToOptimize]] +) -> tuple[Path, FunctionToOptimize] | None: + """Find the closest matching function name using Levenshtein distance. + + Args: + qualified_fn_to_find: Function name to find in format "Class.function" or "function" + found_fns: Dictionary of file paths to list of functions + + Returns: + Tuple of (file_path, function) for closest match, or None if no matches found + + """ + min_distance = 4 + closest_match = None + closest_file = None + + qualified_fn_to_find_lower = qualified_fn_to_find.lower() + + # Cache levenshtein_distance locally for improved lookup speed + _levenshtein = levenshtein_distance + + for file_path, functions in found_fns.items(): + for function in functions: + # Compare either full qualified name or just function name + fn_name = function.qualified_name.lower() + # If the absolute length difference is already >= min_distance, skip calculation + if abs(len(qualified_fn_to_find_lower) - len(fn_name)) >= min_distance: + continue + dist = _levenshtein(qualified_fn_to_find_lower, fn_name) + + if dist < min_distance: + min_distance = dist + closest_match = function + closest_file = file_path + + if closest_match is not None and closest_file is not None: + return closest_file, closest_match + return None + + +def levenshtein_distance(s1: str, s2: str) -> int: + if len(s1) > len(s2): + s1, s2 = s2, s1 + len1 = len(s1) + len2 = len(s2) + # Use a preallocated list instead of creating a new list every iteration + previous = list(range(len1 + 1)) + current = [0] * (len1 + 1) + + for index2 in range(len2): + char2 = s2[index2] + current[0] = index2 + 1 + for index1 in range(len1): + char1 = s1[index1] + if char1 == char2: + current[index1 + 1] = previous[index1] + else: + # Fast min calculation without tuple construct + a = previous[index1] + b = previous[index1 + 1] + c = current[index1] + min_val = min(b, a) + min_val = min(c, min_val) + current[index1 + 1] = 1 + min_val + # Swap references instead of copying + previous, current = current, previous + return previous[len1] + + +def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionToOptimize]]: + modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash) + return get_functions_within_lines(modified_lines) + + +def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]: + functions: dict[Path, list[FunctionToOptimize]] = {} + for path_str, lines_in_file in modified_lines.items(): + path = Path(path_str) + if not path.exists(): + continue + all_functions = find_all_functions_in_file(path) + functions[path] = [ + func + for func in all_functions.get(path, []) + if func.starting_line is not None + and func.ending_line is not None + and any(func.starting_line <= line <= func.ending_line for line in lines_in_file) + ] + return functions + + +def get_all_files_and_functions( + module_root_path: Path, ignore_paths: list[Path], language: str | None = None +) -> dict[Path, list[FunctionToOptimize]]: + """Get all optimizable functions from files in the module root. + + Args: + module_root_path: Root path to search for source files. + ignore_paths: List of paths to ignore. + language: Optional specific language to filter for. If None, includes all supported languages. + + Returns: + Dictionary mapping file paths to lists of FunctionToOptimize. + + """ + functions: dict[Path, list[FunctionToOptimize]] = {} + for file_path in get_files_for_language(module_root_path, ignore_paths, language): + functions.update(find_all_functions_in_file(file_path).items()) + # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time. + # Helpful if an optimize-all run is stuck and we restart it. + files_list = list(functions.items()) + random.shuffle(files_list) + return dict(files_list) + + +def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]: + """Find all optimizable functions in a file using the language support abstraction.""" + if file_path.suffix.lower() not in PYTHON_FILE_EXTENSIONS: + return {} + try: + criteria = FunctionFilterCriteria(require_return=True) + source = file_path.read_text(encoding="utf-8") + return {file_path: discover_functions(source, file_path, criteria)} + except Exception as e: + logger.debug("Failed to discover functions in %s: %s", file_path, e) + return {} + + +def get_all_replay_test_functions( + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path +) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: + trace_file_path: Path | None = None + for replay_test_file in replay_test: + try: + with replay_test_file.open("r", encoding="utf8") as f: + tree = ast.parse(f.read()) + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "trace_file_path" + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + ): + trace_file_path = Path(node.value.value) + break + if trace_file_path: + break + if trace_file_path: + break + except Exception as e: + logger.warning("Error parsing replay test file %s: %s", replay_test_file, e) + + if trace_file_path is None: + logger.error("Could not find trace_file_path in replay test files.") + exit_with_message("Could not find trace_file_path in replay test files.") + raise AssertionError("Unreachable") # exit_with_message never returns + + if not trace_file_path.exists(): + logger.error("Trace file not found: %s", trace_file_path) + exit_with_message( + f"Trace file not found: {trace_file_path}\n" + "The trace file referenced in the replay test no longer exists.\n" + "This can happen if the trace file was cleaned up after a previous optimization run.\n" + "Please regenerate the replay test by re-running 'codeflash optimize' with your command." + ) + + function_tests, _, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test) + # Get the absolute file paths for each function, excluding class name if present + filtered_valid_functions = defaultdict(list) + file_to_functions_map = defaultdict(list) + # below logic can be cleaned up with a better data structure to store the function paths + for function in function_tests: + parts = function.split(".") + module_path_parts = parts[:-1] # Exclude the function or method name + function_name = parts[-1] + # Check if the second-to-last part is a class name + class_name = ( + module_path_parts[-1] + if module_path_parts + and is_class_defined_in_file( + module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py") + ) + else None + ) + if class_name: + # If there is a class name, append it to the module path + qualified_function_name = class_name + "." + function_name + file_path_parts = module_path_parts[:-1] # Exclude the class name + else: + qualified_function_name = function_name + file_path_parts = module_path_parts + file_path = Path(project_root_path, *file_path_parts).with_suffix(".py") + if not file_path.exists(): + continue + file_to_functions_map[file_path].append((qualified_function_name, function_name, class_name)) + for file_path, functions_in_file in file_to_functions_map.items(): + all_valid_functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file_path=file_path) + filtered_list = [] + for func_data in functions_in_file: + qualified_name_to_match, _, _ = func_data + filtered_list.extend( + [ + valid_function + for valid_function in all_valid_functions[file_path] + if valid_function.qualified_name == qualified_name_to_match + ] + ) + if filtered_list: + filtered_valid_functions[file_path] = filtered_list + + return dict(filtered_valid_functions), trace_file_path diff --git a/src/codeflash_python/discovery/import_analyzer.py b/src/codeflash_python/discovery/import_analyzer.py new file mode 100644 index 000000000..8a29f7288 --- /dev/null +++ b/src/codeflash_python/discovery/import_analyzer.py @@ -0,0 +1,369 @@ +"""AST-based import analysis for filtering test files by target function imports.""" + +from __future__ import annotations + +import ast +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash_python.models.models import TestsInFile + +logger = logging.getLogger("codeflash_python") + + +class ImportAnalyzer(ast.NodeVisitor): + """AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file.""" + + def __init__(self, function_names_to_find: set[str]) -> None: + self.function_names_to_find = function_names_to_find + self.found_any_target_function: bool = False + self.found_qualified_name = None + self.imported_modules: set[str] = set() + self.has_dynamic_imports: bool = False + self.wildcard_modules: set[str] = set() + # Track aliases: alias_name -> original_name + self.alias_mapping: dict[str, str] = {} + # Track instances: variable_name -> class_name + self.instance_mapping: dict[str, str] = {} + + # Precompute function_names for prefix search + # For prefix match, store mapping from prefix-root to candidates for O(1) matching + self.exact_names = function_names_to_find + self.prefix_roots: dict[str, list[str]] = {} + # Precompute sets for faster lookup during visit_Attribute() + self.dot_names: set[str] = set() + self.dot_methods: dict[str, set[str]] = {} + self.class_method_to_target: dict[tuple[str, str], str] = {} + + # Optimize prefix-roots and dot_methods construction + add_dot_methods = self.dot_methods.setdefault + add_prefix_roots = self.prefix_roots.setdefault + dot_names_add = self.dot_names.add + class_method_to_target = self.class_method_to_target + for name in function_names_to_find: + if "." in name: + root, method = name.rsplit(".", 1) + dot_names_add(name) + add_dot_methods(method, set()).add(root) + class_method_to_target[(root, method)] = name + root_prefix = name.split(".", 1)[0] + add_prefix_roots(root_prefix, []).append(name) + + def visit_Import(self, node: ast.Import) -> None: + """Handle 'import module' statements.""" + if self.found_any_target_function: + return + + for alias in node.names: + module_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(module_name) + + # Check for dynamic import modules + if alias.name == "importlib": + self.has_dynamic_imports = True + + # Check if module itself is a target qualified name + if module_name in self.function_names_to_find: + self.found_any_target_function = True + self.found_qualified_name = module_name + return + # Check if any target qualified name starts with this module + for target_func in self.function_names_to_find: + if target_func.startswith(f"{module_name}."): + self.found_any_target_function = True + self.found_qualified_name = target_func + return + + def visit_Assign(self, node: ast.Assign) -> None: + """Track variable assignments, especially class instantiations.""" + if self.found_any_target_function: + return + + # Check if the assignment is a class instantiation + value = node.value + if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): + class_name = value.func.id + if class_name in self.imported_modules: + # Map the variable to the actual class name (handling aliases) + original_class = self.alias_mapping.get(class_name, class_name) + # Use list comprehension for direct assignment to instance_mapping, reducing loop overhead + targets = node.targets + instance_mapping = self.instance_mapping + # since ast.Name nodes are heavily used, avoid local lookup for isinstance + # and reuse locals for faster attribute access + for target in targets: + if isinstance(target, ast.Name): + instance_mapping[target.id] = original_class + + # Continue visiting child nodes + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Handle 'from module import name' statements.""" + if self.found_any_target_function: + return + + mod = node.module + if not mod: + return + + fnames = self.exact_names + proots = self.prefix_roots + + for alias in node.names: + aname = alias.name + if aname == "*": + self.wildcard_modules.add(mod) + continue + + imported_name = alias.asname if alias.asname else aname + self.imported_modules.add(imported_name) + + if alias.asname: + self.alias_mapping[imported_name] = aname + + # Fast check for dynamic import + if mod == "importlib" and aname == "import_module": + self.has_dynamic_imports = True + + qname = f"{mod}.{aname}" + + # Fast exact match check + if aname in fnames: + self.found_any_target_function = True + self.found_qualified_name = aname + return + if qname in fnames: + self.found_any_target_function = True + self.found_qualified_name = qname + return + + # Check if any target function is a method of the imported class/module + # Be conservative except when an alias is used (which requires exact method matching) + for target_func in fnames: + if "." in target_func: + class_name, _method_name = target_func.split(".", 1) + if aname == class_name and not alias.asname: + self.found_any_target_function = True + self.found_qualified_name = target_func + return + # If an alias is used, track it for later method access detection + # The actual method usage will be detected in visit_Attribute + + prefix = qname + "." + # Only bother if one of the targets startswith the prefix-root + candidates = proots.get(qname, ()) + for target_func in candidates: + if target_func.startswith(prefix): + self.found_any_target_function = True + self.found_qualified_name = target_func + return + + def visit_Attribute(self, node: ast.Attribute) -> None: + """Handle attribute access like module.function_name.""" + if self.found_any_target_function: + return + + # Check if this is accessing a target function through an imported module + + node_value = node.value + node_attr = node.attr + + # Check if this is accessing a target function through an imported module + + # Accessing a target function through an imported module (fast path for imported modules) + val_id = getattr(node_value, "id", None) + if val_id is not None and val_id in self.imported_modules: + if node_attr in self.function_names_to_find: + self.found_any_target_function = True + self.found_qualified_name = node_attr + return + # Methods via imported modules using precomputed _dot_methods and _class_method_to_target + roots_possible = self.dot_methods.get(node_attr) + if roots_possible: + imported_name = val_id + original_name = self.alias_mapping.get(imported_name, imported_name) + if original_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self.class_method_to_target[(original_name, node_attr)] + return + # Also check if the imported name itself (without resolving alias) matches + # This handles cases where the class itself is the target + if imported_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self.class_method_to_target.get( + (imported_name, node_attr), f"{imported_name}.{node_attr}" + ) + return + + # Methods on instance variables (tighten type check and lookup, important for larger ASTs) + if val_id is not None and val_id in self.instance_mapping: + class_name = self.instance_mapping[val_id] + roots_possible = self.dot_methods.get(node_attr) + if roots_possible and class_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self.class_method_to_target[(class_name, node_attr)] + return + + # Check for dynamic import match + if self.has_dynamic_imports and node_attr in self.function_names_to_find: + self.found_any_target_function = True + self.found_qualified_name = node_attr + return + + # Replace self.generic_visit with base class impl directly: removes an attribute lookup + if not self.found_any_target_function: + ast.NodeVisitor.generic_visit(self, node) + + def visit_Call(self, node: ast.Call) -> None: + """Handle function calls, particularly __import__.""" + if self.found_any_target_function: + return + + # Check if this is a __import__ call + if isinstance(node.func, ast.Name) and node.func.id == "__import__": + self.has_dynamic_imports = True + # When __import__ is used, any target function could potentially be imported + # Be conservative and assume it might import target functions + + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + """Handle direct name usage like target_function().""" + if self.found_any_target_function: + return + + # Check for __import__ usage + if node.id == "__import__": + self.has_dynamic_imports = True + + # Check if this is a direct usage of a target function name + # This catches cases like: result = target_function() + if node.id in self.function_names_to_find: + self.found_any_target_function = True + self.found_qualified_name = node.id + return + + # Check if this name could come from a wildcard import + for wildcard_module in self.wildcard_modules: + for target_func in self.function_names_to_find: + # Check if target_func is from this wildcard module and name matches + if target_func.startswith(f"{wildcard_module}.") and target_func.endswith(f".{node.id}"): + self.found_any_target_function = True + self.found_qualified_name = target_func + return + + self.generic_visit(node) + + def generic_visit(self, node: ast.AST) -> None: + """Override generic_visit to stop traversal if a target function is found.""" + if self.found_any_target_function: + return + # Direct base call improves run speed (avoids extra method resolution) + self.fast_generic_visit(node) + + def fast_generic_visit(self, node: ast.AST) -> None: + """Faster generic_visit: Inline traversal, avoiding method resolution overhead. + + Short-circuits (returns) if found_any_target_function is True. + """ + # This logic is derived from ast.NodeVisitor.generic_visit, but with optimizations. + if self.found_any_target_function: + return + + # Local bindings for improved lookup speed (10-15% faster for inner loop) + visit_cache = type(self).__dict__ + node_fields = node._fields + + # Use manual stack for iterative traversal, replacing recursion + stack = [(node_fields, node)] + append = stack.append + pop = stack.pop + + while stack: + fields, curr_node = pop() + for field in fields: + value = getattr(curr_node, field, None) + if isinstance(value, list): + for item in value: + if self.found_any_target_function: + return + if isinstance(item, ast.AST): + # Method resolution: fast dict lookup first, then getattr fallback + meth = visit_cache.get("visit_" + item.__class__.__name__) + if meth is not None: + meth(self, item) + else: + append((item._fields, item)) + continue + if isinstance(value, ast.AST): + if self.found_any_target_function: + return + meth = visit_cache.get("visit_" + value.__class__.__name__) + if meth is not None: + meth(self, value) + else: + append((value._fields, value)) + + +def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: + """Analyze a test file to see if it imports any of the target functions.""" + try: + with Path(test_file_path).open("r", encoding="utf-8") as f: + source_code = f.read() + tree = ast.parse(source_code, filename=str(test_file_path)) + analyzer = ImportAnalyzer(target_functions) + analyzer.visit(tree) + except (SyntaxError, FileNotFoundError) as e: + logger.debug("Failed to analyze imports in %s: %s", test_file_path, e) + return True + + if analyzer.found_any_target_function: + # logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}") + return True + + # Be conservative with dynamic imports - if __import__ is used and a target function + # is referenced, we should process the file + if analyzer.has_dynamic_imports: + # Check if any target function name appears as a string literal or direct usage + for target_func in target_functions: + if target_func in source_code: + # logger.debug(f"Test file {test_file_path} has dynamic imports and references {target_func}") + return True + + # logger.debug(f"Test file {test_file_path} does not import any target functions.") + return False + + +def filter_test_files_by_imports( + file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str] +) -> dict[Path, list[TestsInFile]]: + """Filter test files based on import analysis to reduce Jedi processing. + + Args: + file_to_test_map: Original mapping of test files to test functions + target_functions: Set of function names we're optimizing + + Returns: + Filtered mapping of test files to test functions + + """ + if not target_functions: + return file_to_test_map + + # logger.debug(f"Target functions for import filtering: {target_functions}") + + filtered_map = {} + for test_file, test_functions in file_to_test_map.items(): + should_process = analyze_imports_in_test_file(test_file, target_functions) + if should_process: + filtered_map[test_file] = test_functions + + logger.debug( + "analyzed %s test files for imports, filtered down to %s relevant files", + len(file_to_test_map), + len(filtered_map), + ) + return filtered_map diff --git a/src/codeflash_python/discovery/pytest_new_process_discovery.py b/src/codeflash_python/discovery/pytest_new_process_discovery.py new file mode 100644 index 000000000..45b08fe22 --- /dev/null +++ b/src/codeflash_python/discovery/pytest_new_process_discovery.py @@ -0,0 +1,64 @@ +# ruff: noqa +import logging +import sys +from pathlib import Path +from typing import Any +import pickle + +logger = logging.getLogger("codeflash_python") + + +# This script should not have any relation to the codeflash package, be careful with imports +cwd = sys.argv[1] +tests_root = sys.argv[2] +pickle_path = sys.argv[3] +collected_tests = [] +pytest_rootdir = None +sys.path.insert(1, str(cwd)) + + +def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]: + test_results = [] + for test in pytest_tests: + test_class = None + if test.cls: + test_class = test.parent.name + test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name}) + return test_results + + +class PytestCollectionPlugin: + def pytest_collection_finish(self, session) -> None: + global pytest_rootdir, collected_tests + + collected_tests.extend(session.items) + pytest_rootdir = session.config.rootdir + + # Write results immediately since pytest.main() will exit after this callback, not always with a success code + tests = parse_pytest_collection_results(collected_tests) + exit_code = getattr(session.config, "exitstatus", 0) + with Path(pickle_path).open("wb") as f: + pickle.dump((exit_code, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL) + + def pytest_collection_modifyitems(self, items) -> None: + skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests") + for item in items: + if "benchmark" in item.fixturenames: + item.add_marker(skip_benchmark) + + +if __name__ == "__main__": + import pytest + + try: + pytest.main( + [tests_root, "-p", "no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"], + plugins=[PytestCollectionPlugin()], + ) + except Exception as e: + logger.warning("Failed to collect tests: %s", e) + try: + with Path(pickle_path).open("wb") as f: + pickle.dump((-1, [], None), f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as pickle_error: + logger.warning("Failed to write failure pickle: %s", pickle_error) diff --git a/src/codeflash_python/discovery/tests_cache.py b/src/codeflash_python/discovery/tests_cache.py new file mode 100644 index 000000000..2cfbaaf97 --- /dev/null +++ b/src/codeflash_python/discovery/tests_cache.py @@ -0,0 +1,167 @@ +"""SQLite persistence layer for caching discovered test-to-function mappings.""" + +from __future__ import annotations + +import hashlib +import logging +import sqlite3 +from collections import defaultdict +from pathlib import Path + +from codeflash_python.code_utils.compat import codeflash_cache_db +from codeflash_python.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType + +logger = logging.getLogger("codeflash_python") + + +class TestsCache: + SCHEMA_VERSION = 1 # Increment this when schema changes + + def __init__(self, project_root_path: Path) -> None: + self.project_root_path = project_root_path.resolve().as_posix() + self.connection = sqlite3.connect(codeflash_cache_db) + self.cur = self.connection.cursor() + + self.cur.execute( + """ + CREATE TABLE IF NOT EXISTS schema_version( + version INTEGER PRIMARY KEY + ) + """ + ) + + self.cur.execute("SELECT version FROM schema_version") + result = self.cur.fetchone() + current_version = result[0] if result else None + + if current_version != self.SCHEMA_VERSION: + logger.debug( + "Schema version mismatch (current: %s, expected: %s). Recreating tables.", + current_version, + self.SCHEMA_VERSION, + ) + self.cur.execute("DROP TABLE IF EXISTS discovered_tests") + self.cur.execute("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash") + self.cur.execute("DELETE FROM schema_version") + self.cur.execute("INSERT INTO schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + self.connection.commit() + + self.cur.execute( + """ + CREATE TABLE IF NOT EXISTS discovered_tests( + project_root_path TEXT, + file_path TEXT, + file_hash TEXT, + qualified_name_with_modules_from_root TEXT, + function_name TEXT, + test_class TEXT, + test_function TEXT, + test_type TEXT, + line_number INTEGER, + col_number INTEGER + ) + """ + ) + self.cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash + ON discovered_tests (project_root_path, file_path, file_hash) + """ + ) + + self.memory_cache = {} + self.pending_rows: list[tuple[str, str, str, str, str, str, int | TestType, int, int]] = [] + self.writes_enabled = True + + def insert_test( + self, + file_path: str, + file_hash: str, + qualified_name_with_modules_from_root: str, + function_name: str, + test_class: str, + test_function: str, + test_type: TestType, + line_number: int, + col_number: int, + ) -> None: + test_type_value = test_type.value if hasattr(test_type, "value") else test_type + self.pending_rows.append( + ( + file_path, + file_hash, + qualified_name_with_modules_from_root, + function_name, + test_class, + test_function, + test_type_value, + line_number, + col_number, + ) + ) + + def flush(self) -> None: + if not self.pending_rows: + return + if not self.writes_enabled: + self.pending_rows.clear() + return + try: + self.cur.executemany( + "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [(self.project_root_path, *row) for row in self.pending_rows], + ) + self.connection.commit() + except sqlite3.OperationalError as e: + logger.debug("Failed to persist discovered test cache, disabling cache writes: %s", e) + self.writes_enabled = False + finally: + self.pending_rows.clear() + + def get_function_to_test_map_for_file( + self, file_path: str, file_hash: str + ) -> dict[str, set[FunctionCalledInTest]] | None: + cache_key = (self.project_root_path, file_path, file_hash) + if cache_key in self.memory_cache: + return self.memory_cache[cache_key] + + self.cur.execute( + "SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?", + (self.project_root_path, file_path, file_hash), + ) + rows = self.cur.fetchall() + if not rows: + return None + + function_to_test_map = defaultdict(set) + + for row in rows: + qualified_name_with_modules_from_root = row[3] + function_called_in_test = FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=Path(row[1]), test_class=row[5], test_function=row[6], test_type=TestType(int(row[7])) + ), + position=CodePosition(line_no=row[8], col_no=row[9]), + ) + function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test) + + result = dict(function_to_test_map) + self.memory_cache[cache_key] = result + return result + + @staticmethod + def compute_file_hash(path: Path) -> str: + h = hashlib.sha256(usedforsecurity=False) + with path.open("rb", buffering=0) as f: + buf = bytearray(8192) + mv = memoryview(buf) + while True: + n = f.readinto(mv) + if n == 0: + break + h.update(mv[:n]) + return h.hexdigest() + + def close(self) -> None: + self.cur.close() + self.connection.close() diff --git a/src/codeflash_python/function_optimizer.py b/src/codeflash_python/function_optimizer.py new file mode 100644 index 000000000..3a95d9e77 --- /dev/null +++ b/src/codeflash_python/function_optimizer.py @@ -0,0 +1,634 @@ +from __future__ import annotations + +import ast +import concurrent.futures +import dataclasses +import logging +import os +import random +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, cast + +from codeflash_core.danom import Err, Ok +from codeflash_python.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path, unified_diff_strings +from codeflash_python.code_utils.config_consts import ( + PYTHON_LANGUAGE_VERSION, + REPEAT_OPTIMIZATION_PROBABILITY, + TOTAL_LOOPING_TIME_EFFECTIVE, + EffortKeys, + EffortLevel, + get_effort_value, +) +from codeflash_python.code_utils.shell_utils import make_env_with_project_root +from codeflash_python.context.unused_helper_detection import ( + detect_unused_helper_functions, + revert_unused_helper_functions, +) +from codeflash_python.discovery.function_filtering import was_function_previously_optimized +from codeflash_python.models.experiment_metadata import ExperimentMetadata +from codeflash_python.models.models import OptimizationSet, TestFiles, TestingMode, TestResults +from codeflash_python.optimizer import resolve_python_function_ast +from codeflash_python.optimizer_mixins import ( + BaselineEstablishmentMixin, + CandidateEvaluationMixin, + CodeReplacementMixin, + RefinementMixin, + ResultProcessingMixin, + TestExecutionMixin, + TestGenerationMixin, + TestReviewMixin, +) +from codeflash_python.static_analysis.code_replacer import add_custom_marker_to_all_tests, modify_autouse_fixture +from codeflash_python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator +from codeflash_python.static_analysis.numerical_detection import is_numerical_code +from codeflash_python.static_analysis.reference_analysis import get_opt_review_metrics +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.verification.equivalence import compare_test_results +from codeflash_python.verification.path_utils import file_name_from_test_module_name +from codeflash_python.verification.test_output_utils import calculate_function_throughput_from_test_results + +if TYPE_CHECKING: + from argparse import Namespace + from typing import Any + + from codeflash_core.config import TestConfig + from codeflash_core.danom import Result + from codeflash_core.models import FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.api.types import TestDiff, TestFileReview + from codeflash_python.context.types import DependencyResolver + from codeflash_python.models.models import ( + BenchmarkKey, + BestOptimization, + CodeOptimizationContext, + CodeStringsMarkdown, + ConcurrencyMetrics, + CoverageData, + FunctionCalledInTest, + GeneratedTestsList, + OriginalCodeBaseline, + ) + +logger = logging.getLogger("codeflash_python") + + +class FunctionOptimizer( + TestGenerationMixin, + TestExecutionMixin, + TestReviewMixin, + BaselineEstablishmentMixin, + CandidateEvaluationMixin, + RefinementMixin, + ResultProcessingMixin, + CodeReplacementMixin, +): + def __init__( + self, + function_to_optimize: FunctionToOptimize, + test_cfg: TestConfig, + function_to_optimize_source_code: str = "", + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, + function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None, + aiservice_client: AiServiceClient | None = None, + function_benchmark_timings: dict[BenchmarkKey, int] | None = None, + total_benchmark_timings: dict[BenchmarkKey, int] | None = None, + args: Namespace | None = None, + replay_tests_dir: Path | None = None, + call_graph: DependencyResolver | None = None, + effort_override: str | None = None, + ) -> None: + self.project_root = test_cfg.project_root.resolve() + self.test_cfg = test_cfg + self.aiservice_client = aiservice_client + resolved_file_path = function_to_optimize.file_path.resolve() + if resolved_file_path != function_to_optimize.file_path: + function_to_optimize = dataclasses.replace(function_to_optimize, file_path=resolved_file_path) + self.function_to_optimize = function_to_optimize + self.function_to_optimize_source_code = ( + function_to_optimize_source_code + if function_to_optimize_source_code + else function_to_optimize.file_path.read_text(encoding="utf8") + ) + if not function_to_optimize_ast: + try: + original_module_ast = ast.parse(self.function_to_optimize_source_code) + self.function_to_optimize_ast = resolve_python_function_ast( + function_to_optimize.function_name, function_to_optimize.parents, original_module_ast + ) + except SyntaxError: + self.function_to_optimize_ast = None + else: + self.function_to_optimize_ast = function_to_optimize_ast + self.function_to_tests = function_to_tests if function_to_tests else {} + + self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) + from codeflash_python.api.aiservice import LocalAiServiceClient + + self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None + self.test_files = TestFiles(test_files=[]) + + default_effort = getattr(args, "effort", EffortLevel.MEDIUM.value) if args else EffortLevel.MEDIUM.value + self.effort = effort_override or default_effort + + self.args = args # Check defaults for these + self.function_trace_id: str = str(uuid.uuid4()) + self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root) + + self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} + self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} + self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.call_graph = call_graph + n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 + ) + self.optimization_review = "" + self.future_all_code_repair: list[concurrent.futures.Future] = [] + self.future_all_refinements: list[concurrent.futures.Future] = [] + self.future_adaptive_optimizations: list[concurrent.futures.Future] = [] + self.repair_counter = 0 # track how many repairs we did for each function + self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function + self.is_numerical_code: bool | None = None + self.code_already_exists: bool = False + + # --- Utility methods (from UtilitiesMixin) --- + + def get_trace_id(self, exp_type: str) -> str: + """Get the trace ID for the current experiment type.""" + if self.experiment_id: + return self.function_trace_id[:-4] + exp_type + return self.function_trace_id + + def get_test_env( + self, codeflash_loop_index: int, codeflash_test_iteration: int, codeflash_tracer_disable: int = 1 + ) -> dict: + assert self.args is not None + test_env = make_env_with_project_root(self.args.project_root) + test_env["CODEFLASH_TEST_ITERATION"] = str(codeflash_test_iteration) + test_env["CODEFLASH_TRACER_DISABLE"] = str(codeflash_tracer_disable) + test_env["CODEFLASH_LOOP_INDEX"] = str(codeflash_loop_index) + return test_env + + @staticmethod + def cleanup_leftover_test_return_values() -> None: + # remove leftovers from previous run + get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) + + def cleanup_generated_files(self) -> None: + from codeflash_python.code_utils.code_utils import cleanup_paths + + paths_to_cleanup = [] + for test_file in self.test_files: + paths_to_cleanup.append(test_file.instrumented_behavior_file_path) + paths_to_cleanup.append(test_file.benchmarking_file_path) + + cleanup_paths(paths_to_cleanup) + + def cleanup_async_helper_file(self) -> None: + from codeflash_python.verification.async_instrumentation import ASYNC_HELPER_FILENAME + + helper_path = self.project_root / ASYNC_HELPER_FILENAME + helper_path.unlink(missing_ok=True) + + def get_results_not_matched_error(self) -> Err: + logger.info("h4|Test results did not match the test results of the original code ❌") + return Err("Test results did not match the test results of the original code.") + + # --- Python-specific implementations --- + + def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: + from codeflash_python.context import code_context_extractor + + try: + return Ok( + code_context_extractor.get_code_optimization_context( + self.function_to_optimize, self.project_root, call_graph=self.call_graph + ) + ) + except ValueError as e: + return Err(str(e)) + + def requires_function_ast(self) -> bool: + return True + + def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None: + self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat) + + def get_optimization_review_metrics( + self, + source_code: str, + file_path: Path, + qualified_name: str, + project_root: Path, + tests_root: Path, + language: str, + ) -> str: + return get_opt_review_metrics(source_code, file_path, qualified_name, project_root, tests_root, language) + + def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, str] | None: + logger.info("Disabling all autouse fixtures associated with the generated test files") + original_conftest_content = modify_autouse_fixture(test_paths) + logger.info("Add custom marker to generated test files") + add_custom_marker_to_all_tests(test_paths) + return original_conftest_content + + def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None: + from codeflash_python.verification.instrument_codeflash_capture import instrument_codeflash_capture + + instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root) + + def display_repaired_functions( + self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str] + ) -> None: + """Display per-function diffs of repaired tests using libcst.""" + import libcst as cst + + def extract_functions(source: str, names: set[str]) -> dict[str, str]: + """Extract functions by name from top-level and class bodies.""" + try: + tree = cst.parse_module(source) + except cst.ParserSyntaxError: + logger.debug("Failed to parse source for diff display", exc_info=True) + return {} + result: dict[str, str] = {} + for node in tree.body: + if isinstance(node, cst.FunctionDef) and node.name.value in names: + result[node.name.value] = tree.code_for_node(node) + elif isinstance(node, cst.ClassDef): + for child in node.body.body: + if isinstance(child, cst.FunctionDef) and child.name.value in names: + result[child.name.value] = tree.code_for_node(child) + return result + + for review in reviews: + gt = generated_tests.generated_tests[review.test_index] + repaired_names = {f.function_name for f in review.functions_to_repair} + new_source = gt.generated_original_test_source + old_source = original_sources.get(review.test_index, "") + + old_funcs = extract_functions(old_source, repaired_names) + new_funcs = extract_functions(new_source, repaired_names) + + for name in repaired_names: + old_func = old_funcs.get(name, "") + new_func = new_funcs.get(name, "") + if not new_func: + continue + if old_func and old_func != new_func: + diff = unified_diff_strings( + old_func, new_func, fromfile=f"{name} (before)", tofile=f"{name} (after)" + ) + if diff: + logger.info("Repaired: %s", name) + continue + logger.info("Repaired: %s", name) + + def should_check_coverage(self) -> bool: + return True + + def collect_async_metrics( + self, + benchmarking_results: TestResults, + code_context: CodeOptimizationContext, + helper_code: dict[Path, str], + test_env: dict[str, str], + ) -> tuple[int | None, ConcurrencyMetrics | None]: + if not self.function_to_optimize.is_async: + return None, None + + async_throughput = calculate_function_throughput_from_test_results( + benchmarking_results, self.function_to_optimize.function_name + ) + logger.debug("Async function throughput: %s calls/second", async_throughput) + + concurrency_metrics = self.run_concurrency_benchmark( + code_context=code_context, original_helper_code=helper_code, test_env=test_env + ) + if concurrency_metrics: + logger.debug( + "Concurrency metrics: ratio=%.2f, seq=%sns, conc=%sns", + concurrency_metrics.concurrency_ratio, + concurrency_metrics.sequential_time_ns, + concurrency_metrics.concurrent_time_ns, + ) + return async_throughput, concurrency_metrics + + def instrument_async_for_mode(self, mode: TestingMode) -> None: + from codeflash_python.verification.async_instrumentation import add_async_decorator_to_function + + add_async_decorator_to_function( + self.function_to_optimize.file_path, self.function_to_optimize, mode, project_root=self.project_root + ) + + def parse_line_profile_test_results( + self, line_profiler_output_file: Path | None + ) -> tuple[TestResults | dict, CoverageData | None]: + from codeflash_python.benchmarking.parse_line_profile_test_output import parse_line_profile_results + + return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) + + def compare_candidate_results( + self, + baseline_results: OriginalCodeBaseline, + candidate_behavior_results: TestResults, + optimization_candidate_index: int, + ) -> tuple[bool, list[TestDiff]]: + return compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + + def replace_function_and_helpers_with_optimized_code( + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], + ) -> bool: + from codeflash_python.static_analysis.code_replacer import replace_function_definitions_in_module + + did_update = False + for module_abspath, qualified_names in self.group_functions_by_file(code_context).items(): + did_update |= replace_function_definitions_in_module( + function_names=list(qualified_names), + optimized_code=optimized_code, + module_abspath=module_abspath, + preexisting_objects=code_context.preexisting_objects, + project_root_path=self.project_root, + ) + + unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code) + if unused_helpers: + revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code) + return did_update + + def fixup_generated_tests(self, generated_tests: GeneratedTestsList) -> GeneratedTestsList: + return generated_tests + + def line_profiler_step( + self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int + ) -> dict[str, Any]: + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + if contains_jit_decorator(candidate_fto_code): + logger.info( + "Skipping line profiler for %s - code contains JIT decorator", self.function_to_optimize.function_name + ) + return {"timings": {}, "unit": 0, "str_out": ""} + + for module_abspath in original_helper_code: + candidate_helper_code = Path(module_abspath).read_text("utf-8") + if contains_jit_decorator(candidate_helper_code): + logger.info( + "Skipping line profiler for %s - helper code contains JIT decorator", + self.function_to_optimize.function_name, + ) + return {"timings": {}, "unit": 0, "str_out": ""} + + try: + test_env = self.get_test_env( + codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1 + ) + line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context) + line_profile_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=False, + code_context=code_context, + line_profiler_output_file=line_profiler_output_file, + ) + finally: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results: + logger.warning( + "Timeout occurred while running line profiler for original function %s", + self.function_to_optimize.function_name, + ) + return {"timings": {}, "unit": 0, "str_out": ""} + if not isinstance(line_profile_results, TestResults) and line_profile_results.get("str_out") == "": + logger.warning( + "Couldn't run line profiler for original function %s", self.function_to_optimize.function_name + ) + return ( + line_profile_results + if not isinstance(line_profile_results, TestResults) + else {"timings": {}, "unit": 0, "str_out": ""} + ) + + # --- Core orchestration --- + + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: + should_run_experiment = self.experiment_id is not None + ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id}) + + # Early check: if --no-gen-tests is set, verify there are existing tests for this function + assert self.args is not None + if self.args.no_gen_tests: + from codeflash_python.models.function_types import qualified_name_with_modules_from_root + + func_qualname = qualified_name_with_modules_from_root(self.function_to_optimize, self.project_root) + if not self.function_to_tests.get(func_qualname): + return Err( + f"No existing tests found for '{self.function_to_optimize.function_name}'. " + f"Cannot optimize without tests when --no-gen-tests is set." + ) + + self.cleanup_leftover_test_return_values() + file_name_from_test_module_name.cache_clear() + ctx_result = self.get_code_optimization_context() + if not ctx_result.is_ok(): + return Err(cast("Err", ctx_result).error) + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + # Random here means that we still attempt optimization with a fractional chance to see if + # last time we could not find an optimization, maybe this time we do. + # Random is before as a performance optimization, swapping the two 'and' statements has the same effect + assert self.args is not None + self.code_already_exists = was_function_previously_optimized(self.function_to_optimize, code_context, self.args) + if random.random() > REPEAT_OPTIMIZATION_PROBABILITY and self.code_already_exists: # noqa: S311 + return Err("Function optimization previously attempted, skipping.") + + return Ok((should_run_experiment, code_context, original_helper_code)) + + # note: this isn't called by the lsp, only called by cli + def optimize_function(self) -> Result[BestOptimization, str]: + from codeflash_python.code_utils.code_utils import restore_conftest + + initialization_result = self.can_be_optimized() + if not initialization_result.is_ok(): + return Err(cast("Err", initialization_result).error) + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + self.analyze_code_characteristics(code_context) + + new_code_context = code_context + # Generate tests and optimizations in parallel + future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) + assert self.args is not None + future_optimizations = self.executor.submit( + self.generate_optimizations, + read_writable_code=code_context.read_writable_code, + read_only_context_code=code_context.read_only_context_code, + run_experiment=should_run_experiment, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, + ) + + concurrent.futures.wait([future_tests, future_optimizations]) + + test_setup_result = future_tests.result() + optimization_result = future_optimizations.result() + + if not test_setup_result.is_ok(): + return Err(cast("Err", test_setup_result).error) + + if not optimization_result.is_ok(): + return Err(cast("Err", optimization_result).error) + + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + optimizations_set, function_references = optimization_result.unwrap() + + precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None + assert self.args is not None + if generated_tests.generated_tests and self.args.testgen_review: + review_result = self.review_and_repair_tests( + generated_tests=generated_tests, code_context=code_context, original_helper_code=original_helper_code + ) + if not review_result.is_ok(): + return Err(cast("Err", review_result).error) + generated_tests, review_behavioral, review_coverage = review_result.unwrap() + if review_behavioral is not None: + precomputed_behavioral = (review_behavioral, review_coverage) + + # Full baseline (behavioral + benchmarking) runs once on the final approved tests + baseline_setup_result = self.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + precomputed_behavioral=precomputed_behavioral, + ) + + if not baseline_setup_result.is_ok(): + return Err(cast("Err", baseline_setup_result).error) + + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() + + best_optimization = self.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + function_references=function_references, + ) + + # Add function to code context hash if in gh actions and code doesn't already exist + from codeflash_python.api.cfapi import add_code_context_hash + + if not self.code_already_exists: + add_code_context_hash(code_context.hashing_code_context_hash) + + assert self.args is not None + if self.args.override_fixtures and original_conftest_content is not None: + restore_conftest(original_conftest_content) + if not best_optimization: + return Err(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") + return Ok(best_optimization) + + def generate_optimizations( + self, + read_writable_code: CodeStringsMarkdown, + read_only_context_code: str, + run_experiment: bool = False, + is_numerical_code: bool | None = None, + ) -> Result[tuple[OptimizationSet, str], str]: + """Generate optimization candidates for the function. Backend handles multi-model diversity.""" + assert self.aiservice_client is not None + n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) + future_optimization_candidates = self.executor.submit( + self.aiservice_client.optimize_code, + read_writable_code.markdown, + read_only_context_code, + self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, + ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, + language=self.function_to_optimize.language, + language_version=PYTHON_LANGUAGE_VERSION, + is_async=self.function_to_optimize.is_async, + n_candidates=n_candidates, + is_numerical_code=is_numerical_code, + ) + + future_references = self.executor.submit( + self.get_optimization_review_metrics, + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + self.function_to_optimize.language, + ) + + futures = [future_optimization_candidates, future_references] + future_candidates_exp = None + + if run_experiment: + assert self.local_aiservice_client is not None + future_candidates_exp = self.executor.submit( + self.local_aiservice_client.optimize_code, + read_writable_code.markdown, + read_only_context_code, + self.function_trace_id[:-4] + "EXP1", + ExperimentMetadata(id=self.experiment_id, group="experiment"), + language=self.function_to_optimize.language, + language_version=PYTHON_LANGUAGE_VERSION, + is_async=self.function_to_optimize.is_async, + n_candidates=n_candidates, + ) + futures.append(future_candidates_exp) + + # Wait for optimization futures to complete + concurrent.futures.wait(futures) + + # Retrieve results - optimize_python_code returns list of candidates + candidates = future_optimization_candidates.result() + + if not candidates: + return Err(f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}") + + # Handle experiment results + candidates_experiment = None + if future_candidates_exp: + candidates_experiment = future_candidates_exp.result() + function_references = future_references.result() + + return Ok((OptimizationSet(control=candidates, experiment=candidates_experiment), function_references)) diff --git a/src/codeflash_python/init_config.py b/src/codeflash_python/init_config.py new file mode 100644 index 000000000..5025eac32 --- /dev/null +++ b/src/codeflash_python/init_config.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import os +from enum import Enum +from functools import lru_cache +from pathlib import Path +from typing import Any + +import click +import inquirer +import inquirer.themes +import tomlkit +from pydantic.dataclasses import dataclass + +from codeflash_python.cli_common import apologize_and_exit +from codeflash_python.code_utils.compat import LF +from codeflash_python.code_utils.config_parser import parse_config_file +from codeflash_python.code_utils.env_utils import check_formatter_installed +from codeflash_python.telemetry.posthog_cf import ph + + +@dataclass(frozen=True) +class CLISetupInfo: + """Setup info for Python projects.""" + + module_root: str + tests_root: str + benchmarks_root: str | None + ignore_paths: list[str] + formatter: str | list[str] + git_remote: str + enable_telemetry: bool + + +@dataclass(frozen=True) +class VsCodeSetupInfo: + """Setup info for VSCode extension initialization.""" + + module_root: str + tests_root: str + formatter: str | list[str] + + +# Custom theme for better UX +class CodeflashTheme(inquirer.themes.Default): + def __init__(self) -> None: + super().__init__() + self.Question.mark_color = inquirer.themes.term.yellow # type: ignore[assignment] + self.Question.brackets_color = inquirer.themes.term.bright_blue # type: ignore[assignment] + self.Question.default_color = inquirer.themes.term.bright_cyan # type: ignore[assignment] + self.List.selection_color = inquirer.themes.term.bright_blue # type: ignore[assignment] + self.Checkbox.selection_color = inquirer.themes.term.bright_blue # type: ignore[assignment] + self.Checkbox.selected_icon = "✅" # type: ignore[assignment] + self.Checkbox.unselected_icon = "⬜" # type: ignore[assignment] + + +# common sections between normal mode and lsp mode +class CommonSections(Enum): + module_root = "module_root" + tests_root = "tests_root" + formatter_cmds = "formatter_cmds" + + def get_toml_key(self) -> str: + return self.value.replace("_", "-") + + +ignore_subdirs = {"venv", "dist", "build", "build_temp", "build_scripts", "env", "logs", "tmp", "__pycache__"} + + +@lru_cache(maxsize=1) +def get_valid_subdirs(current_dir: Path | None = None) -> list[str]: + + path_str = str(current_dir) if current_dir else "." + return [ + entry.name + for entry in os.scandir(path_str) + if entry.is_dir() and not entry.name.startswith((".", "__")) and entry.name not in ignore_subdirs + ] + + +def get_suggestions(section: CommonSections) -> tuple[list[str], str | None]: + valid_subdirs = get_valid_subdirs() + if section == CommonSections.module_root: + return [d for d in valid_subdirs if d != "tests"], None + if section == CommonSections.tests_root: + default = "tests" if "tests" in valid_subdirs else None + return valid_subdirs, default + if section == CommonSections.formatter_cmds: + return ["disabled", "ruff", "black"], "disabled" + msg = f"Unknown section: {section}" + raise ValueError(msg) + + +def config_found(pyproject_toml_path: str | Path) -> tuple[bool, str]: + pyproject_toml_path = Path(pyproject_toml_path) + + if not pyproject_toml_path.exists(): + return False, f"Configuration file not found: {pyproject_toml_path}" + + if not pyproject_toml_path.is_file(): + return False, f"Configuration file is not a file: {pyproject_toml_path}" + + if pyproject_toml_path.suffix != ".toml": + return False, f"Configuration file is not a .toml file: {pyproject_toml_path}" + + return True, "" + + +def is_valid_pyproject_toml(pyproject_toml_path: str | Path) -> tuple[bool, dict[str, Any] | None, str]: + pyproject_toml_path = Path(pyproject_toml_path) + try: + config, _ = parse_config_file(pyproject_toml_path) + except Exception as e: + return False, None, f"Failed to parse configuration: {e}" + + module_root = config.get("module_root") + if not module_root: + return False, config, "Missing required field: 'module_root'" + + if not Path(module_root).is_dir(): + return False, config, f"Invalid 'module_root': directory does not exist at {module_root}" + + tests_root = config.get("tests_root") + if not tests_root: + return False, config, "Missing required field: 'tests_root'" + + if not Path(tests_root).is_dir(): + return False, config, f"Invalid 'tests_root': directory does not exist at {tests_root}" + + return True, config, "" + + +def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]: + """Check if the current directory contains a valid pyproject.toml file with codeflash config. + + If it does, ask the user if they want to re-configure it. + """ + pyproject_toml_path = Path.cwd() / "pyproject.toml" + + found, _ = config_found(pyproject_toml_path) + if not found: + return True, None + + valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path) + if not valid: + # needs to be re-configured + return True, None + + return input( + "A valid Codeflash config already exists in this project. Do you want to re-configure it? [y/N] " + ).strip().lower() in ("y", "yes"), config + + +def get_formatter_cmds(formatter: str) -> list[str]: + if formatter == "black": + return ["black $file"] + if formatter == "ruff": + return ["ruff check --exit-zero --fix $file", "ruff format $file"] + if formatter == "other": + click.echo( + "🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code." + ) + return ["your-formatter $file"] + if formatter in {"don't use a formatter", "disabled"}: + return ["disabled"] + if " && " in formatter: + return formatter.split(" && ") + return [formatter] + + +# Create or update the pyproject.toml file with the Codeflash dependency & configuration +def configure_pyproject_toml(setup_info: VsCodeSetupInfo | CLISetupInfo, config_file: Path | None = None) -> bool: + for_vscode = isinstance(setup_info, VsCodeSetupInfo) + toml_path = config_file or Path.cwd() / "pyproject.toml" + try: + with toml_path.open(encoding="utf8") as pyproject_file: + pyproject_data = tomlkit.parse(pyproject_file.read()) + except FileNotFoundError: + click.echo( + f"I couldn't find a pyproject.toml in the current directory.{LF}" + f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file." + ) + return False + + codeflash_section = tomlkit.table() + codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory.")) + + if for_vscode: + for section in CommonSections: + if hasattr(setup_info, section.value): + codeflash_section[section.get_toml_key()] = getattr(setup_info, section.value) + elif isinstance(setup_info, CLISetupInfo): + codeflash_section["module-root"] = setup_info.module_root + codeflash_section["tests-root"] = setup_info.tests_root + codeflash_section["ignore-paths"] = setup_info.ignore_paths + if not setup_info.enable_telemetry: + codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry + if setup_info.git_remote not in ["", "origin"]: + codeflash_section["git-remote"] = setup_info.git_remote + + formatter = setup_info.formatter + + formatter_cmds = formatter if isinstance(formatter, list) else get_formatter_cmds(formatter) + + check_formatter_installed(formatter_cmds, exit_on_failure=False) + codeflash_section["formatter-cmds"] = formatter_cmds + # Add the 'codeflash' section, ensuring 'tool' section exists + tool_section = pyproject_data.get("tool", tomlkit.table()) + + if for_vscode: + # merge the existing codeflash section, instead of overwriting it + existing_codeflash = tool_section.get("codeflash", tomlkit.table()) + + for key, value in codeflash_section.items(): + existing_codeflash[key] = value + tool_section["codeflash"] = existing_codeflash + else: + tool_section["codeflash"] = codeflash_section + + pyproject_data["tool"] = tool_section + + with toml_path.open("w", encoding="utf8") as pyproject_file: + pyproject_file.write(tomlkit.dumps(pyproject_data)) + click.echo(f"Added Codeflash configuration to {toml_path}") + click.echo() + return True + + +def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None: + ph("cli-create-pyproject-toml") + lsp_mode = False + # Define a minimal pyproject.toml content + new_pyproject_toml = tomlkit.document() + new_pyproject_toml["tool"] = {"codeflash": {}} + try: + pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8") + + # Check if the pyproject.toml file was created + if pyproject_toml_path.exists() and not lsp_mode: + print(f"Created a pyproject.toml file at {pyproject_toml_path}") + print("Your project is now ready for Codeflash configuration!") + print("\nPress any key to continue...") + input() + ph("cli-created-pyproject-toml") + except OSError: + click.echo("❌ Failed to create pyproject.toml. Please check your disk permissions and available space.") + apologize_and_exit() + + +def ask_for_telemetry() -> bool: + """Prompt the user to enable or disable telemetry.""" + return input( + "Help us improve Codeflash by sharing anonymous usage data (e.g. errors encountered)? [Y/n] " + ).strip().lower() not in ("n", "no") diff --git a/src/codeflash_python/models/__init__.py b/src/codeflash_python/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/models/call_graph.py b/src/codeflash_python/models/call_graph.py new file mode 100644 index 000000000..93cc2d936 --- /dev/null +++ b/src/codeflash_python/models/call_graph.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import logging +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, NamedTuple + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_python.models.models import FunctionSource + + +class FunctionNode(NamedTuple): + file_path: Path + qualified_name: str + + +@dataclass(frozen=True) +class CalleeMetadata: + fully_qualified_name: str + only_function_name: str + definition_type: str + source_line: str + + +@dataclass(frozen=True) +class CallEdge: + caller: FunctionNode + callee: FunctionNode + is_cross_file: bool + call_count: int | None = None + total_time_ns: int | None = None + callee_metadata: CalleeMetadata | None = None + + +@dataclass +class CallGraph: + edges: list[CallEdge] + forward: dict[FunctionNode, list[CallEdge]] = field(default_factory=dict, init=False, repr=False) + reverse: dict[FunctionNode, list[CallEdge]] = field(default_factory=dict, init=False, repr=False) + nodes: set[FunctionNode] = field(default_factory=set, init=False, repr=False) + + def __post_init__(self) -> None: + fwd: dict[FunctionNode, list[CallEdge]] = {} + rev: dict[FunctionNode, list[CallEdge]] = {} + nodes: set[FunctionNode] = set() + for edge in self.edges: + fwd.setdefault(edge.caller, []).append(edge) + rev.setdefault(edge.callee, []).append(edge) + nodes.add(edge.caller) + nodes.add(edge.callee) + self.forward = fwd + self.reverse = rev + self.nodes = nodes + + def callees_of(self, node: FunctionNode) -> list[CallEdge]: + return self.forward.get(node, []) + + def callers_of(self, node: FunctionNode) -> list[CallEdge]: + return self.reverse.get(node, []) + + def descendants(self, node: FunctionNode, max_depth: int | None = None) -> set[FunctionNode]: + visited: set[FunctionNode] = set() + forward_map = self.forward + if max_depth is None: + queue: deque[FunctionNode] = deque([node]) + while queue: + current = queue.popleft() + for edge in forward_map.get(current, []): + if edge.callee not in visited: + visited.add(edge.callee) + queue.append(edge.callee) + else: + depth_queue: deque[tuple[FunctionNode, int]] = deque([(node, 0)]) + while depth_queue: + current, depth = depth_queue.popleft() + if depth >= max_depth: + continue + for edge in forward_map.get(current, []): + if edge.callee not in visited: + visited.add(edge.callee) + depth_queue.append((edge.callee, depth + 1)) + return visited + + def ancestors(self, node: FunctionNode, max_depth: int | None = None) -> set[FunctionNode]: + visited: set[FunctionNode] = set() + reverse_map = self.reverse + if max_depth is None: + queue: list[FunctionNode] = [node] + while queue: + current = queue.pop() + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + queue.append(edge.caller) + else: + depth_queue: list[tuple[FunctionNode, int]] = [(node, 0)] + while depth_queue: + current, depth = depth_queue.pop() + if depth >= max_depth: + continue + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + depth_queue.append((edge.caller, depth + 1)) + return visited + + def subgraph(self, nodes: set[FunctionNode]) -> CallGraph: + filtered = [e for e in self.edges if e.caller in nodes and e.callee in nodes] + return CallGraph(edges=filtered) + + def leaf_functions(self) -> set[FunctionNode]: + all_nodes = self.nodes + return all_nodes - set(self.forward.keys()) + + def root_functions(self) -> set[FunctionNode]: + all_nodes = self.nodes + return all_nodes - set(self.reverse.keys()) + + def topological_order(self) -> list[FunctionNode]: + in_degree: dict[FunctionNode, int] = {} + all_nodes = self.nodes + for node in all_nodes: + in_degree.setdefault(node, 0) + for edge in self.edges: + in_degree[edge.callee] = in_degree.get(edge.callee, 0) + 1 + + forward_map = self.forward + queue = deque(node for node, deg in in_degree.items() if deg == 0) + result: list[FunctionNode] = [] + while queue: + node = queue.popleft() + result.append(node) + for edge in forward_map.get(node, []): + in_degree[edge.callee] -= 1 + if in_degree[edge.callee] == 0: + queue.append(edge.callee) + + if len(result) < len(all_nodes): + logger.warning( + "Call graph contains cycles: %d of %d nodes excluded from topological order", + len(all_nodes) - len(result), + len(all_nodes), + ) + + # Leaves-first: reverse the topological order + result.reverse() + return result + + +def augment_with_trace(graph: CallGraph, trace_db_path: Path) -> CallGraph: + import sqlite3 + + conn = sqlite3.connect(str(trace_db_path)) + try: + rows = conn.execute( + "SELECT filename, function, class_name, call_count_nonrecursive, total_time_ns FROM pstats" + ).fetchall() + except sqlite3.OperationalError: + conn.close() + return graph + conn.close() + + lookup: dict[tuple[str, str], tuple[int, int]] = {} + for filename, function, class_name, call_count, total_time in rows: + if class_name: + qn = f"{class_name}.{function}" + else: + qn = function + lookup[(filename, qn)] = (call_count, total_time) + + augmented_edges: list[CallEdge] = [] + for edge in graph.edges: + callee_file = str(edge.callee.file_path) + callee_qn = edge.callee.qualified_name + stats = lookup.get((callee_file, callee_qn)) + if stats is not None: + call_count, total_time = stats + augmented_edges.append( + CallEdge( + caller=edge.caller, + callee=edge.callee, + is_cross_file=edge.is_cross_file, + call_count=call_count, + total_time_ns=total_time, + callee_metadata=edge.callee_metadata, + ) + ) + else: + augmented_edges.append(edge) + + return CallGraph(edges=augmented_edges) + + +def callees_from_graph(graph: CallGraph) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + + from codeflash_python.models.models import FunctionSource + + file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) + function_source_list: list[FunctionSource] = [] + + for edge in graph.edges: + meta = edge.callee_metadata + if meta is None: + continue + callee_path = edge.callee.file_path + fs = FunctionSource( + file_path=callee_path, + qualified_name=edge.callee.qualified_name, + fully_qualified_name=meta.fully_qualified_name, + only_function_name=meta.only_function_name, + source_code=meta.source_line, + definition_type=meta.definition_type, + ) + file_path_to_function_source[callee_path].add(fs) + function_source_list.append(fs) + + return file_path_to_function_source, function_source_list diff --git a/src/codeflash_python/models/experiment_metadata.py b/src/codeflash_python/models/experiment_metadata.py new file mode 100644 index 000000000..6f5e16018 --- /dev/null +++ b/src/codeflash_python/models/experiment_metadata.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from pydantic import BaseModel + + +class ExperimentMetadata(BaseModel): + id: str | None + group: str diff --git a/src/codeflash_python/models/function_types.py b/src/codeflash_python/models/function_types.py new file mode 100644 index 000000000..52df82da2 --- /dev/null +++ b/src/codeflash_python/models/function_types.py @@ -0,0 +1,17 @@ +"""Python-specific helpers for FunctionToOptimize. + +The canonical types (FunctionToOptimize, FunctionParent) live in +codeflash_core.models. Import them from there directly. +""" + +from __future__ import annotations + +from pathlib import Path # noqa: TC003 + +from codeflash_core.models import FunctionToOptimize # noqa: TC001 + + +def qualified_name_with_modules_from_root(fto: FunctionToOptimize, project_root_path: Path) -> str: + from codeflash_python.code_utils.code_utils import module_name_from_file_path + + return f"{module_name_from_file_path(fto.file_path, project_root_path)}.{fto.qualified_name}" diff --git a/src/codeflash_python/models/models.py b/src/codeflash_python/models/models.py new file mode 100644 index 000000000..ead112341 --- /dev/null +++ b/src/codeflash_python/models/models.py @@ -0,0 +1,817 @@ +from __future__ import annotations + +import os +from collections import Counter, defaultdict +from collections.abc import Collection +from functools import lru_cache +from re import Pattern +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionParent +from codeflash_python.models.test_type import TestType + +if TYPE_CHECKING: + from collections.abc import Iterator + +import enum +import logging +import re +import sys +from enum import Enum +from pathlib import Path +from typing import Any, cast + +from pydantic import BaseModel, ConfigDict, PrivateAttr, ValidationError, model_validator +from pydantic.dataclasses import dataclass + +from codeflash_python.code_utils.code_utils import module_name_from_file_path, validate_python_code + +logger = logging.getLogger("codeflash_python") + +DEBUG_MODE = os.environ.get("CODEFLASH_DEBUG", "").lower() in ("1", "true") + + +# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully +# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name +# of the module is foo.eggs. + + +class ValidCode(BaseModel): + model_config = ConfigDict(frozen=True) + + source_code: str + normalized_code: str + + +@dataclass(frozen=True) +class FunctionSource: + file_path: Path + qualified_name: str + fully_qualified_name: str + only_function_name: str + source_code: str + definition_type: str | None = None # e.g. "function", "class"; None for non-Python languages + + def __eq__(self, other: object) -> bool: + if not isinstance(other, FunctionSource): + return False + return ( + self.file_path == other.file_path + and self.qualified_name == other.qualified_name + and self.fully_qualified_name == other.fully_qualified_name + and self.only_function_name == other.only_function_name + and self.source_code == other.source_code + ) + + def __hash__(self) -> int: + return hash( + (self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code) + ) + + +class BestOptimization(BaseModel): + candidate: OptimizedCandidate + explanation_v2: str | None = None + helper_functions: list[FunctionSource] + code_context: CodeOptimizationContext + runtime: int + replay_performance_gain: dict[BenchmarkKey, float] | None = None + winning_behavior_test_results: TestResults + winning_benchmarking_test_results: TestResults + winning_replay_benchmarking_test_results: TestResults | None = None + line_profiler_test_results: dict[Any, Any] + async_throughput: int | None = None + concurrency_metrics: ConcurrencyMetrics | None = None + + +@dataclass(frozen=True) +class BenchmarkKey: + module_path: str + function_name: str + + def __str__(self) -> str: + return f"{self.module_path}::{self.function_name}" + + +@dataclass +class ConcurrencyMetrics: + sequential_time_ns: int + concurrent_time_ns: int + concurrency_factor: int + concurrency_ratio: float # sequential_time / concurrent_time + + +@dataclass +class BenchmarkDetail: + benchmark_name: str + test_function: str + original_timing: str + expected_new_timing: str + speedup_percent: float + + def to_string(self) -> str: + return ( + f"Original timing for {self.benchmark_name}::{self.test_function}: {self.original_timing}\n" + f"Expected new timing for {self.benchmark_name}::{self.test_function}: {self.expected_new_timing}\n" + f"Benchmark speedup for {self.benchmark_name}::{self.test_function}: {self.speedup_percent:.2f}%\n" + ) + + def to_dict(self) -> dict[str, Any]: + return { + "benchmark_name": self.benchmark_name, + "test_function": self.test_function, + "original_timing": self.original_timing, + "expected_new_timing": self.expected_new_timing, + "speedup_percent": self.speedup_percent, + } + + +@dataclass +class ProcessedBenchmarkInfo: + benchmark_details: list[BenchmarkDetail] + + def to_string(self) -> str: + if not self.benchmark_details: + return "" + + result = "Benchmark Performance Details:\n" + for detail in self.benchmark_details: + result += detail.to_string() + "\n" + return result + + def to_dict(self) -> dict[str, list[dict[str, Any]]]: + return {"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]} + + +class CodeString(BaseModel): + code: str + file_path: Path | None = None + language: str = "python" # Language for validation + + @model_validator(mode="after") + def validate_code_syntax(self) -> CodeString: + """Validate code syntax for the specified language.""" + if self.language == "python": + validate_python_code(self.code) + else: + try: + compile(self.code, "", "exec") + except SyntaxError: + msg = f"Invalid {self.language.title()} code" + raise ValueError(msg) from None + return self + + +def get_comment_prefix(file_path: Path) -> str: + """Get the comment prefix for a given language.""" + return "#" + + +def get_code_block_splitter(file_path: Path | None) -> str: + if file_path is None: + return "" + comment_prefix = get_comment_prefix(file_path) + return f"{comment_prefix} file: {file_path.as_posix()}" + + +# Pattern to match markdown code blocks with optional language tag and file path +# Matches: ```language:filepath\ncode\n``` or ```language\ncode\n``` +markdown_pattern = re.compile(r"```(\w+)(?::([^\n]+))?\n(.*?)\n```", re.DOTALL) + + +class CodeStringsMarkdown(BaseModel): + code_strings: list[CodeString] = [] + language: str = "python" # Language for markdown code block tags + _cache: dict[str, Any] = PrivateAttr(default_factory=dict) + + @property + def flat(self) -> str: + """Returns the combined source code module from all code blocks. + + Each block is prefixed by a file path comment to indicate its origin. + The comment prefix is determined by the language attribute. + + Returns: + str: The concatenated code of all blocks with file path annotations. + + !! Important !!: + Avoid parsing the flat code with multiple files, + parsing may result in unexpected behavior. + + + """ + if self._cache.get("flat") is not None: + return self._cache["flat"] + self._cache["flat"] = "\n".join( + get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings + ) + return self._cache["flat"] + + @property + def markdown(self) -> str: + """Returns a Markdown-formatted string containing all code blocks. + + Each block is enclosed in a triple-backtick code block with an optional + file path suffix (e.g., ```python:filename.py). + + The language tag is determined by the `language` attribute. + + Returns: + str: Markdown representation of the code blocks. + + """ + return "\n".join( + [ + f"```{self.language}{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```" + for code_string in self.code_strings + ] + ) + + def file_to_path(self) -> dict[str, str]: + """Return a dictionary mapping file paths to their corresponding code blocks. + + Returns: + dict[str, str]: Mapping from file path (as string) to code. + + """ + if self._cache.get("file_to_path") is not None: + return self._cache["file_to_path"] + self._cache["file_to_path"] = { + str(code_string.file_path): code_string.code for code_string in self.code_strings + } + return self._cache["file_to_path"] + + @staticmethod + def parse_markdown_code(markdown_code: str, expected_language: str = "python") -> CodeStringsMarkdown: + """Parse a Markdown string into a CodeStringsMarkdown object. + + Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance. + + Args: + markdown_code (str): The Markdown-formatted string to parse. + expected_language (str): The expected language of code blocks (default: "python"). + + Returns: + CodeStringsMarkdown: Parsed object containing code blocks. + + """ + matches = markdown_pattern.findall(markdown_code) + code_string_list = [] + detected_language = expected_language + try: + for language, file_path, code in matches: + # Use the first detected language or the expected language + if language: + detected_language = language + if file_path: + path = file_path.strip() + code_string_list.append(CodeString(code=code, file_path=Path(path), language=detected_language)) + else: + # No file path specified - skip this block or create with None + code_string_list.append(CodeString(code=code, file_path=None, language=detected_language)) + return CodeStringsMarkdown(code_strings=code_string_list, language=detected_language) + except ValidationError: + # if any file is invalid, return an empty CodeStringsMarkdown for the entire context + return CodeStringsMarkdown(language=expected_language) + + +class CodeOptimizationContext(BaseModel): + testgen_context: CodeStringsMarkdown + read_writable_code: CodeStringsMarkdown + read_only_context_code: str = "" + hashing_code_context: str = "" + hashing_code_context_hash: str = "" + helper_functions: list[FunctionSource] + testgen_helper_fqns: list[str] = [] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] + + +class OptimizedCandidateResult(BaseModel): + max_loop_count: int + best_test_runtime: int + behavior_test_results: TestResults + benchmarking_test_results: TestResults + replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None + optimization_candidate_index: int + total_candidate_timing: int + async_throughput: int | None = None + concurrency_metrics: ConcurrencyMetrics | None = None + + +class GeneratedTests(BaseModel): + generated_original_test_source: str + instrumented_behavior_test_source: str + instrumented_perf_test_source: str + raw_generated_test_source: str | None = None + behavior_file_path: Path + perf_file_path: Path + + +class GeneratedTestsList(BaseModel): + generated_tests: list[GeneratedTests] + + +class TestFile(BaseModel): + instrumented_behavior_file_path: Path + benchmarking_file_path: Path | None = None + original_file_path: Path | None = None + original_source: str | None = None + test_type: TestType + tests_in_file: list[TestsInFile] | None = None + + +class TestFiles(BaseModel): + test_files: list[TestFile] + + def get_by_type(self, test_type: TestType) -> TestFiles: + return TestFiles(test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type]) + + def add(self, test_file: TestFile) -> None: + if test_file not in self.test_files: + self.test_files.append(test_file) + else: + msg = "Test file already exists in the list" + raise ValueError(msg) + + def get_by_original_file_path(self, file_path: Path) -> TestFile | None: + normalized = self._normalize_path_for_comparison(file_path) + for test_file in self.test_files: + if test_file.original_file_path is None: + continue + normalized_test_path = self._normalize_path_for_comparison(test_file.original_file_path) + if normalized == normalized_test_path: + return test_file + return None + + def get_test_type_by_instrumented_file_path(self, file_path: Path) -> TestType | None: + normalized = self._normalize_path_for_comparison(file_path) + for test_file in self.test_files: + normalized_behavior_path = self._normalize_path_for_comparison(test_file.instrumented_behavior_file_path) + if normalized == normalized_behavior_path: + return test_file.test_type + if test_file.benchmarking_file_path is not None: + normalized_benchmark_path = self._normalize_path_for_comparison(test_file.benchmarking_file_path) + if normalized == normalized_benchmark_path: + return test_file.test_type + + # Fallback: try filename-only matching when normalized paths don't match + file_name = file_path.name + for test_file in self.test_files: + if ( + test_file.instrumented_behavior_file_path + and test_file.instrumented_behavior_file_path.name == file_name + ): + return test_file.test_type + if test_file.benchmarking_file_path and test_file.benchmarking_file_path.name == file_name: + return test_file.test_type + + return None + + def get_test_type_by_original_file_path(self, file_path: Path) -> TestType | None: + normalized = self._normalize_path_for_comparison(file_path) + for test_file in self.test_files: + if test_file.original_file_path is None: + continue + normalized_test_path = self._normalize_path_for_comparison(test_file.original_file_path) + if normalized == normalized_test_path: + return test_file.test_type + return None + + @staticmethod + @lru_cache(maxsize=4096) + def _normalize_path_for_comparison(path: Path) -> str: + """Normalize a path for cross-platform comparison. + + Resolves the path to an absolute path and handles Windows case-insensitivity. + """ + try: + resolved = str(path.resolve()) + except (OSError, RuntimeError): + # If resolve fails (e.g., file doesn't exist), use absolute path + resolved = str(path.absolute()) + # Only lowercase on Windows where filesystem is case-insensitive + return resolved.lower() if sys.platform == "win32" else resolved + + def __iter__(self) -> Iterator[TestFile]: # type: ignore[override] + return iter(self.test_files) + + def __len__(self) -> int: + return len(self.test_files) + + +class OptimizationSet(BaseModel): + control: list[OptimizedCandidate] + experiment: list[OptimizedCandidate] | None + + +@dataclass(frozen=True) +class TestsInFile: + test_file: Path + test_class: str | None + test_function: str + test_type: TestType + + +class OptimizedCandidateSource(str, Enum): + OPTIMIZE = "OPTIMIZE" + OPTIMIZE_LP = "OPTIMIZE_LP" + REFINE = "REFINE" + REPAIR = "REPAIR" + ADAPTIVE = "ADAPTIVE" + JIT_REWRITE = "JIT_REWRITE" + + +@dataclass(frozen=True) +class OptimizedCandidate: + source_code: CodeStringsMarkdown + explanation: str + optimization_id: str + source: OptimizedCandidateSource + parent_id: str | None = None + model: str | None = None # Which LLM model generated this candidate + + +@dataclass(frozen=True) +class FunctionCalledInTest: + tests_in_file: TestsInFile + position: CodePosition + + +@dataclass(frozen=True) +class CodePosition: + line_no: int + col_no: int + + +class OriginalCodeBaseline(BaseModel): + behavior_test_results: TestResults + benchmarking_test_results: TestResults + replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None + line_profile_results: dict + runtime: int + coverage_results: CoverageData | None + async_throughput: int | None = None + concurrency_metrics: ConcurrencyMetrics | None = None + + +class CoverageStatus(Enum): + NOT_FOUND = "Coverage Data Not Found" + PARSED_SUCCESSFULLY = "Parsed Successfully" + + +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class CoverageData: + """Represents the coverage data for a specific function in a source file, using one or more test files.""" + + file_path: Path + coverage: float + function_name: str + functions_being_tested: list[str] + graph: dict[str, dict[str, Collection[object]]] + code_context: CodeOptimizationContext + main_func_coverage: FunctionCoverage + dependent_func_coverage: FunctionCoverage | None + status: CoverageStatus + blank_re: Pattern[str] = re.compile(r"\s*(#|$)") + else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)") + + def build_message(self) -> str: + if self.status == CoverageStatus.NOT_FOUND: + return f"No coverage data found for {self.function_name}" + return f"{self.coverage:.1f}%" + + def log_coverage(self) -> None: + lines = ["Test Coverage Results", f" Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%"] + if self.dependent_func_coverage: + lines.append( + f" Dependent Function: {self.dependent_func_coverage.name}: {self.dependent_func_coverage.coverage:.2f}%" + ) + lines.append(f" Total Coverage: {self.coverage:.2f}%") + logger.info("\n".join(lines)) + + if not self.coverage: + logger.debug(self.graph) + + @classmethod + def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOptimizationContext) -> CoverageData: + return cls( + file_path=file_path, + coverage=0.0, + function_name=function_name, + functions_being_tested=[function_name], + graph={ + function_name: { + "executed_lines": set(), + "unexecuted_lines": set(), + "executed_branches": [], + "unexecuted_branches": [], + } + }, + code_context=code_context, + main_func_coverage=FunctionCoverage( + name=function_name, + coverage=0.0, + executed_lines=[], + unexecuted_lines=[], + executed_branches=[], + unexecuted_branches=[], + ), + dependent_func_coverage=None, + status=CoverageStatus.NOT_FOUND, + ) + + +@dataclass +class FunctionCoverage: + """Represents the coverage data for a specific function in a source file.""" + + name: str + coverage: float + executed_lines: list[int] + unexecuted_lines: list[int] + executed_branches: list[list[int]] + unexecuted_branches: list[list[int]] + + +class TestingMode(enum.Enum): + BEHAVIOR = "behavior" + PERFORMANCE = "performance" + LINE_PROFILE = "line_profile" + CONCURRENCY = "concurrency" + + +# Intentionally duplicated in codeflash_capture (runs in subprocess, can't import from here) +class VerificationType(str, Enum): + FUNCTION_CALL = ( + "function_call" # Correctness verification for a test function, checks input values and output values) + ) + INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init + INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init + + +@dataclass(frozen=True) +class InvocationId: + test_module_path: str # The fully qualified name of the test module + test_class_name: str | None # The name of the class where the test is defined + test_function_name: str | None # The name of the test_function. Does not include the components of the file_name + function_getting_tested: str + iteration_id: str | None + + # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id + def id(self) -> str: + class_prefix = f"{self.test_class_name}." if self.test_class_name else "" + return ( + f"{self.test_module_path}:{class_prefix}{self.test_function_name}:" + f"{self.function_getting_tested}:{self.iteration_id}" + ) + + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) + + @staticmethod + def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: + components = string_id.split(":") + assert len(components) == 4 + second_components = components[1].split(".") + if len(second_components) == 1: + test_class_name = None + test_function_name = second_components[0] + else: + test_class_name = second_components[0] + test_function_name = second_components[1] + return InvocationId( + test_module_path=components[0], + test_class_name=test_class_name, + test_function_name=test_function_name, + function_getting_tested=components[2], + iteration_id=iteration_id if iteration_id else components[3], + ) + + +@dataclass(frozen=True) +class FunctionTestInvocation: + loop_index: int # The loop index of the function invocation, starts at 1 + id: InvocationId # The fully qualified name of the function invocation (id) + file_name: Path # The file where the test is defined + did_pass: bool # Whether the test this function invocation was part of, passed or failed + runtime: int | None # Time in nanoseconds + test_framework: str # unittest or pytest + test_type: TestType + return_value: object | None # The return value of the function invocation + timed_out: bool | None + verification_type: str | None = VerificationType.FUNCTION_CALL + stdout: str | None = None + + @property + def unique_invocation_loop_id(self) -> str: + return f"{self.loop_index}:{self.id.id()}" + + +class TestResults(BaseModel): # noqa: PLW1641 + # don't modify these directly, use the add method + # also we don't support deletion of test results elements - caution is advised + test_results: list[FunctionTestInvocation] = [] + test_result_idx: dict[str, int] = {} + + perf_stdout: str | None = None + # mapping between test function name and stdout failure message + test_failures: dict[str, str] | None = None + + def add(self, function_test_invocation: FunctionTestInvocation) -> None: + unique_id = function_test_invocation.unique_invocation_loop_id + test_result_idx = self.test_result_idx + if unique_id in test_result_idx: + if DEBUG_MODE: + logger.warning("Test result with id %s already exists. SKIPPING", unique_id) + return + test_results = self.test_results + test_result_idx[unique_id] = len(test_results) + test_results.append(function_test_invocation) + + def merge(self, other: TestResults) -> None: + original_len = len(self.test_results) + self.test_results.extend(other.test_results) + for k, v in other.test_result_idx.items(): + if k in self.test_result_idx: + msg = f"Test result with id {k} already exists." + raise ValueError(msg) + self.test_result_idx[k] = v + original_len + + def group_by_benchmarks( + self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path + ) -> dict[BenchmarkKey, TestResults]: + """Group TestResults by benchmark for calculating improvements for each benchmark.""" + test_results_by_benchmark = defaultdict(TestResults) + benchmark_module_path = {} + for benchmark_key in benchmark_keys: + benchmark_module_path[benchmark_key] = module_name_from_file_path( + benchmark_replay_test_dir.resolve() + / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", + project_root, + traverse_up=True, + ) + for test_result in self.test_results: + if test_result.test_type == TestType.REPLAY_TEST: + for benchmark_key, module_path in benchmark_module_path.items(): + if test_result.id.test_module_path.startswith(module_path): + test_results_by_benchmark[benchmark_key].add(test_result) + + return test_results_by_benchmark + + def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None: + try: + return self.test_results[self.test_result_idx[unique_invocation_loop_id]] + except (IndexError, KeyError): + return None + + def get_all_ids(self) -> set[InvocationId]: + return {test_result.id for test_result in self.test_results} + + def get_all_unique_invocation_loop_ids(self) -> set[str]: + return {test_result.unique_invocation_loop_id for test_result in self.test_results} + + def number_of_loops(self) -> int: + if not self.test_results: + return 0 + return max(test_result.loop_index for test_result in self.test_results) + + def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: + report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType} + for test_result in self.test_results: + if test_result.loop_index != 1: + continue + if test_result.did_pass: + report[test_result.test_type]["passed"] += 1 + else: + report[test_result.test_type]["failed"] += 1 + return report + + @staticmethod + def report_to_string(report: dict[TestType, dict[str, int]]) -> str: + return " ".join( + [ + f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" + for test_type in TestType + ] + ) + + @staticmethod + def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> str: + lines = [title] + for test_type in TestType: + if test_type is TestType.INIT_STATE_TEST: + continue + lines.append( + f" {test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}" + ) + return "\n".join(lines) + + def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: + # Efficient single traversal, directly accumulating into a dict. + # can track mins here and only sums can be return in total_passed_runtime + by_id: dict[InvocationId, list[int]] = {} + for result in self.test_results: + if result.did_pass: + if result.runtime: + by_id.setdefault(result.id, []).append(result.runtime) + else: + msg = ( + f"Ignoring test case that passed but had no runtime -> {result.id}, " + f"Loop # {result.loop_index}, Test Type: {result.test_type}, " + f"Verification Type: {result.verification_type}" + ) + logger.debug(msg) + return by_id + + def total_passed_runtime(self) -> int: + """Calculate the sum of runtimes of all test cases that passed. + + A testcase runtime is the minimum value of all looped execution runtimes. + + :return: The runtime in nanoseconds. + """ + # TODO this doesn't look at the intersection of tests of baseline and original + return sum( + [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] + ) + + def effective_loop_count(self) -> int: + """Calculate the effective number of complete loops. + + Returns the maximum loop_index seen across all test results. This represents + the number of timing iterations that were performed. + + :return: The effective loop count, or 0 if no test results. + """ + if not self.test_results: + return 0 + # Get all loop indices from results that have timing data + loop_indices = {result.loop_index for result in self.test_results if result.runtime is not None} + if not loop_indices: + # Fallback: use all loop indices even without runtime + loop_indices = {result.loop_index for result in self.test_results} + return max(loop_indices) if loop_indices else 0 + + def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]: + map_gen_test_file_to_no_of_tests = Counter() + for gen_test_result in self.test_results: + if ( + gen_test_result.test_type == TestType.GENERATED_REGRESSION + and gen_test_result.id.test_function_name not in test_functions_to_remove + ): + map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1 + return map_gen_test_file_to_no_of_tests + + def __iter__(self) -> Iterator[FunctionTestInvocation]: # type: ignore[override] + return iter(self.test_results) + + def __len__(self) -> int: + return len(self.test_results) + + def __getitem__(self, index: int) -> FunctionTestInvocation: + return self.test_results[index] + + def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: + self.test_results[index] = value + + def __contains__(self, value: FunctionTestInvocation) -> bool: + return value in self.test_results + + def __bool__(self) -> bool: + return bool(self.test_results) + + def __eq__(self, other: object) -> bool: + # Unordered comparison + if type(self) is not type(other): + return False + if len(self) != len(other): # type: ignore[arg-type] + return False + from codeflash_python.verification.comparator import comparator + + original_recursion_limit = sys.getrecursionlimit() + cast("TestResults", other) + for test_result in self: + other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) # type: ignore[attr-defined] + if other_test_result is None: + return False + + if original_recursion_limit < 5000: + sys.setrecursionlimit(5000) + if ( + test_result.file_name != other_test_result.file_name + or test_result.did_pass != other_test_result.did_pass + or test_result.runtime != other_test_result.runtime + or test_result.test_framework != other_test_result.test_framework + or test_result.test_type != other_test_result.test_type + or not comparator(test_result.return_value, other_test_result.return_value) + ): + sys.setrecursionlimit(original_recursion_limit) + return False + sys.setrecursionlimit(original_recursion_limit) + return True diff --git a/src/codeflash_python/models/test_result.py b/src/codeflash_python/models/test_result.py new file mode 100644 index 000000000..0942ec705 --- /dev/null +++ b/src/codeflash_python/models/test_result.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + + +@dataclass +class TestInfo: + """Information about a test that exercises a function. + + Attributes: + test_name: Name of the test function. + test_file: Path to the test file. + test_class: Name of the test class, if any. + + """ + + test_name: str + test_file: Path + test_class: str | None = None + + @property + def full_test_path(self) -> str: + """Get full test path in pytest format (file::class::function).""" + file_path = self.test_file.as_posix() + if self.test_class: + return f"{file_path}::{self.test_class}::{self.test_name}" + return f"{file_path}::{self.test_name}" + + +@dataclass +class TestResult: + """Language-agnostic test result. + + Captures the outcome of running a single test, including timing + and behavioral data for equivalence checking. + + Attributes: + test_name: Name of the test function. + test_file: Path to the test file. + passed: Whether the test passed. + runtime_ns: Execution time in nanoseconds. + return_value: The return value captured from the test. + stdout: Standard output captured during test execution. + stderr: Standard error captured during test execution. + error_message: Error message if the test failed. + + """ + + test_name: str + test_file: Path + passed: bool + runtime_ns: int | None = None + return_value: Any = None + stdout: str = "" + stderr: str = "" + error_message: str | None = None diff --git a/src/codeflash_python/models/test_type.py b/src/codeflash_python/models/test_type.py new file mode 100644 index 000000000..154e3f7f2 --- /dev/null +++ b/src/codeflash_python/models/test_type.py @@ -0,0 +1,22 @@ +from enum import Enum + + +class TestType(Enum): + EXISTING_UNIT_TEST = 1 + INSPIRED_REGRESSION = 2 + GENERATED_REGRESSION = 3 + REPLAY_TEST = 4 + CONCOLIC_COVERAGE_TEST = 5 + INIT_STATE_TEST = 6 + + def to_name(self) -> str: + return _TO_NAME_MAP.get(self, "") + + +_TO_NAME_MAP: dict[TestType, str] = { + TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests", + TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests", + TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests", + TestType.REPLAY_TEST: "⏪ Replay Tests", + TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests", +} diff --git a/src/codeflash_python/normalizer.py b/src/codeflash_python/normalizer.py new file mode 100644 index 000000000..f5cee935a --- /dev/null +++ b/src/codeflash_python/normalizer.py @@ -0,0 +1,181 @@ +"""Python code normalizer using AST transformation.""" + +from __future__ import annotations + +import ast +from typing import cast + + +class VariableNormalizer(ast.NodeTransformer): + """Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc. + + Preserves function names, class names, parameters, built-ins, and imported names. + """ + + def __init__(self) -> None: + self.var_counter = 0 + self.var_mapping: dict[str, str] = {} + self.scope_stack: list[dict] = [] + self.builtins = set(dir(__builtins__)) + self.imports: set[str] = set() + self.global_vars: set[str] = set() + self.nonlocal_vars: set[str] = set() + self.parameters: set[str] = set() + + def enter_scope(self) -> None: + """Enter a new scope (function/class).""" + self.scope_stack.append( + {"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)} + ) + + def exit_scope(self) -> None: + """Exit current scope and restore parent scope.""" + if self.scope_stack: + scope = self.scope_stack.pop() + self.var_mapping = scope["var_mapping"] + self.var_counter = scope["var_counter"] + self.parameters = scope["parameters"] + + def get_normalized_name(self, name: str) -> str: + """Get or create normalized name for a variable.""" + if ( + name in self.builtins + or name in self.imports + or name in self.global_vars + or name in self.nonlocal_vars + or name in self.parameters + ): + return name + + if name not in self.var_mapping: + self.var_mapping[name] = f"var_{self.var_counter}" + self.var_counter += 1 + return self.var_mapping[name] + + def visit_Import(self, node: ast.Import) -> ast.Import: + """Track imported names.""" + for alias in node.names: + name = alias.asname if alias.asname else alias.name + self.imports.add(name.split(".")[0]) + return node + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: + """Track imported names from modules.""" + for alias in node.names: + name = alias.asname if alias.asname else alias.name + self.imports.add(name) + return node + + def visit_Global(self, node: ast.Global) -> ast.Global: + """Track global variable declarations.""" + self.global_vars.update(node.names) + return node + + def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal: + """Track nonlocal variable declarations.""" + self.nonlocal_vars.update(node.names) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + """Process function but keep function name and parameters unchanged.""" + self.enter_scope() + + for arg in node.args.args: + self.parameters.add(arg.arg) + if node.args.vararg: + self.parameters.add(node.args.vararg.arg) + if node.args.kwarg: + self.parameters.add(node.args.kwarg.arg) + for arg in node.args.kwonlyargs: + self.parameters.add(arg.arg) + + node = cast("ast.FunctionDef", self.generic_visit(node)) + self.exit_scope() + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + """Handle async functions same as regular functions.""" + return self.visit_FunctionDef(node) # type: ignore[return-value] + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + """Process class but keep class name unchanged.""" + self.enter_scope() + node = cast("ast.ClassDef", self.generic_visit(node)) + self.exit_scope() + return node + + def visit_Name(self, node: ast.Name) -> ast.Name: + """Normalize variable names in Name nodes.""" + if isinstance(node.ctx, (ast.Store, ast.Del)): + if ( + node.id not in self.builtins + and node.id not in self.imports + and node.id not in self.parameters + and node.id not in self.global_vars + and node.id not in self.nonlocal_vars + ): + node.id = self.get_normalized_name(node.id) + elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping: + node.id = self.var_mapping[node.id] + return node + + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler: + """Normalize exception variable names.""" + if node.name: + node.name = self.get_normalized_name(node.name) + return cast("ast.ExceptHandler", self.generic_visit(node)) + + def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension: + """Normalize comprehension target variables.""" + old_mapping = dict(self.var_mapping) + old_counter = self.var_counter + + node = cast("ast.comprehension", self.generic_visit(node)) + + self.var_mapping = old_mapping + self.var_counter = old_counter + return node + + def visit_For(self, node: ast.For) -> ast.For: + """Handle for loop target variables.""" + return cast("ast.For", self.generic_visit(node)) + + def visit_With(self, node: ast.With) -> ast.With: + """Handle with statement as variables.""" + return cast("ast.With", self.generic_visit(node)) + + +def remove_docstrings_from_ast(node: ast.AST) -> None: + """Remove docstrings from AST nodes.""" + node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module) + stack = [node] + while stack: + current_node = stack.pop() + if isinstance(current_node, node_types): + body = current_node.body + if ( + body + and isinstance(body[0], ast.Expr) + and isinstance(body[0].value, ast.Constant) + and isinstance(body[0].value.value, str) + ): + current_node.body = body[1:] + stack.extend([child for child in body if isinstance(child, node_types)]) + + +def normalize_python_code(code: str, remove_docstrings: bool = True) -> str: + """Normalize Python code to a canonical form for comparison. + + Replaces local variable names with canonical forms (var_0, var_1, etc.) + while preserving function names, class names, parameters, and imports. + """ + tree = ast.parse(code) + + if remove_docstrings: + remove_docstrings_from_ast(tree) + + normalizer = VariableNormalizer() + normalized_tree = normalizer.visit(tree) + ast.fix_missing_locations(normalized_tree) + + return ast.unparse(normalized_tree) diff --git a/src/codeflash_python/optimization/__init__.py b/src/codeflash_python/optimization/__init__.py new file mode 100644 index 000000000..37d7fb84e --- /dev/null +++ b/src/codeflash_python/optimization/__init__.py @@ -0,0 +1,7 @@ +"""Optimization module for codeflash_python.""" + +from __future__ import annotations + +__all__ = ["Optimizer"] + +from codeflash_python.optimization.optimizer import Optimizer diff --git a/src/codeflash_python/optimization/optimizer.py b/src/codeflash_python/optimization/optimizer.py new file mode 100644 index 000000000..f54237e7a --- /dev/null +++ b/src/codeflash_python/optimization/optimizer.py @@ -0,0 +1,285 @@ +"""Optimizer class for codeflash_python optimization pipeline.""" + +from __future__ import annotations + +import copy +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash_python.code_utils.git_utils import git_root_dir + +logger = logging.getLogger("codeflash_python") + +if TYPE_CHECKING: + import ast + from argparse import Namespace + + from codeflash_core.models import FunctionToOptimize + from codeflash_python.context.types import DependencyResolver + from codeflash_python.function_optimizer import FunctionOptimizer + from codeflash_python.models.models import BenchmarkKey, FunctionCalledInTest + +try: + from codeflash_core.config import TestConfig +except ImportError: + # Stub if not available + class TestConfig: + def __init__(self, **kwargs: Any) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + + +class Optimizer: + """Main optimizer class for coordinating the optimization pipeline.""" + + def __init__(self, args: Namespace) -> None: + self.args = args + + self.test_cfg = TestConfig( + tests_root=args.tests_root, + tests_project_rootdir=args.test_project_root, + project_root=args.project_root, + test_command=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest", + benchmark_tests_root=args.benchmarks_root if hasattr(args, "benchmarks_root") else None, + ) + + from codeflash_python.api.aiservice import AiServiceClient + + self.aiservice_client = AiServiceClient() + self.replay_tests_dir = None + self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None + self.cached_callee_counts: dict[tuple[Path, str], int] = {} + + def create_function_optimizer( + self, + function_to_optimize: FunctionToOptimize, + function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None, + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, + function_to_optimize_source_code: str | None = "", + function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, + total_benchmark_timings: dict[BenchmarkKey, float] | None = None, + call_graph: DependencyResolver | None = None, + effort_override: str | None = None, + ) -> FunctionOptimizer | None: + from codeflash_python.models.function_types import qualified_name_with_modules_from_root + + qualified_name_w_module = qualified_name_with_modules_from_root(function_to_optimize, self.args.project_root) + + function_specific_timings = None + if ( + hasattr(self.args, "benchmark") + and self.args.benchmark + and function_benchmark_timings + and qualified_name_w_module in function_benchmark_timings + and total_benchmark_timings + ): + function_specific_timings = function_benchmark_timings[qualified_name_w_module] + + from codeflash_python.function_optimizer import FunctionOptimizer + + # Convert float values to int for benchmark timings + function_specific_timings_int: dict[BenchmarkKey, int] | None = None + if function_specific_timings: + function_specific_timings_int = {k: int(v) for k, v in function_specific_timings.items()} + + total_benchmark_timings_int: dict[BenchmarkKey, int] | None = None + if total_benchmark_timings and function_specific_timings: + total_benchmark_timings_int = {k: int(v) for k, v in total_benchmark_timings.items()} + + function_optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=self.test_cfg, # type: ignore[arg-type] + function_to_optimize_source_code=function_to_optimize_source_code or "", + function_to_tests=function_to_tests, + function_to_optimize_ast=function_to_optimize_ast, + aiservice_client=self.aiservice_client, + args=self.args, + function_benchmark_timings=function_specific_timings_int, + total_benchmark_timings=total_benchmark_timings_int, + replay_tests_dir=self.replay_tests_dir, + call_graph=call_graph, + effort_override=effort_override, + ) + if function_optimizer.function_to_optimize_ast is None and function_optimizer.requires_function_ast(): + logger.info( + "Function %s not found in %s.\nSkipping optimization.", + function_to_optimize.qualified_name, + function_to_optimize.file_path, + ) + return None + return function_optimizer + + def rank_all_functions_globally( + self, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], + trace_file_path: Path | None, + call_graph: DependencyResolver | None = None, + test_count_cache: dict[tuple[Path, str], int] | None = None, + ) -> list[tuple[Path, FunctionToOptimize]]: + all_functions: list[tuple[Path, FunctionToOptimize]] = [] + for file_path, functions in file_to_funcs_to_optimize.items(): + all_functions.extend((file_path, func) for func in functions) + + if not trace_file_path or not trace_file_path.exists(): + if call_graph is not None: + return self.rank_by_dependency_count(all_functions, call_graph, test_count_cache=test_count_cache) + logger.debug("No trace file available, using original function order") + return all_functions + + try: + from codeflash_python.benchmarking.function_ranker import FunctionRanker + + logger.info("loading|Ranking functions globally by performance impact...") + ranker = FunctionRanker(trace_file_path) + functions_only = [func for _, func in all_functions] + ranked_functions = ranker.rank_functions(functions_only) + + func_to_file_map = {} + for file_path, func in all_functions: + key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) + func_to_file_map[key] = file_path + ranked_with_metadata: list[tuple[Path, FunctionToOptimize, float, int]] = [] + for rank_index, func in enumerate(ranked_functions): + key = (func.file_path, func.qualified_name, func.starting_line) + file_path = func_to_file_map.get(key) + if file_path: + ranked_with_metadata.append( + (file_path, func, ranker.get_function_addressable_time(func), rank_index) + ) + + if test_count_cache: + ranked_with_metadata.sort( + key=lambda item: (-item[2], -test_count_cache.get((item[0], item[1].qualified_name), 0), item[3]) + ) + + globally_ranked = [ + (file_path, func) for file_path, func, _addressable_time, _rank_index in ranked_with_metadata + ] + + logger.info( + "Globally ranked %s functions by addressable time (filtered %s low-importance functions)", + len(ranked_functions), + len(functions_only) - len(ranked_functions), + ) + + except Exception as e: + logger.warning("Could not perform global ranking: %s", e) + logger.debug("Falling back to original function order") + return all_functions + else: + return globally_ranked + + def rank_by_dependency_count( + self, + all_functions: list[tuple[Path, FunctionToOptimize]], + call_graph: DependencyResolver, + test_count_cache: dict[tuple[Path, str], int] | None = None, + ) -> list[tuple[Path, FunctionToOptimize]]: + file_to_qns: dict[Path, set[str]] = defaultdict(set) + for file_path, func in all_functions: + file_to_qns[file_path].add(func.qualified_name) + callee_counts = call_graph.count_callees_per_function(dict(file_to_qns)) + self.cached_callee_counts = callee_counts + + if test_count_cache: + ranked = sorted( + enumerate(all_functions), + key=lambda x: ( + -callee_counts.get((x[1][0], x[1][1].qualified_name), 0), + -test_count_cache.get((x[1][0], x[1][1].qualified_name), 0), + x[0], + ), + ) + else: + ranked = sorted( + enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0]) + ) + logger.debug("Ranked %s functions by dependency count (most complex first)", len(ranked)) + return [item for _, item in ranked] + + @staticmethod + def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]: + """Search for all paths within the test_root that match instrumented test file patterns. + + Patterns: + - 'test.*__perf_test_{0,1}.py' + - 'test_.*__unit_test_{0,1}.py' + - 'test_.*__perfinstrumented.py' + - 'test_.*__perfonlyinstrumented.py' + + Returns: + A list of matching file paths. + + """ + pattern = re.compile( + r"(?:" + r"test.*__perf_test_\d?\.py|test_.*__unit_test_\d?\.py|" + r"test_.*__perfinstrumented\.py|test_.*__perfonlyinstrumented\.py" + r")$" + ) + + return [ + file_path for file_path in test_root.rglob("*") if file_path.is_file() and pattern.match(file_path.name) + ] + + def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None: + """Mirror file paths from original git root to worktree directory. + + This updates all paths in args and test_cfg to point to their + corresponding locations in the worktree. + + Args: + worktree_dir: The worktree directory to mirror paths to. + + """ + original_args = copy.deepcopy(self.args) + original_test_cfg = copy.deepcopy(self.test_cfg) + self.original_args_and_test_cfg = (original_args, original_test_cfg) + + original_git_root = git_root_dir() + + # mirror project_root + self.args.project_root = mirror_path(self.args.project_root, original_git_root, worktree_dir) + self.test_cfg.project_root = mirror_path(self.test_cfg.project_root, original_git_root, worktree_dir) # type: ignore[assignment] + + # mirror module_root + self.args.module_root = mirror_path(self.args.module_root, original_git_root, worktree_dir) + + # mirror target file + if hasattr(self.args, "file") and self.args.file: + self.args.file = mirror_path(self.args.file, original_git_root, worktree_dir) + + if hasattr(self.args, "all") and self.args.all: + # the args.all path is the same as module_root. + self.args.all = mirror_path(self.args.all, original_git_root, worktree_dir) + + # mirror tests root + self.args.tests_root = mirror_path(self.args.tests_root, original_git_root, worktree_dir) + self.test_cfg.tests_root = mirror_path(self.test_cfg.tests_root, original_git_root, worktree_dir) # type: ignore[assignment] + + # mirror tests project root + self.args.test_project_root = mirror_path(self.args.test_project_root, original_git_root, worktree_dir) + self.test_cfg.tests_project_rootdir = mirror_path( # type: ignore[assignment,unresolved-attribute] + self.test_cfg.tests_project_rootdir, + original_git_root, + worktree_dir, # type: ignore[unresolved-attribute] + ) + + # mirror benchmarks root paths + if hasattr(self.args, "benchmarks_root") and self.args.benchmarks_root: + self.args.benchmarks_root = mirror_path(self.args.benchmarks_root, original_git_root, worktree_dir) + if hasattr(self.test_cfg, "benchmark_tests_root") and self.test_cfg.benchmark_tests_root: + self.test_cfg.benchmark_tests_root = mirror_path( # type: ignore[assignment] + self.test_cfg.benchmark_tests_root, + original_git_root, + worktree_dir, # type: ignore[arg-type] + ) + + +def mirror_path(path: Path | str, src_root: Path, dest_root: Path) -> Path: + """Mirror a path from src_root to dest_root, preserving relative structure.""" + relative_path = Path(path).resolve().relative_to(src_root.resolve()) + return Path(dest_root / relative_path) diff --git a/src/codeflash_python/optimizer.py b/src/codeflash_python/optimizer.py new file mode 100644 index 000000000..d455b04f1 --- /dev/null +++ b/src/codeflash_python/optimizer.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import ast +import logging +from typing import TYPE_CHECKING + +from codeflash_python.models.models import ValidCode + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.models import FunctionParent + + +logger = logging.getLogger("codeflash_python") + + +def prepare_python_module( + original_module_code: str, original_module_path: Path, project_root: Path +) -> tuple[dict[Path, ValidCode], ast.Module] | None: + """Parse a Python module, normalize its code, and validate imported callee modules. + + Returns a mapping of file paths to ValidCode (for the module and its imported callees) + plus the parsed AST, or None on syntax error. + """ + from codeflash_python.static_analysis.code_replacer import normalize_code, normalize_node + from codeflash_python.static_analysis.static_analysis import analyze_imported_modules + + try: + original_module_ast = ast.parse(original_module_code) + except SyntaxError as e: + logger.warning("Syntax error parsing code in %s: %s", original_module_path, e) + logger.info("Skipping optimization due to file error.") + return None + + normalized_original_module_code = ast.unparse(normalize_node(original_module_ast)) + validated_original_code: dict[Path, ValidCode] = { + original_module_path: ValidCode( + source_code=original_module_code, normalized_code=normalized_original_module_code + ) + } + + imported_module_analyses = analyze_imported_modules(original_module_code, original_module_path, project_root) + + for analysis in imported_module_analyses: + callee_original_code = analysis.file_path.read_text(encoding="utf8") + try: + normalized_callee_original_code = normalize_code(callee_original_code) + except SyntaxError as e: + logger.warning("Syntax error parsing code in callee module %s: %s", analysis.file_path, e) + logger.info("Skipping optimization due to helper file error.") + return None + validated_original_code[analysis.file_path] = ValidCode( + source_code=callee_original_code, normalized_code=normalized_callee_original_code + ) + + return validated_original_code, original_module_ast + + +def resolve_python_function_ast( + function_name: str, parents: list[FunctionParent], module_ast: ast.Module +) -> ast.FunctionDef | ast.AsyncFunctionDef | None: + """Look up a function/method AST node in a parsed Python module.""" + from codeflash_python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast + + return get_first_top_level_function_or_method_ast(function_name, parents, module_ast) diff --git a/src/codeflash_python/optimizer_mixins/__init__.py b/src/codeflash_python/optimizer_mixins/__init__.py new file mode 100644 index 000000000..31f066e7c --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/__init__.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from codeflash_python.optimizer_mixins.baseline import BaselineEstablishmentMixin +from codeflash_python.optimizer_mixins.candidate_evaluation import CandidateEvaluationMixin +from codeflash_python.optimizer_mixins.candidate_structures import CandidateForest, CandidateNode, CandidateProcessor +from codeflash_python.optimizer_mixins.code_replacement import CodeReplacementMixin +from codeflash_python.optimizer_mixins.refinement import RefinementMixin +from codeflash_python.optimizer_mixins.result_processing import ResultProcessingMixin +from codeflash_python.optimizer_mixins.test_execution import TestExecutionMixin +from codeflash_python.optimizer_mixins.test_generation import TestGenerationMixin +from codeflash_python.optimizer_mixins.test_review import TestReviewMixin + +__all__ = [ + "BaselineEstablishmentMixin", + "CandidateEvaluationMixin", + "CandidateForest", + "CandidateNode", + "CandidateProcessor", + "CodeReplacementMixin", + "RefinementMixin", + "ResultProcessingMixin", + "TestExecutionMixin", + "TestGenerationMixin", + "TestReviewMixin", +] diff --git a/src/codeflash_python/optimizer_mixins/_protocol.py b/src/codeflash_python/optimizer_mixins/_protocol.py new file mode 100644 index 000000000..bd1d57c3f --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/_protocol.py @@ -0,0 +1,388 @@ +"""Type-only protocol declaring the shared interface of FunctionOptimizer. + +Mixin classes inherit from this under TYPE_CHECKING so that type checkers can +resolve cross-mixin attribute and method accesses on ``self``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + import ast + import concurrent.futures + from argparse import Namespace + from pathlib import Path + + from codeflash_core.config import TestConfig + from codeflash_core.danom import Err, Result + from codeflash_core.models import FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.api.types import TestDiff, TestFileReview + from codeflash_python.context.types import DependencyResolver + from codeflash_python.models.models import ( + BenchmarkKey, + BestOptimization, + CodeOptimizationContext, + CodeStringsMarkdown, + ConcurrencyMetrics, + CoverageData, + FunctionCalledInTest, + FunctionSource, + GeneratedTestsList, + OptimizationSet, + OptimizedCandidate, + OptimizedCandidateResult, + OriginalCodeBaseline, + TestFiles, + TestingMode, + TestResults, + ) + from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext, CandidateNode + from codeflash_python.result.explanation import Explanation + + +class FunctionOptimizerProtocol(Protocol): + # -- Instance attributes (set in FunctionOptimizer.__init__) -- + + project_root: Path + test_cfg: TestConfig + aiservice_client: AiServiceClient | None + local_aiservice_client: AiServiceClient | None + function_to_optimize: FunctionToOptimize + function_to_optimize_source_code: str + function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None + function_to_tests: dict[str, set[FunctionCalledInTest]] + experiment_id: str | None + test_files: TestFiles + effort: str + args: Namespace | None + function_trace_id: str + original_module_path: str + function_benchmark_timings: dict[BenchmarkKey, int] + total_benchmark_timings: dict[BenchmarkKey, int] + replay_tests_dir: Path | None + call_graph: DependencyResolver | None + executor: concurrent.futures.ThreadPoolExecutor + optimization_review: str + future_all_code_repair: list[concurrent.futures.Future] + future_all_refinements: list[concurrent.futures.Future] + future_adaptive_optimizations: list[concurrent.futures.Future] + repair_counter: int + adaptive_optimization_counter: int + is_numerical_code: bool | None + code_already_exists: bool + + # -- Methods defined in FunctionOptimizer -- + + def get_test_env( + self, codeflash_loop_index: int, codeflash_test_iteration: int, codeflash_tracer_disable: int = ... + ) -> dict: ... + + def get_trace_id(self, exp_type: str) -> str: ... + + def cleanup_async_helper_file(self) -> None: ... + + def get_results_not_matched_error(self) -> Err: ... + + def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None: ... + + def instrument_async_for_mode(self, mode: TestingMode) -> None: ... + + def should_check_coverage(self) -> bool: ... + + def parse_line_profile_test_results( + self, line_profiler_output_file: Path | None + ) -> tuple[TestResults | dict, CoverageData | None]: ... + + def compare_candidate_results( + self, + baseline_results: OriginalCodeBaseline, + candidate_behavior_results: TestResults, + optimization_candidate_index: int, + ) -> tuple[bool, list[TestDiff]]: ... + + def replace_function_and_helpers_with_optimized_code( + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], + ) -> bool: ... + + def collect_async_metrics( + self, + benchmarking_results: TestResults, + code_context: CodeOptimizationContext, + helper_code: dict[Path, str], + test_env: dict[str, str], + ) -> tuple[int | None, ConcurrencyMetrics | None]: ... + + def line_profiler_step( + self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int + ) -> dict[str, Any]: ... + + def display_repaired_functions( + self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str] + ) -> None: ... + + def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, str] | None: ... + + def fixup_generated_tests(self, generated_tests: GeneratedTestsList) -> GeneratedTestsList: ... + + # -- Methods from CodeReplacementMixin -- + + @staticmethod + def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None: ... + + def reformat_code_and_helpers( + self, + helper_functions: list[FunctionSource], + path: Path, + original_code: str, + optimized_context: CodeStringsMarkdown, + ) -> tuple[str, dict[Path, str]]: ... + + def group_functions_by_file(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]: ... + + def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None: ... + + # -- Methods from TestExecutionMixin -- + + def run_and_parse_tests( + self, + testing_type: TestingMode, + test_env: dict[str, str], + test_files: TestFiles, + optimization_iteration: int, + testing_time: float = ..., + *, + enable_coverage: bool = ..., + pytest_min_loops: int = ..., + pytest_max_loops: int = ..., + code_context: CodeOptimizationContext | None = ..., + line_profiler_output_file: Path | None = ..., + ) -> tuple[TestResults | dict, CoverageData | None]: ... + + def run_behavioral_validation( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + ) -> tuple[TestResults, CoverageData | None] | None: ... + + def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]: ... + + def run_concurrency_benchmark( + self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str] + ) -> ConcurrencyMetrics | None: ... + + # -- Methods from BaselineEstablishmentMixin -- + + def establish_original_code_baseline( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = ..., + ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: ... + + def setup_and_establish_baseline( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + function_to_concolic_tests: dict[str, set[FunctionCalledInTest]], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + instrumented_unittests_created_for_function: set[Path], + original_conftest_content: dict[Path, str] | None, + precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = ..., + ) -> Result[ + tuple[str, dict[str, set[FunctionCalledInTest]], OriginalCodeBaseline, list[str], dict[Path, set[str]]], str + ]: ... + + def build_helper_classes_map(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]: ... + + # -- Methods from TestGenerationMixin -- + + def generate_tests( + self, + testgen_context: CodeStringsMarkdown, + helper_functions: list[FunctionSource], + testgen_helper_fqns: list[str], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + ) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]: ... + + def submit_test_generation_tasks( + self, + executor: concurrent.futures.ThreadPoolExecutor, + source_code_being_tested: str, + helper_function_names: list[str], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + ) -> list[concurrent.futures.Future]: ... + + def generate_and_instrument_tests( + self, code_context: CodeOptimizationContext + ) -> Result[ + tuple[ + GeneratedTestsList, + dict[str, set[FunctionCalledInTest]], + str, + list[Path], + list[Path], + set[Path], + dict[Path, str] | None, + ], + str, + ]: ... + + # -- Methods from TestReviewMixin -- + + def review_and_repair_tests( + self, + generated_tests: GeneratedTestsList, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + ) -> Result[tuple[GeneratedTestsList, TestResults | None, CoverageData | None], str]: ... + + # -- Methods from CandidateEvaluationMixin -- + + def handle_successful_candidate( + self, + candidate: OptimizedCandidate, + candidate_result: OptimizedCandidateResult, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + candidate_index: int, + eval_ctx: CandidateEvaluationContext, + ) -> BestOptimization: ... + + def select_best_optimization( + self, + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: ... + + def log_evaluation_results( + self, + eval_ctx: CandidateEvaluationContext, + best_optimization: BestOptimization, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + ) -> None: ... + + def run_optimized_candidate( + self, + *, + optimization_candidate_index: int, + baseline_results: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, + ) -> Result[OptimizedCandidateResult, str]: ... + + def process_single_candidate( + self, + candidate_node: CandidateNode, + candidate_index: int, + total_candidates: int, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + exp_type: str, + function_references: str, + normalized_original: str, + ) -> BestOptimization | None: ... + + def determine_best_candidate( + self, + *, + candidates: list[OptimizedCandidate], + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + exp_type: str, + function_references: str, + ) -> BestOptimization | None: ... + + def call_adaptive_optimize( + self, + trace_id: str, + original_source_code: str, + prev_candidates: list[OptimizedCandidate], + eval_ctx: CandidateEvaluationContext, + ai_service_client: AiServiceClient, + ) -> concurrent.futures.Future[OptimizedCandidate | None] | None: ... + + def repair_optimization( + self, + original_source_code: str, + modified_source_code: str, + test_diffs: list[TestDiff], + trace_id: str, + optimization_id: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + language: str = ..., + ) -> concurrent.futures.Future[OptimizedCandidate | None]: ... + + def repair_if_possible( + self, + candidate: OptimizedCandidate, + diffs: list[TestDiff], + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + test_results_count: int, + exp_type: str, + ) -> None: ... + + # -- Methods from ResultProcessingMixin -- + + def find_and_process_best_optimization( + self, + optimizations_set: OptimizationSet, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + function_to_optimize_qualified_name: str, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + function_references: str, + ) -> BestOptimization | None: ... + + def process_review( + self, + original_code_baseline: OriginalCodeBaseline, + best_optimization: BestOptimization, + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + original_code_combined: dict[Path, str], + new_code_combined: dict[Path, str], + explanation: Explanation, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + exp_type: str, + original_helper_code: dict[Path, str], + code_context: CodeOptimizationContext, + function_references: str, + ) -> None: ... + + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str + ) -> None: ... diff --git a/src/codeflash_python/optimizer_mixins/baseline.py b/src/codeflash_python/optimizer_mixins/baseline.py new file mode 100644 index 000000000..6d6ffab44 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/baseline.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, cast + +from codeflash_core.danom import Err, Ok +from codeflash_python.code_utils.code_utils import cleanup_paths +from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.models.models import OriginalCodeBaseline, TestingMode, TestType +from codeflash_python.result.critic import coverage_critic, quantity_of_tests_critic + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.danom import Result + from codeflash_python.models.models import CodeOptimizationContext, CoverageData, FunctionCalledInTest, TestResults + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class BaselineEstablishmentMixin(_Base): + def establish_original_code_baseline( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None, + ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: + line_profile_results = {"timings": {}, "unit": 0, "str_out": ""} + # For the original function - run the tests and get the runtime, plus coverage + success = True + + test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) + + if precomputed_behavioral is not None: + # Reuse behavioral results from the review cycle (no repairs were needed) + behavioral_results, coverage_results = precomputed_behavioral + logger.debug("[PIPELINE] Reusing behavioral results from test review cycle (no repairs were made)") + else: + if self.function_to_optimize.is_async: + self.instrument_async_for_mode(TestingMode.BEHAVIOR) + + # Instrument codeflash capture + try: + self.instrument_capture(file_path_to_helper_classes) + logger.debug("[PIPELINE] Establishing baseline with %s test files", len(self.test_files)) + for idx, tf in enumerate(self.test_files): + logger.debug( + "[PIPELINE] Test file %s: behavior=%s, perf=%s", + idx, + tf.instrumented_behavior_file_path, + tf.benchmarking_file_path, + ) + behavioral_results, coverage_results = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=True, + code_context=code_context, + ) + assert isinstance(behavioral_results, TestResults) + finally: + # Remove codeflash capture + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + if not behavioral_results: + logger.warning( + "force_lsp|Couldn't run any tests for original function %s. Skipping optimization.", + self.function_to_optimize.function_name, + ) + return Err("Failed to establish a baseline for the original code - bevhavioral tests failed.") + # Skip coverage check for non-Python languages (coverage not yet supported) + if self.should_check_coverage() and not coverage_critic(coverage_results): + did_pass_all_tests = all(result.did_pass for result in behavioral_results) + if not did_pass_all_tests: + return Err("Tests failed to pass for the original code.") + coverage_pct = coverage_results.coverage if coverage_results else 0 + return Err( + f"Test coverage is {coverage_pct}%, which is below the required threshold of {__import__('codeflash_python.code_utils.config_consts', fromlist=['COVERAGE_THRESHOLD']).COVERAGE_THRESHOLD}%." + ) + + line_profile_results = self.line_profiler_step( + code_context=code_context, original_helper_code=original_helper_code, candidate_index=0 + ) + + logger.debug( + "[BENCHMARK-START] Starting benchmarking tests with %s test files", len(self.test_files.test_files) + ) + for idx, tf in enumerate(self.test_files.test_files): + logger.debug("[BENCHMARK-FILES] Test file %s: perf_file=%s", idx, tf.benchmarking_file_path) + + if self.function_to_optimize.is_async: + self.instrument_async_for_mode(TestingMode.PERFORMANCE) + + try: + benchmarking_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=False, + code_context=code_context, + ) + assert isinstance(benchmarking_results, TestResults) + logger.debug("[BENCHMARK-DONE] Got %s benchmark results", len(benchmarking_results.test_results)) + finally: + if self.function_to_optimize.is_async: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + + total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index + functions_to_remove = [ + result.id.test_function_name + for result in behavioral_results + if (result.test_type == TestType.GENERATED_REGRESSION and not result.did_pass) + ] + + if total_timing == 0: + logger.warning("The overall summed benchmark runtime of the original function is 0, couldn't run tests.") + success = False + if not total_timing: + logger.warning("Failed to run the tests for the original function, skipping optimization") + success = False + if not success: + return Err("Failed to establish a baseline for the original code.") + + loop_count = benchmarking_results.effective_loop_count() + logger.info( + "h3|⌚ Original code summed runtime measured over '%s' loop%s: '%s' per full loop", + loop_count, + "s" if loop_count > 1 else "", + humanize_runtime(total_timing), + ) + logger.debug("Total original code runtime (ns): %s", total_timing) + + async_throughput, concurrency_metrics = self.collect_async_metrics( + benchmarking_results, code_context, original_helper_code, test_env + ) + + assert self.args is not None + if self.args.benchmark: + assert self.replay_tests_dir is not None + replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( + list(self.total_benchmark_timings.keys()), self.replay_tests_dir, self.project_root + ) + return Ok( + ( + OriginalCodeBaseline( + behavior_test_results=behavioral_results, + benchmarking_test_results=benchmarking_results, + replay_benchmarking_test_results=replay_benchmarking_test_results if self.args.benchmark else None, + runtime=total_timing, + coverage_results=coverage_results, + line_profile_results=line_profile_results, + async_throughput=async_throughput, + concurrency_metrics=concurrency_metrics, + ), + functions_to_remove, + ) + ) + + def setup_and_establish_baseline( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + function_to_concolic_tests: dict[str, set[FunctionCalledInTest]], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + instrumented_unittests_created_for_function: set[Path], + original_conftest_content: dict[Path, str] | None, + precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None, + ) -> Result[ + tuple[str, dict[str, set[FunctionCalledInTest]], OriginalCodeBaseline, list[str], dict[Path, set[str]]], str + ]: + """Set up baseline context and establish original code baseline.""" + from codeflash_python.code_utils.code_utils import restore_conftest + + function_to_optimize_qualified_name = self.function_to_optimize.qualified_name + function_to_all_tests = { + key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) + for key in set(self.function_to_tests) | set(function_to_concolic_tests) + } + + file_path_to_helper_classes = self.build_helper_classes_map(code_context) + + baseline_result = self.establish_original_code_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + precomputed_behavioral=precomputed_behavioral, + ) + + paths_to_cleanup = ( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) + + if not baseline_result.is_ok(): + assert self.args is not None + if self.args.override_fixtures and original_conftest_content is not None: + restore_conftest(original_conftest_content) + cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() + return Err(cast("Err", baseline_result).error) + + original_code_baseline, test_functions_to_remove = baseline_result.unwrap() + # Check test quantity for all languages + quantity_ok = quantity_of_tests_critic(original_code_baseline) + coverage_ok = coverage_critic(original_code_baseline.coverage_results) if self.should_check_coverage() else True + if isinstance(original_code_baseline, OriginalCodeBaseline) and (not coverage_ok or not quantity_ok): + assert self.args is not None + if self.args.override_fixtures and original_conftest_content is not None: + restore_conftest(original_conftest_content) + cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() + return Err("The threshold for test confidence was not met.") + + return Ok( + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) + ) + + def build_helper_classes_map(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]: + """Build a mapping of file paths to helper class names from code context.""" + file_path_to_helper_classes: dict[Path, set[str]] = defaultdict(set) + for function_source in code_context.helper_functions: + if ( + function_source.qualified_name != self.function_to_optimize.qualified_name + and "." in function_source.qualified_name + ): + file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0]) + return file_path_to_helper_classes diff --git a/src/codeflash_python/optimizer_mixins/candidate_evaluation.py b/src/codeflash_python/optimizer_mixins/candidate_evaluation.py new file mode 100644 index 000000000..c9da5c5aa --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/candidate_evaluation.py @@ -0,0 +1,587 @@ +from __future__ import annotations + +import concurrent.futures +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_core.danom import Ok +from codeflash_python.api.types import AIServiceRefinerRequest +from codeflash_python.code_utils.code_utils import get_run_tmp_file, unified_diff_strings +from codeflash_python.code_utils.config_consts import ( + PYTHON_LANGUAGE_VERSION, + TOTAL_LOOPING_TIME_EFFECTIVE, + EffortKeys, + get_effort_value, +) +from codeflash_python.models.models import ( + BestOptimization, + OptimizedCandidate, + OptimizedCandidateResult, + OptimizedCandidateSource, + TestingMode, +) +from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext, CandidateProcessor +from codeflash_python.optimizer_mixins.scoring import create_rank_dictionary_compact, diff_length +from codeflash_python.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic + +if TYPE_CHECKING: + from codeflash_core.danom import Result + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.models.models import CodeOptimizationContext, OriginalCodeBaseline + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base + from codeflash_python.optimizer_mixins.candidate_structures import CandidateNode +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +def normalize_code(source: str) -> str: + from codeflash_python.normalizer import normalize_python_code + + try: + return normalize_python_code(source, remove_docstrings=True) + except Exception: + return source + + +class CandidateEvaluationMixin(_Base): + def handle_successful_candidate( + self, + candidate: OptimizedCandidate, + candidate_result: OptimizedCandidateResult, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + candidate_index: int, + eval_ctx: CandidateEvaluationContext, + ) -> BestOptimization: + """Handle a successful optimization candidate.""" + line_profile_test_results = self.line_profiler_step( + code_context=code_context, original_helper_code=original_helper_code, candidate_index=candidate_index + ) + + eval_ctx.record_line_profiler_result(candidate.optimization_id, line_profile_test_results["str_out"]) + + replay_perf_gain = {} + + assert self.args is not None + if self.args.benchmark: + assert self.replay_tests_dir is not None + assert original_code_baseline.replay_benchmarking_test_results is not None + assert self.total_benchmark_timings is not None + test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks( + list(self.total_benchmark_timings.keys()), self.replay_tests_dir, self.project_root + ) + for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): + original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[ + benchmark_key + ].total_passed_runtime() + candidate_replay_runtime = candidate_test_results.total_passed_runtime() + replay_perf_gain[benchmark_key] = performance_gain( + original_runtime_ns=original_code_replay_runtime, optimized_runtime_ns=candidate_replay_runtime + ) + + assert self.args is not None + return BestOptimization( + candidate=candidate, + helper_functions=code_context.helper_functions, + code_context=code_context, + runtime=candidate_result.best_test_runtime, + line_profiler_test_results=line_profile_test_results, + winning_behavior_test_results=candidate_result.behavior_test_results, + replay_performance_gain=replay_perf_gain if self.args.benchmark else None, + winning_benchmarking_test_results=candidate_result.benchmarking_test_results, + winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, + async_throughput=candidate_result.async_throughput, + concurrency_metrics=candidate_result.concurrency_metrics, + ) + + def select_best_optimization( + self, + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Select the best optimization from valid candidates.""" + if not eval_ctx.valid_optimizations: + return None + + valid_candidates_with_shorter_code = [] + diff_lens_list = [] # character level diff + speedups_list = [] + optimization_ids = [] + diff_strs = [] + runtimes_list = [] + + for valid_opt in eval_ctx.valid_optimizations: + valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) + new_candidate_with_shorter_code = OptimizedCandidate( + source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + optimization_id=valid_opt.candidate.optimization_id, + explanation=valid_opt.candidate.explanation, + source=valid_opt.candidate.source, + parent_id=valid_opt.candidate.parent_id, + ) + new_best_opt = BestOptimization( + candidate=new_candidate_with_shorter_code, + helper_functions=valid_opt.helper_functions, + code_context=valid_opt.code_context, + runtime=valid_opt.runtime, + line_profiler_test_results=valid_opt.line_profiler_test_results, + winning_behavior_test_results=valid_opt.winning_behavior_test_results, + replay_performance_gain=valid_opt.replay_performance_gain, + winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results, + winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results, + async_throughput=valid_opt.async_throughput, + concurrency_metrics=valid_opt.concurrency_metrics, + ) + valid_candidates_with_shorter_code.append(new_best_opt) + diff_lens_list.append( + diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat) + ) + diff_strs.append( + unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat) + ) + speedups_list.append( + 1 + + performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime + ) + ) + optimization_ids.append(new_best_opt.candidate.optimization_id) + runtimes_list.append(new_best_opt.runtime) + + if len(optimization_ids) > 1: + ranking = None + future_ranking = self.executor.submit( + ai_service_client.generate_ranking, + diffs=diff_strs, + optimization_ids=optimization_ids, + speedups=speedups_list, + trace_id=self.get_trace_id(exp_type), + function_references=function_references, + ) + concurrent.futures.wait([future_ranking]) + ranking = future_ranking.result() + if ranking: + min_key = ranking[0] + else: + diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) + runtimes_ranking = create_rank_dictionary_compact(runtimes_list) + overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking} + min_key = min(overall_ranking, key=lambda k: overall_ranking[k]) + elif len(optimization_ids) == 1: + min_key = 0 + else: + return None + + return valid_candidates_with_shorter_code[min_key] + + def log_evaluation_results( + self, + eval_ctx: CandidateEvaluationContext, + best_optimization: BestOptimization, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + ) -> None: + """Log evaluation results to the AI service.""" + ai_service_client.log_results( + function_trace_id=self.get_trace_id(exp_type), + speedup_ratio=eval_ctx.speedup_ratios, + original_runtime=original_code_baseline.runtime, + optimized_runtime=eval_ctx.optimized_runtimes, + is_correct=eval_ctx.is_correct, + optimized_line_profiler_results=eval_ctx.optimized_line_profiler_results, + optimizations_post=eval_ctx.optimizations_post, + metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, + ) + + def run_optimized_candidate( + self, + *, + optimization_candidate_index: int, + baseline_results: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, + ) -> Result[OptimizedCandidateResult, str]: + + test_env = self.get_test_env( + codeflash_loop_index=0, codeflash_test_iteration=optimization_candidate_index, codeflash_tracer_disable=1 + ) + + get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) + # Instrument codeflash capture + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + candidate_helper_code = {} + for module_abspath in original_helper_code: + candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") + if self.function_to_optimize.is_async: + self.instrument_async_for_mode(TestingMode.BEHAVIOR) + + try: + self.instrument_capture(file_path_to_helper_classes) + + candidate_behavior_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=optimization_candidate_index, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=False, + ) + finally: + self.write_code_and_helpers(candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path) + from codeflash_python.models.models import TestResults + + assert isinstance(candidate_behavior_results, TestResults) + match, diffs = self.compare_candidate_results( + baseline_results, candidate_behavior_results, optimization_candidate_index + ) + + if match: + logger.info("h3|Test results matched ✅") + else: + self.repair_if_possible(candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type) + return self.get_results_not_matched_error() + + logger.info("loading|Running performance tests for candidate %s...", optimization_candidate_index) + + if self.function_to_optimize.is_async: + self.instrument_async_for_mode(TestingMode.PERFORMANCE) + + try: + candidate_benchmarking_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=optimization_candidate_index, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=False, + ) + finally: + if self.function_to_optimize.is_async: + self.write_code_and_helpers( + candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path + ) + # Use effective_loop_count which represents the number of timing samples across all test cases. + from codeflash_python.models.models import TestResults as TestResultsModel + + assert isinstance(candidate_benchmarking_results, TestResultsModel) + loop_count = candidate_benchmarking_results.effective_loop_count() + + if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: + logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") + + logger.debug("Total optimized code %s runtime (ns): %s", optimization_candidate_index, total_candidate_timing) + + candidate_async_throughput, candidate_concurrency_metrics = self.collect_async_metrics( + candidate_benchmarking_results, code_context, candidate_helper_code, test_env + ) + + assert self.args is not None + if self.args.benchmark: + assert self.total_benchmark_timings is not None + assert self.replay_tests_dir is not None + candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( + list(self.total_benchmark_timings.keys()), self.replay_tests_dir, self.project_root + ) + from codeflash_python.code_utils.time_utils import humanize_runtime + + for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): + logger.debug( + "Benchmark %s runtime (ns): %s", + benchmark_name, + humanize_runtime(benchmark_results.total_passed_runtime()), + ) + return Ok( + OptimizedCandidateResult( + max_loop_count=loop_count, + best_test_runtime=total_candidate_timing, + behavior_test_results=candidate_behavior_results, + benchmarking_test_results=candidate_benchmarking_results, + replay_benchmarking_test_results=candidate_replay_benchmarking_results if self.args.benchmark else None, + optimization_candidate_index=optimization_candidate_index, + total_candidate_timing=total_candidate_timing, + async_throughput=candidate_async_throughput, + concurrency_metrics=candidate_concurrency_metrics, + ) + ) + + def process_single_candidate( + self, + candidate_node: CandidateNode, + candidate_index: int, + total_candidates: int, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + exp_type: str, + function_references: str, + normalized_original: str, + ) -> BestOptimization | None: + """Process a single optimization candidate. + + Returns the BestOptimization if the candidate is successful, None otherwise. + Updates eval_ctx with results and may append to all_refinements_data. + """ + # Cleanup temp files + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) + + candidate = candidate_node.candidate + + normalized_code = normalize_code(candidate.source_code.flat.strip()) + + if normalized_code == normalized_original: + logger.info("h3|Candidate %s/%s: Identical to original code, skipping.", candidate_index, total_candidates) + return None + + if normalized_code in eval_ctx.ast_code_to_id: + logger.info( + "h3|Candidate %s/%s: Duplicate of a previous candidate, skipping.", candidate_index, total_candidates + ) + eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) + return None + + logger.info("h3|Optimization candidate %s/%s:", candidate_index, total_candidates) + + # Try to replace function with optimized code + try: + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=candidate.source_code, + original_helper_code=original_helper_code, + ) + if not did_update: + logger.info("No functions were replaced in the optimized code. Skipping optimization candidate.") + return None + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.exception(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + return None + + eval_ctx.register_new_candidate(normalized_code, candidate, code_context) + + # Run the optimized candidate + run_results = self.run_optimized_candidate( + optimization_candidate_index=candidate_index, + baseline_results=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + eval_ctx=eval_ctx, + code_context=code_context, + candidate=candidate, + exp_type=exp_type, + ) + + if not run_results.is_ok(): + eval_ctx.record_failed_candidate(candidate.optimization_id) + return None + + candidate_result: OptimizedCandidateResult = run_results.unwrap() + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=candidate_result.best_test_runtime + ) + eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain) + + # Check if this is a successful optimization + is_successful_opt = speedup_critic( + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now=None, + original_async_throughput=original_code_baseline.async_throughput, + best_throughput_until_now=None, + original_concurrency_metrics=original_code_baseline.concurrency_metrics, + best_concurrency_ratio_until_now=None, + ) and quantity_of_tests_critic(candidate_result) + + best_optimization = None + + if is_successful_opt: + best_optimization = self.handle_successful_candidate( + candidate=candidate, + candidate_result=candidate_result, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + candidate_index=candidate_index, + eval_ctx=eval_ctx, + ) + eval_ctx.valid_optimizations.append(best_optimization) + + current_tree_candidates = candidate_node.path_to_root() + is_candidate_refined_before = any( + c.source == OptimizedCandidateSource.REFINE for c in current_tree_candidates + ) + + assert self.aiservice_client is not None + aiservice_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + assert aiservice_client is not None + + if is_candidate_refined_before: + future_adaptive_optimization = self.call_adaptive_optimize( + trace_id=self.get_trace_id(exp_type), + original_source_code=code_context.read_writable_code.markdown, + prev_candidates=current_tree_candidates, + eval_ctx=eval_ctx, + ai_service_client=aiservice_client, + ) + if future_adaptive_optimization: + self.future_adaptive_optimizations.append(future_adaptive_optimization) + else: + # Refinement + future_refinement = self.executor.submit( + aiservice_client.optimize_code_refinement, + request=[ + AIServiceRefinerRequest( + optimization_id=best_optimization.candidate.optimization_id, + original_source_code=code_context.read_writable_code.markdown, + read_only_dependency_code=code_context.read_only_context_code, + original_code_runtime=original_code_baseline.runtime, + optimized_source_code=best_optimization.candidate.source_code.markdown, + optimized_explanation=best_optimization.candidate.explanation, + optimized_code_runtime=best_optimization.runtime, + speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", + trace_id=self.get_trace_id(exp_type), + original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], + optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], + function_references=function_references, + language=self.function_to_optimize.language, + language_version=PYTHON_LANGUAGE_VERSION, + ) + ], + ) + self.future_all_refinements.append(future_refinement) + + return best_optimization + + def determine_best_candidate( + self, + *, + candidates: list[OptimizedCandidate], + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Determine the best optimization candidate from a list of candidates.""" + from codeflash_python.models.experiment_metadata import ExperimentMetadata + + logger.info( + "Determining best optimization candidate (out of %s) for %s…", + len(candidates), + self.function_to_optimize.qualified_name, + ) + + # Initialize evaluation context and async tasks + eval_ctx = CandidateEvaluationContext() + + self.future_all_refinements.clear() + self.future_all_code_repair.clear() + self.future_adaptive_optimizations.clear() + + self.repair_counter = 0 + self.adaptive_optimization_counter = 0 + + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + assert ai_service_client is not None, "AI service client must be set for optimization" + + assert self.args is not None + future_line_profile_results = self.executor.submit( + ai_service_client.optimize_python_code_line_profiler, + source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code, + trace_id=self.get_trace_id(exp_type), + line_profiler_results=original_code_baseline.line_profile_results["str_out"], + n_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.effort), + experiment_metadata=ExperimentMetadata( + id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + ) + if self.experiment_id + else None, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, + language=self.function_to_optimize.language, + language_version=PYTHON_LANGUAGE_VERSION, + ) + + processor = CandidateProcessor( + candidates, + future_line_profile_results, + eval_ctx, + self.effort, + code_context.read_writable_code.markdown, + self.future_all_refinements, + self.future_all_code_repair, + self.future_adaptive_optimizations, + ) + candidate_index = 0 + normalized_original = normalize_code(code_context.read_writable_code.flat.strip()) + + # Process candidates using queue-based approach + while not processor.is_done(): + candidate_node = processor.get_next_candidate() + if candidate_node is None: + logger.debug("everything done, exiting") + break + + try: + candidate_index += 1 + self.process_single_candidate( + candidate_node=candidate_node, + candidate_index=candidate_index, + total_candidates=processor.candidate_len, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + eval_ctx=eval_ctx, + exp_type=exp_type, + function_references=function_references, + normalized_original=normalized_original, + ) + except KeyboardInterrupt as e: + logger.exception("Optimization interrupted: %s", e) + raise + finally: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + + # Select and return the best optimization + best_optimization = self.select_best_optimization( + eval_ctx=eval_ctx, + code_context=code_context, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + function_references=function_references, + ) + + if best_optimization: + self.log_evaluation_results( + eval_ctx=eval_ctx, + best_optimization=best_optimization, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + ) + + return best_optimization diff --git a/src/codeflash_python/optimizer_mixins/candidate_structures.py b/src/codeflash_python/optimizer_mixins/candidate_structures.py new file mode 100644 index 000000000..03bcce3c9 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/candidate_structures.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import concurrent.futures +import logging +import queue +from typing import TYPE_CHECKING, Callable + +from pydantic import Field +from pydantic.dataclasses import dataclass + +from codeflash_python.code_utils.config_consts import REFINED_CANDIDATE_RANKING_WEIGHTS, EffortKeys, get_effort_value +from codeflash_python.optimizer_mixins.scoring import ( + choose_weights, + create_score_dictionary_from_metrics, + diff_length, + normalize_by_max, +) + +if TYPE_CHECKING: + from codeflash_python.models.models import CodeOptimizationContext, OptimizedCandidate + +logger = logging.getLogger("codeflash_python") + + +class CandidateNode: + __slots__ = ("candidate", "children", "parent") + + def __init__(self, candidate: OptimizedCandidate) -> None: + self.candidate = candidate + self.parent: CandidateNode | None = None + self.children: list[CandidateNode] = [] + + def is_leaf(self) -> bool: + return not self.children + + def path_to_root(self) -> list[OptimizedCandidate]: + path = [] + node: CandidateNode | None = self + while node: + path.append(node.candidate) + node = node.parent + return path[::-1] + + +class CandidateForest: + def __init__(self) -> None: + self.nodes: dict[str, CandidateNode] = {} + + def add(self, candidate: OptimizedCandidate) -> CandidateNode: + cid = candidate.optimization_id + pid = candidate.parent_id + + node = self.nodes.get(cid) + if node is None: + node = CandidateNode(candidate) + self.nodes[cid] = node + + if pid is not None: + parent = self.nodes.get(pid) + if parent is None: + parent = CandidateNode(candidate=None) # type: ignore[arg-type] # placeholder + self.nodes[pid] = parent + + node.parent = parent + parent.children.append(node) + + return node + + def get_node(self, cid: str) -> CandidateNode | None: + return self.nodes.get(cid) + + +class CandidateProcessor: + """Handles candidate processing using a queue-based approach.""" + + def __init__( + self, + initial_candidates: list[OptimizedCandidate], + future_line_profile_results: concurrent.futures.Future, + eval_ctx: CandidateEvaluationContext, + effort: str, + original_markdown_code: str, + future_all_refinements: list[concurrent.futures.Future], + future_all_code_repair: list[concurrent.futures.Future], + future_adaptive_optimizations: list[concurrent.futures.Future], + ) -> None: + self.candidate_queue = queue.Queue() + self.forest = CandidateForest() + self.line_profiler_done = False + self.refinement_done = False + self.eval_ctx = eval_ctx + self.effort = effort + self.candidate_len = len(initial_candidates) + self.refinement_calls_count = 0 + self.original_markdown_code = original_markdown_code + + # Initialize queue with initial candidates + for candidate in initial_candidates: + self.forest.add(candidate) + self.candidate_queue.put(candidate) + + self.future_line_profile_results = future_line_profile_results + self.future_all_refinements = future_all_refinements + self.future_all_code_repair = future_all_code_repair + self.future_adaptive_optimizations = future_adaptive_optimizations + + def get_total_llm_calls(self) -> int: + return self.refinement_calls_count + + def get_next_candidate(self) -> CandidateNode | None: + """Get the next candidate from the queue, handling async results as needed.""" + try: + return self.forest.get_node(self.candidate_queue.get_nowait().optimization_id) + except queue.Empty: + return self.handle_empty_queue() + + def handle_empty_queue(self) -> CandidateNode | None: + """Handle empty queue by checking for pending async results.""" + if not self.line_profiler_done: + return self.process_candidates( + [self.future_line_profile_results], + "all candidates processed, await candidates from line profiler", + "Added results from line profiler to candidates, total candidates now: {1}", + lambda: setattr(self, "line_profiler_done", True), + ) + if len(self.future_all_code_repair) > 0: + return self.process_candidates( + self.future_all_code_repair, + "Repairing {0} candidates", + "Added {0} candidates from repair, total candidates now: {1}", + self.future_all_code_repair.clear, + ) + if self.line_profiler_done and not self.refinement_done: + return self.process_candidates( + self.future_all_refinements, + "Refining generated code for improved quality and performance...", + "Added {0} candidates from refinement, total candidates now: {1}", + lambda: setattr(self, "refinement_done", True), + filter_candidates_func=self.filter_refined_candidates, + ) + if len(self.future_adaptive_optimizations) > 0: + return self.process_candidates( + self.future_adaptive_optimizations, + "Applying adaptive optimizations to {0} candidates", + "Added {0} candidates from adaptive optimization, total candidates now: {1}", + self.future_adaptive_optimizations.clear, + ) + return None # All done + + def process_candidates( + self, + future_candidates: list[concurrent.futures.Future], + loading_msg: str, + success_msg: str, + callback: Callable[[], None], + filter_candidates_func: Callable[[list[OptimizedCandidate]], list[OptimizedCandidate]] | None = None, + ) -> CandidateNode | None: + if len(future_candidates) == 0: + return None + concurrent.futures.wait(future_candidates) + candidates: list[OptimizedCandidate] = [] + for future_c in future_candidates: + candidate_result = future_c.result() + if not candidate_result: + continue + + if isinstance(candidate_result, list): + candidates.extend(candidate_result) + else: + candidates.append(candidate_result) + + candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates + for candidate in candidates: + self.forest.add(candidate) + self.candidate_queue.put(candidate) + self.candidate_len += 1 + + if len(candidates) > 0: + logger.info(success_msg.format(len(candidates), self.candidate_len)) + + callback() + return self.get_next_candidate() + + def filter_refined_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]: + """We generate a weighted ranking based on the runtime and diff lines and select the best of valid optimizations to be tested.""" + self.refinement_calls_count += len(candidates) + + top_n_candidates = int( + min(int(get_effort_value(EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT, self.effort)), len(candidates)) + ) + + if len(candidates) == top_n_candidates: + # no need for ranking since we will return all candidates + return candidates + + diff_lens_list = [] + runtimes_list = [] + for c in candidates: + # current refined candidates is not benchmarked yet, a close values we would expect to be the parent candidate + parent_id = c.parent_id + if parent_id is None: + continue + parent_candidate_node = self.forest.get_node(parent_id) + parent_optimized_runtime = self.eval_ctx.get_optimized_runtime(parent_id) + if not parent_optimized_runtime or not parent_candidate_node: + continue + diff_lens_list.append( + diff_length(self.original_markdown_code, parent_candidate_node.candidate.source_code.markdown) + ) + runtimes_list.append(parent_optimized_runtime) + + if not runtimes_list or not diff_lens_list: + # should not happen + logger.warning("No valid candidates for refinement while filtering") + return candidates + + runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS + weights = choose_weights(runtime=runtime_w, diff=diff_w) + + runtime_norm = normalize_by_max(runtimes_list) + diffs_norm = normalize_by_max(diff_lens_list) + # the lower the better + score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm) + top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates] # type: ignore[arg-type] + + return [candidates[idx] for idx in top_indecies] + + def is_done(self) -> bool: + """Check if processing is complete.""" + return ( + self.line_profiler_done + and self.refinement_done + and len(self.future_all_code_repair) == 0 + and len(self.future_adaptive_optimizations) == 0 + and self.candidate_queue.empty() + ) + + +@dataclass +class CandidateEvaluationContext: + """Holds tracking state during candidate evaluation in determine_best_candidate.""" + + speedup_ratios: dict[str, float | None] = Field(default_factory=dict) + optimized_runtimes: dict[str, float | None] = Field(default_factory=dict) + is_correct: dict[str, bool] = Field(default_factory=dict) + optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict) + ast_code_to_id: dict = Field(default_factory=dict) + optimizations_post: dict[str, str] = Field(default_factory=dict) + valid_optimizations: list = Field(default_factory=list) + + def record_failed_candidate(self, optimization_id: str) -> None: + """Record results for a failed candidate.""" + self.optimized_runtimes[optimization_id] = None + self.is_correct[optimization_id] = False + self.speedup_ratios[optimization_id] = None + + def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None: + """Record results for a successful candidate.""" + self.optimized_runtimes[optimization_id] = runtime + self.is_correct[optimization_id] = True + self.speedup_ratios[optimization_id] = speedup + + def record_line_profiler_result(self, optimization_id: str, result: str) -> None: + """Record line profiler results for a candidate.""" + self.optimized_line_profiler_results[optimization_id] = result + + def handle_duplicate_candidate( + self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext + ) -> None: + """Handle a candidate that has been seen before.""" + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + + # Copy results from the previous evaluation + self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id] + self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id] + self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id] + + # Line profiler results only available for successful runs + if past_opt_id in self.optimized_line_profiler_results: + self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[ + past_opt_id + ] + + self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + + # Update to shorter code if this candidate has a shorter diff + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + + def register_new_candidate( + self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext + ) -> None: + """Register a new candidate that hasn't been seen before.""" + self.ast_code_to_id[normalized_code] = { + "optimization_id": candidate.optimization_id, + "shorter_source_code": candidate.source_code, + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + } + + def get_speedup_ratio(self, optimization_id: str) -> float | None: + return self.speedup_ratios.get(optimization_id) + + def get_optimized_runtime(self, optimization_id: str) -> float | None: + return self.optimized_runtimes.get(optimization_id) diff --git a/src/codeflash_python/optimizer_mixins/code_replacement.py b/src/codeflash_python/optimizer_mixins/code_replacement.py new file mode 100644 index 000000000..6f8faad55 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/code_replacement.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.formatter import format_code, sort_imports + +if TYPE_CHECKING: + from codeflash_python.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionSource + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class CodeReplacementMixin(_Base): + @staticmethod + def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None: + with path.open("w", encoding="utf8") as f: + f.write(original_code) + for module_abspath, helper_code in original_helper_code.items(): + with Path(module_abspath).open("w", encoding="utf8") as f: + f.write(helper_code) + + def reformat_code_and_helpers( + self, + helper_functions: list[FunctionSource], + path: Path, + original_code: str, + optimized_context: CodeStringsMarkdown, + ) -> tuple[str, dict[Path, str]]: + assert self.args is not None + should_sort_imports = not self.args.disable_imports_sorting + if should_sort_imports and sort_imports(code=original_code) != original_code: + should_sort_imports = False + + optimized_code = "" + if optimized_context is not None: + file_to_code_context = optimized_context.file_to_path() + optimized_code = file_to_code_context.get(str(path.resolve().relative_to(self.project_root)), "") + + new_code = format_code( + self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False + ) + if should_sort_imports: + new_code = sort_imports(new_code) + + new_helper_code: dict[Path, str] = {} + for hp in helper_functions: + module_abspath = hp.file_path + hp_source_code = hp.source_code + formatted_helper_code = format_code( + self.args.formatter_cmds, + module_abspath, + optimized_code=hp_source_code, + check_diff=True, + exit_on_failure=False, + ) + if should_sort_imports: + formatted_helper_code = sort_imports(formatted_helper_code) + new_helper_code[module_abspath] = formatted_helper_code + + return new_code, new_helper_code + + def group_functions_by_file(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]: + functions_by_file: dict[Path, set[str]] = defaultdict(set) + functions_by_file[self.function_to_optimize.file_path].add(self.function_to_optimize.qualified_name) + for helper in code_context.helper_functions: + if helper.definition_type in ("function", None): + functions_by_file[helper.file_path].add(helper.qualified_name) + return functions_by_file + + def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None: + logger.info("Reverting code and helpers...") + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + self.cleanup_async_helper_file() diff --git a/src/codeflash_python/optimizer_mixins/refinement.py b/src/codeflash_python/optimizer_mixins/refinement.py new file mode 100644 index 000000000..d52311296 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/refinement.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.config_consts import MIN_CORRECT_CANDIDATES, EffortKeys, get_effort_value +from codeflash_python.models.models import OptimizedCandidateSource + +if TYPE_CHECKING: + import concurrent.futures + + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.api.types import TestDiff + from codeflash_python.models.models import CodeOptimizationContext, OptimizedCandidate + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base + from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class RefinementMixin(_Base): + def call_adaptive_optimize( + self, + trace_id: str, + original_source_code: str, + prev_candidates: list[OptimizedCandidate], + eval_ctx: CandidateEvaluationContext, + ai_service_client: AiServiceClient, + ) -> concurrent.futures.Future[OptimizedCandidate | None] | None: + if self.adaptive_optimization_counter >= get_effort_value( + EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, self.effort + ): + logger.debug( + "Max adaptive optimizations reached for %s: %s", + self.function_to_optimize.qualified_name, + self.adaptive_optimization_counter, + ) + return None + + adaptive_count = sum(1 for c in prev_candidates if c.source == OptimizedCandidateSource.ADAPTIVE) + + if adaptive_count >= get_effort_value(EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD, self.effort): + return None + + from codeflash_python.api.types import AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest + + request_candidates = [] + + for c in prev_candidates: + speedup = eval_ctx.get_speedup_ratio(c.optimization_id) + request_candidates.append( + AdaptiveOptimizedCandidate( + optimization_id=c.optimization_id, + source_code=c.source_code.markdown, + explanation=c.explanation, + source=c.source, + speedup=f"Performance gain: {int(speedup * 100 + 0.5)}%" + if speedup + else "Candidate didn't match the behavior of the original code", + ) + ) + + request = AIServiceAdaptiveOptimizeRequest( + trace_id=trace_id, original_source_code=original_source_code, candidates=request_candidates + ) + self.adaptive_optimization_counter += 1 + return self.executor.submit(ai_service_client.adaptive_optimize, request=request) + + def repair_optimization( + self, + original_source_code: str, + modified_source_code: str, + test_diffs: list[TestDiff], + trace_id: str, + optimization_id: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + language: str = "python", + ) -> concurrent.futures.Future[OptimizedCandidate | None]: + from codeflash_python.api.types import AIServiceCodeRepairRequest + + request = AIServiceCodeRepairRequest( + optimization_id=optimization_id, + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_diffs=test_diffs, + trace_id=trace_id, + language=language, + ) + return executor.submit(ai_service_client.code_repair, request=request) + + def repair_if_possible( + self, + candidate: OptimizedCandidate, + diffs: list[TestDiff], + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + test_results_count: int, + exp_type: str, + ) -> None: + max_repairs = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, self.effort) + if self.repair_counter >= max_repairs: + logger.debug("Repair counter reached %s, skipping repair", max_repairs) + return + + successful_candidates_count = sum(1 for is_correct in eval_ctx.is_correct.values() if is_correct) + if successful_candidates_count >= MIN_CORRECT_CANDIDATES: + logger.debug("%s of the candidates were correct, no need to repair", successful_candidates_count) + return + + if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP): + # only repair the first pass of the candidates for now + logger.debug("Candidate is a result of %s, skipping repair", candidate.source.value) + return + if not diffs: + logger.debug("No diffs found, skipping repair") + return + result_unmatched_perc = len(diffs) / test_results_count + if result_unmatched_perc > get_effort_value(EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT, self.effort): + logger.debug("Result unmatched percentage is %s%%, skipping repair", result_unmatched_perc * 100) + return + + logger.debug( + "Adding a candidate for repair, with %s diffs, (%s%% unmatched)", len(diffs), result_unmatched_perc * 100 + ) + # start repairing + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + assert ai_service_client is not None + self.repair_counter += 1 + self.future_all_code_repair.append( + self.repair_optimization( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + executor=self.executor, + language=self.function_to_optimize.language, + ) + ) diff --git a/src/codeflash_python/optimizer_mixins/result_processing.py b/src/codeflash_python/optimizer_mixins/result_processing.py new file mode 100644 index 000000000..6d4b88264 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/result_processing.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_python.api.cfapi import create_staging, mark_optimization_success +from codeflash_python.benchmarking.utils import process_benchmark_data +from codeflash_python.code_utils import env_utils +from codeflash_python.code_utils.formatter import format_generated_code +from codeflash_python.code_utils.git_utils import git_root_dir +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.result.create_pr import check_create_pr, existing_tests_source_for +from codeflash_python.result.critic import concurrency_gain, get_acceptance_reason, performance_gain, throughput_gain +from codeflash_python.result.explanation import Explanation +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.verification.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + remove_functions_from_generated_tests, +) + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_python.models.models import ( + BestOptimization, + CodeOptimizationContext, + FunctionCalledInTest, + GeneratedTestsList, + OptimizationSet, + OriginalCodeBaseline, + ) + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class ResultProcessingMixin(_Base): + def find_and_process_best_optimization( + self, + optimizations_set: OptimizationSet, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + function_to_optimize_qualified_name: str, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + function_references: str, + ) -> BestOptimization | None: + """Find the best optimization candidate and process it with all required steps.""" + assert self.args is not None + best_optimization = None + for _u, (candidates, exp_type) in enumerate( + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) + ): + if candidates is None: + continue + + best_optimization = self.determine_best_candidate( + candidates=candidates, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + exp_type=exp_type, + function_references=function_references, + ) + ph( + "cli-optimize-function-finished", + { + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id + }, + ) + + if best_optimization: + logger.info("h2|Best candidate 🚀") + processed_benchmark_info = None + if self.args.benchmark and best_optimization.replay_performance_gain is not None: + processed_benchmark_info = process_benchmark_data( + replay_performance_gain=best_optimization.replay_performance_gain, + fto_benchmark_timings=self.function_benchmark_timings, + total_benchmark_timings=self.total_benchmark_timings, + ) + acceptance_reason = get_acceptance_reason( + original_runtime_ns=original_code_baseline.runtime, + optimized_runtime_ns=best_optimization.runtime, + original_async_throughput=original_code_baseline.async_throughput, + optimized_async_throughput=best_optimization.async_throughput, + original_concurrency_metrics=original_code_baseline.concurrency_metrics, + optimized_concurrency_metrics=best_optimization.concurrency_metrics, + ) + explanation = Explanation( + raw_explanation_message=best_optimization.candidate.explanation, + winning_behavior_test_results=best_optimization.winning_behavior_test_results, + winning_benchmarking_test_results=best_optimization.winning_benchmarking_test_results, + original_runtime_ns=original_code_baseline.runtime, + best_runtime_ns=best_optimization.runtime, + function_name=function_to_optimize_qualified_name, + file_path=self.function_to_optimize.file_path, + benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, + original_async_throughput=original_code_baseline.async_throughput, + best_async_throughput=best_optimization.async_throughput, + original_concurrency_metrics=original_code_baseline.concurrency_metrics, + best_concurrency_metrics=best_optimization.concurrency_metrics, + acceptance_reason=acceptance_reason, + ) + + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=best_optimization.candidate.source_code, + original_helper_code=original_helper_code, + ) + + new_code, new_helper_code = self.reformat_code_and_helpers( + code_context.helper_functions, + explanation.file_path, + self.function_to_optimize_source_code, + optimized_context=best_optimization.candidate.source_code, + ) + + original_code_combined = original_helper_code.copy() + original_code_combined[explanation.file_path] = self.function_to_optimize_source_code + new_code_combined = new_helper_code.copy() + new_code_combined[explanation.file_path] = new_code + self.process_review( + original_code_baseline, + best_optimization, + generated_tests, + test_functions_to_remove, + concolic_test_str, + original_code_combined, + new_code_combined, + explanation, + function_to_all_tests, + exp_type, + original_helper_code, + code_context, + function_references, + ) + return best_optimization + + def process_review( + self, + original_code_baseline: OriginalCodeBaseline, + best_optimization: BestOptimization, + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + original_code_combined: dict[Path, str], + new_code_combined: dict[Path, str], + explanation: Explanation, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + exp_type: str, + original_helper_code: dict[Path, str], + code_context: CodeOptimizationContext, + function_references: str, + ) -> None: + from codeflash_python.api.types import OptimizationReviewResult + from codeflash_python.models.function_types import qualified_name_with_modules_from_root + + assert self.args is not None + assert self.aiservice_client is not None + coverage_message = ( + original_code_baseline.coverage_results.build_message() + if original_code_baseline.coverage_results + else "Coverage data not available" + ) + + generated_tests = remove_functions_from_generated_tests(generated_tests, test_functions_to_remove) + map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests( + test_functions_to_remove + ) + + original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() + optimized_runtime_by_test = ( + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() + ) + + generated_tests = add_runtime_comments_to_generated_tests( + generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir + ) + + generated_tests_str = "" + code_lang = self.function_to_optimize.language + for test in generated_tests.generated_tests: + if any( + test_file.name == test.behavior_file_path.name and count > 0 + for test_file, count in map_gen_test_file_to_no_of_tests.items() + ): + formatted_generated_test = format_generated_code( + test.generated_original_test_source, self.args.formatter_cmds + ) + generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```" + generated_tests_str += "\n\n" + + if concolic_test_str: + formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) + generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n" + + existing_tests, replay_tests, _concolic_tests = existing_tests_source_for( + qualified_name_with_modules_from_root(self.function_to_optimize, self.project_root), + function_to_all_tests, + test_cfg=self.test_cfg, + original_runtimes_all=original_runtime_by_test, + optimized_runtimes_all=optimized_runtime_by_test, + test_files_registry=self.test_files, + ) + original_throughput_str = None + optimized_throughput_str = None + throughput_improvement_str = None + original_concurrency_ratio_str = None + optimized_concurrency_ratio_str = None + concurrency_improvement_str = None + + if ( + self.function_to_optimize.is_async + and original_code_baseline.async_throughput is not None + and best_optimization.async_throughput is not None + ): + original_throughput_str = f"{original_code_baseline.async_throughput} operations/second" + optimized_throughput_str = f"{best_optimization.async_throughput} operations/second" + throughput_improvement_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=best_optimization.async_throughput, + ) + throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%" + + if original_code_baseline.concurrency_metrics is not None and best_optimization.concurrency_metrics is not None: + original_concurrency_ratio_str = f"{original_code_baseline.concurrency_metrics.concurrency_ratio:.2f}x" + optimized_concurrency_ratio_str = f"{best_optimization.concurrency_metrics.concurrency_ratio:.2f}x" + conc_improvement_value = concurrency_gain( + original_code_baseline.concurrency_metrics, best_optimization.concurrency_metrics + ) + concurrency_improvement_str = f"{conc_improvement_value * 100:.1f}%" + + new_explanation_raw_str = self.aiservice_client.get_new_explanation( + source_code=code_context.read_writable_code.flat, + dependency_code=code_context.read_only_context_code, + trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + optimized_code=best_optimization.candidate.source_code.flat, + original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], + optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], + original_code_runtime=humanize_runtime(original_code_baseline.runtime), + optimized_code_runtime=humanize_runtime(best_optimization.runtime), + speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", + annotated_tests=generated_tests_str, + optimization_id=best_optimization.candidate.optimization_id, + original_explanation=best_optimization.candidate.explanation, + original_throughput=original_throughput_str, + optimized_throughput=optimized_throughput_str, + throughput_improvement=throughput_improvement_str, + function_references=function_references, + acceptance_reason=explanation.acceptance_reason.value, + original_concurrency_ratio=original_concurrency_ratio_str, + optimized_concurrency_ratio=optimized_concurrency_ratio_str, + concurrency_improvement=concurrency_improvement_str, + ) + new_explanation = Explanation( + raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message, + winning_behavior_test_results=explanation.winning_behavior_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + original_runtime_ns=explanation.original_runtime_ns, + best_runtime_ns=explanation.best_runtime_ns, + function_name=explanation.function_name, + file_path=explanation.file_path, + benchmark_details=explanation.benchmark_details, + original_async_throughput=explanation.original_async_throughput, + best_async_throughput=explanation.best_async_throughput, + original_concurrency_metrics=explanation.original_concurrency_metrics, + best_concurrency_metrics=explanation.best_concurrency_metrics, + acceptance_reason=explanation.acceptance_reason, + ) + self.log_successful_optimization(new_explanation, generated_tests, exp_type) + + best_optimization.explanation_v2 = new_explanation.explanation_message() + + data = { + "original_code": original_code_combined, + "new_code": new_code_combined, + "explanation": new_explanation, + "existing_tests_source": existing_tests, + "generated_original_test_source": generated_tests_str, + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + "coverage_message": coverage_message, + "replay_tests": replay_tests, + # "concolic_tests": concolic_tests, + "language": self.function_to_optimize.language, + # "original_line_profiler": original_code_baseline.line_profile_results.get("str_out", ""), + # "optimized_line_profiler": best_optimization.line_profiler_test_results.get("str_out", ""), + } + + raise_pr = not self.args.no_pr + staging_review = self.args.staging_review + + opt_review_result = OptimizationReviewResult(review="", explanation="") + # this will now run regardless of pr, staging review flags + try: + opt_review_result = self.aiservice_client.get_optimization_review( + **data, + calling_fn_details=function_references, # type: ignore[invalid-argument-type] + ) + except Exception as e: + logger.debug("optimization review response failed, investigate %s", e) + data["optimization_review"] = opt_review_result.review + self.optimization_review = opt_review_result.review + + # Display the reviewer result to the user + from git import Repo as GitRepo + + if raise_pr or staging_review: + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) + if raise_pr and not staging_review and opt_review_result.review != "low": + # Ensure root_dir is set for PR creation (needed for async functions that skip opt_review) + if "root_dir" not in data: + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) + data["git_remote"] = self.args.git_remote + # Remove language from data dict as check_create_pr doesn't accept it + pr_data = {k: v for k, v in data.items() if k != "language"} + check_create_pr(**pr_data) # type: ignore[invalid-argument-type] + elif staging_review: + response = create_staging(**data) # type: ignore[invalid-argument-type] + + else: + # Mark optimization success since no PR will be created + mark_optimization_success( + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None + ) + + # If worktree mode, do not revert code and helpers, otherwise we would have an empty diff when writing the patch in the lsp + if self.args.worktree: + return + + if raise_pr and ( + self.args.all + or env_utils.get_pr_number() + or self.args.replay_test + or (self.args.file and not self.args.function) + ): + self.revert_code_and_helpers(original_helper_code) + return + + if staging_review: + # always revert code and helpers when staging review + self.revert_code_and_helpers(original_helper_code) + return + + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str + ) -> None: + ph( + "cli-optimize-success", + { + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + "speedup_x": explanation.speedup_x, + "speedup_pct": explanation.speedup_pct, + "best_runtime": explanation.best_runtime_ns, + "original_runtime": explanation.original_runtime_ns, + "winning_test_results": { + tt.to_name(): v + for tt, v in explanation.winning_behavior_test_results.get_test_pass_fail_report_by_type().items() + }, + }, + ) diff --git a/src/codeflash_python/optimizer_mixins/scoring.py b/src/codeflash_python/optimizer_mixins/scoring.py new file mode 100644 index 000000000..d80b9cdcb --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/scoring.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import difflib + + +def choose_weights(**importance: float) -> list[float]: + """Choose normalized weights from relative importance values. + + Example: + choose_weights(runtime=3, diff=1) + -> [0.75, 0.25] + + Args: + **importance: keyword args of metric=importance (relative numbers). + + Returns: + A list of weights in the same order as the arguments. + + """ + total = sum(importance.values()) + if total == 0: + raise ValueError("At least one importance value must be > 0") + + return [v / total for v in importance.values()] + + +def normalize_by_max(values: list[float]) -> list[float]: + mx = max(values) + if mx == 0: + return [0.0] * len(values) + return [v / mx for v in values] + + +def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, float]: + """Combine multiple metrics into a single weighted score dictionary. + + Each metric is a list of values (smaller = better). + The total score for each index is the weighted sum of its values + across all metrics: + + score[index] = Σ (value * weight) + + Args: + weights: A list of weights, one per metric. Larger weight = more influence. + *metrics: Lists of values (one list per metric, aligned by index). + + Returns: + A dictionary mapping each index to its combined weighted score. + + """ + if len(weights) != len(metrics): + raise ValueError("Number of weights must match number of metrics") + + combined: dict[int, float] = {} + + for weight, metric in zip(weights, metrics): + for idx, value in enumerate(metric): + combined[idx] = combined.get(idx, 0) + value * weight + + return combined + + +def diff_length(a: str, b: str) -> int: + """Compute the length (in characters) of the unified diff between two strings. + + Args: + a (str): Original string. + b (str): Modified string. + + Returns: + int: Total number of characters in the diff. + + """ + # Split input strings into lines for line-by-line diff + a_lines = a.splitlines(keepends=True) + b_lines = b.splitlines(keepends=True) + + # Compute unified diff + diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm="")) + + # Join all lines with newline to calculate total diff length + diff_text = "\n".join(diff_lines) + + return len(diff_text) + + +def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]: + """Create a dictionary from a list of ints, mapping the original index to its rank. + + This version uses a more compact, "Pythonic" implementation. + + Args: + int_array: A list of integers. + + Returns: + A dictionary where keys are original indices and values are the + rank of the element in ascending order. + + """ + # Sort the indices of the array based on their corresponding values + sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i]) + + # Create a dictionary mapping the original index to its rank (its position in the sorted list) + return {original_index: rank for rank, original_index in enumerate(sorted_indices)} diff --git a/src/codeflash_python/optimizer_mixins/test_execution.py b/src/codeflash_python/optimizer_mixins/test_execution.py new file mode 100644 index 000000000..c3fbf5aa9 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/test_execution.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import logging +import subprocess +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.config_consts import INDIVIDUAL_TESTCASE_TIMEOUT, TOTAL_LOOPING_TIME_EFFECTIVE +from codeflash_python.models.models import TestingMode, TestResults, TestType +from codeflash_python.verification.instrument_existing_tests import inject_profiling_into_existing_test +from codeflash_python.verification.parse_test_output import parse_test_results +from codeflash_python.verification.test_output_utils import parse_concurrency_metrics +from codeflash_python.verification.test_runner import ( + run_behavioral_tests, + run_benchmarking_tests, + run_line_profile_tests, +) + +if TYPE_CHECKING: + from codeflash_python.models.models import ( + CodeOptimizationContext, + ConcurrencyMetrics, + CoverageData, + FunctionCalledInTest, + TestFiles, + ) + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class TestExecutionMixin(_Base): + def run_and_parse_tests( + self, + testing_type: TestingMode, + test_env: dict[str, str], + test_files: TestFiles, + optimization_iteration: int, + testing_time: float = TOTAL_LOOPING_TIME_EFFECTIVE, + *, + enable_coverage: bool = False, + pytest_min_loops: int = 5, + pytest_max_loops: int = 250, + code_context: CodeOptimizationContext | None = None, + line_profiler_output_file: Path | None = None, + ) -> tuple[TestResults | dict, CoverageData | None]: + assert self.project_root is not None + coverage_database_file = None + coverage_config_file = None + try: + if testing_type == TestingMode.BEHAVIOR: + result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests( + test_paths=test_files, + test_env=test_env, + cwd=self.project_root, + timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + enable_coverage=enable_coverage, + candidate_index=optimization_iteration, + ) + elif testing_type == TestingMode.LINE_PROFILE: + result_file_path, run_result = run_line_profile_tests( + test_paths=test_files, + test_env=test_env, + cwd=self.project_root, + timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + line_profile_output_file=line_profiler_output_file, + ) + elif testing_type == TestingMode.PERFORMANCE: + result_file_path, run_result = run_benchmarking_tests( + test_paths=test_files, + test_env=test_env, + cwd=self.project_root, + timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + min_loops=pytest_min_loops, + max_loops=pytest_max_loops, + target_duration_seconds=testing_time, + ) + else: + msg = f"Unexpected testing type: {testing_type}" + raise ValueError(msg) + except subprocess.TimeoutExpired: + logger.exception( + "Error running tests in %s.\nTimeout Error", ", ".join(str(f) for f in test_files.test_files) + ) + return TestResults(), None + if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}: + assert self.test_cfg is not None + results, coverage_results = parse_test_results( + test_xml_path=result_file_path, + test_files=test_files, + test_config=self.test_cfg, + optimization_iteration=optimization_iteration, + run_result=run_result, + function_name=self.function_to_optimize.qualified_name, + source_file=self.function_to_optimize.file_path, + code_context=code_context, + coverage_database_file=coverage_database_file, + coverage_config_file=coverage_config_file, + ) + if testing_type == TestingMode.PERFORMANCE: + results.perf_stdout = run_result.stdout + return results, coverage_results + return self.parse_line_profile_test_results(line_profiler_output_file) + + def run_behavioral_validation( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + ) -> tuple[TestResults, CoverageData | None] | None: + """Run behavioral tests only. Returns (results, coverage) or None if no tests ran.""" + test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) + if self.function_to_optimize.is_async: + self.instrument_async_for_mode(TestingMode.BEHAVIOR) + try: + self.instrument_capture(file_path_to_helper_classes) + behavioral_results, coverage_results = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=True, + code_context=code_context, + ) + finally: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + if isinstance(behavioral_results, TestResults) and behavioral_results: + return behavioral_results, coverage_results + return None + + def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]: + from codeflash_python.models.function_types import qualified_name_with_modules_from_root + from codeflash_python.models.models import TestFile + + assert self.project_root is not None + existing_test_files_count = 0 + replay_test_files_count = 0 + concolic_coverage_test_files_count = 0 + unique_instrumented_test_files = set() + + func_qualname = qualified_name_with_modules_from_root(self.function_to_optimize, self.project_root) + if func_qualname not in function_to_all_tests: + logger.info("Did not find any pre-existing tests for '%s', will only use generated tests.", func_qualname) + return unique_instrumented_test_files + + test_file_invocation_positions = defaultdict(list) + tests_in_file_set = function_to_all_tests.get(func_qualname) + if tests_in_file_set is None: + return unique_instrumented_test_files + for tests_in_file in tests_in_file_set: + test_file_invocation_positions[ + (tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type) + ].append(tests_in_file) + + for test_file, test_type in test_file_invocation_positions: + path_obj_test_file = Path(test_file) + if test_type == TestType.EXISTING_UNIT_TEST: + existing_test_files_count += 1 + elif test_type == TestType.REPLAY_TEST: + replay_test_files_count += 1 + elif test_type == TestType.CONCOLIC_COVERAGE_TEST: + concolic_coverage_test_files_count += 1 + else: + msg = f"Unexpected test type: {test_type}" + raise ValueError(msg) + + if existing_test_files_count > 0 or replay_test_files_count > 0 or concolic_coverage_test_files_count > 0: + logger.info( + "Discovered %s existing unit test file%s, %s replay test file%s, and " + "%s concolic coverage test file%s for %s", + existing_test_files_count, + "s" if existing_test_files_count != 1 else "", + replay_test_files_count, + "s" if replay_test_files_count != 1 else "", + concolic_coverage_test_files_count, + "s" if concolic_coverage_test_files_count != 1 else "", + func_qualname, + ) + + assert self.test_cfg is not None + for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items(): + path_obj_test_file = Path(test_file) + # Use language-specific instrumentation + success, injected_behavior_test = inject_profiling_into_existing_test( + test_path=path_obj_test_file, + call_positions=[test.position for test in tests_in_file_list], + function_to_optimize=self.function_to_optimize, + tests_project_root=self.test_cfg.tests_project_rootdir, + mode=TestingMode.BEHAVIOR, + ) + if not success: + logger.debug("Failed to instrument test file %s for behavior testing", test_file) + continue + + success, injected_perf_test = inject_profiling_into_existing_test( + test_path=path_obj_test_file, + call_positions=[test.position for test in tests_in_file_list], + function_to_optimize=self.function_to_optimize, + tests_project_root=self.test_cfg.tests_project_rootdir, + mode=TestingMode.PERFORMANCE, + ) + if not success: + logger.debug("Failed to instrument test file %s for performance testing", test_file) + continue + + def get_instrumented_path(original_path: str, suffix: str) -> Path: + path_obj = Path(original_path) + return path_obj.parent / f"{path_obj.stem}{suffix}{path_obj.suffix}" + + new_behavioral_test_path = get_instrumented_path(test_file, "__perfinstrumented") + new_perf_test_path = get_instrumented_path(test_file, "__perfonlyinstrumented") + + if injected_behavior_test is not None: + with new_behavioral_test_path.open("w", encoding="utf8") as _f: + _f.write(injected_behavior_test) + logger.debug("[PIPELINE] Wrote instrumented behavior test to %s", new_behavioral_test_path) + else: + msg = "injected_behavior_test is None" + raise ValueError(msg) + + if injected_perf_test is not None: + with new_perf_test_path.open("w", encoding="utf8") as _f: + _f.write(injected_perf_test) + logger.debug("[PIPELINE] Wrote instrumented perf test to %s", new_perf_test_path) + + unique_instrumented_test_files.add(new_behavioral_test_path) + unique_instrumented_test_files.add(new_perf_test_path) + + if not self.test_files.get_by_original_file_path(path_obj_test_file): + self.test_files.add( + TestFile( + instrumented_behavior_file_path=new_behavioral_test_path, + benchmarking_file_path=new_perf_test_path, + original_source=None, + original_file_path=Path(test_file), + test_type=test_type, + tests_in_file=[t.tests_in_file for t in tests_in_file_list], + ) + ) + + instrumented_count = len(unique_instrumented_test_files) // 2 # each test produces behavior + perf files + if instrumented_count > 0: + logger.info( + "Instrumented %s existing unit test file%s for %s", + instrumented_count, + "s" if instrumented_count != 1 else "", + func_qualname, + ) + return unique_instrumented_test_files + + def run_concurrency_benchmark( + self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str] + ) -> ConcurrencyMetrics | None: + """Run concurrency benchmark to measure sequential vs concurrent execution for async functions. + + This benchmark detects blocking vs non-blocking async code by comparing: + - Sequential execution time (running N iterations one after another) + - Concurrent execution time (running N iterations in parallel with asyncio.gather) + + Blocking code (like time.sleep) will have similar sequential and concurrent times. + Non-blocking code (like asyncio.sleep) will be much faster when run concurrently. + + Returns: + ConcurrencyMetrics if benchmark ran successfully, None otherwise. + + """ + if not self.function_to_optimize.is_async: + return None + + from codeflash_python.verification.async_instrumentation import add_async_decorator_to_function + + assert self.project_root is not None + try: + # Add concurrency decorator to the source function + add_async_decorator_to_function( + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.CONCURRENCY, + project_root=self.project_root, + ) + + # Run the concurrency benchmark tests + concurrency_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, # Use performance mode for running + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=5.0, # Short benchmark time + enable_coverage=False, + code_context=code_context, + pytest_min_loops=1, + pytest_max_loops=3, + ) + except Exception as e: + logger.debug("Concurrency benchmark failed: %s", e) + return None + finally: + # Restore original code + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + + # Parse concurrency metrics from stdout + from codeflash_python.models.models import TestResults as TestResultsInternal + + if ( + concurrency_results + and isinstance(concurrency_results, TestResultsInternal) + and concurrency_results.perf_stdout + ): + return parse_concurrency_metrics(concurrency_results, self.function_to_optimize.function_name) + + return None diff --git a/src/codeflash_python/optimizer_mixins/test_generation.py b/src/codeflash_python/optimizer_mixins/test_generation.py new file mode 100644 index 000000000..97ecae2ae --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/test_generation.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import concurrent.futures +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.danom import Err, Ok +from codeflash_python.code_utils.config_consts import INDIVIDUAL_TESTCASE_TIMEOUT, EffortKeys, get_effort_value +from codeflash_python.models.models import GeneratedTests, GeneratedTestsList +from codeflash_python.verification.verifier import generate_tests + +if TYPE_CHECKING: + from codeflash_core.danom import Result + from codeflash_python.models.models import ( + CodeOptimizationContext, + CodeStringsMarkdown, + FunctionCalledInTest, + FunctionSource, + ) + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +from codeflash_python.verification.concolic import generate_concolic_tests + +logger = logging.getLogger("codeflash_python") + + +class TestGenerationMixin(_Base): + def generate_tests( + self, + testgen_context: CodeStringsMarkdown, + helper_functions: list[FunctionSource], + testgen_helper_fqns: list[str], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + ) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]: + """Generate unit tests and concolic tests for the function.""" + assert self.args is not None + n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) + assert len(generated_test_paths) == n_tests + + if not self.args.no_gen_tests: + helper_fqns = testgen_helper_fqns or [definition.fully_qualified_name for definition in helper_functions] + future_tests = self.submit_test_generation_tasks( + self.executor, testgen_context.markdown, helper_fqns, generated_test_paths, generated_perf_test_paths + ) + + future_concolic_tests = self.executor.submit( + generate_concolic_tests, + self.test_cfg, + self.args.project_root, + self.function_to_optimize, + self.function_to_optimize_ast, + ) + + if not self.args.no_gen_tests: + # Wait for test futures to complete + futures_to_wait = [*future_tests] + if future_concolic_tests is not None: + futures_to_wait.append(future_concolic_tests) + concurrent.futures.wait(futures_to_wait) + elif future_concolic_tests is not None: + concurrent.futures.wait([future_concolic_tests]) + # Process test generation results + tests: list[GeneratedTests] = [] + if not self.args.no_gen_tests: + for future in future_tests: + res = future.result() + if res: + ( + generated_test_source, + instrumented_behavior_test_source, + instrumented_perf_test_source, + raw_generated_test_source, + test_behavior_path, + test_perf_path, + ) = res + tests.append( + GeneratedTests( + generated_original_test_source=generated_test_source, + instrumented_behavior_test_source=instrumented_behavior_test_source, + instrumented_perf_test_source=instrumented_perf_test_source, + raw_generated_test_source=raw_generated_test_source, + behavior_file_path=test_behavior_path, + perf_file_path=test_perf_path, + ) + ) + + if not tests: + logger.warning( + "Failed to generate and instrument tests for %s", self.function_to_optimize.function_name + ) + return Err(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}") + + if future_concolic_tests is not None: + function_to_concolic_tests, concolic_test_str = future_concolic_tests.result() + else: + function_to_concolic_tests, concolic_test_str = {}, None + count_tests = len(tests) + if concolic_test_str: + count_tests += 1 + + generated_tests = GeneratedTestsList(generated_tests=tests) + return Ok((count_tests, generated_tests, function_to_concolic_tests, concolic_test_str)) + + def submit_test_generation_tasks( + self, + executor: concurrent.futures.ThreadPoolExecutor, + source_code_being_tested: str, + helper_function_names: list[str], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + ) -> list[concurrent.futures.Future]: + assert self.aiservice_client is not None + assert self.test_cfg is not None + return [ + executor.submit( + generate_tests, + self.aiservice_client, + source_code_being_tested, + self.function_to_optimize, + helper_function_names, + Path(self.original_module_path), + self.test_cfg, # type: ignore[arg-type] + INDIVIDUAL_TESTCASE_TIMEOUT, + self.function_trace_id, + test_index, + test_path, + test_perf_path, + self.is_numerical_code, + ) + for test_index, (test_path, test_perf_path) in enumerate( + zip(generated_test_paths, generated_perf_test_paths) + ) + ] + + def generate_and_instrument_tests( + self, code_context: CodeOptimizationContext + ) -> Result[ + tuple[ + GeneratedTestsList, + dict[str, set[FunctionCalledInTest]], + str, + list[Path], + list[Path], + set[Path], + dict[Path, str] | None, + ], + str, + ]: + """Generate and instrument tests for the function.""" + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.models.models import TestFile, TestType + from codeflash_python.verification.verification_utils import get_test_file_path + + n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) + source_file = Path(self.function_to_optimize.file_path) + generated_test_paths = [ + get_test_file_path( + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="unit", + source_file_path=source_file, + ) + for test_index in range(n_tests) + ] + generated_perf_test_paths = [ + get_test_file_path( + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="perf", + source_file_path=source_file, + ) + for test_index in range(n_tests) + ] + + test_results = self.generate_tests( + testgen_context=code_context.testgen_context, + helper_functions=code_context.helper_functions, + testgen_helper_fqns=code_context.testgen_helper_fqns, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + ) + + if not test_results.is_ok(): + # Result type doesn't have unwrap_err, manually get error + return test_results # type: ignore[return-value] + + count_tests, generated_tests, function_to_concolic_tests, concolic_test_str = test_results.unwrap() + + generated_tests = self.fixup_generated_tests(generated_tests) + + logger.debug("[PIPELINE] Processing %s generated tests", count_tests) + for i, generated_test in enumerate(generated_tests.generated_tests): + logger.debug( + "[PIPELINE] Test %s: behavior_path=%s, perf_path=%s", + i + 1, + generated_test.behavior_file_path, + generated_test.perf_file_path, + ) + + with generated_test.behavior_file_path.open("w", encoding="utf8") as f: + f.write(generated_test.instrumented_behavior_test_source) + logger.debug("[PIPELINE] Wrote behavioral test to %s", generated_test.behavior_file_path) + + debug_file_path = get_run_tmp_file(Path("perf_test_debug.py")) + with debug_file_path.open("w", encoding="utf-8") as debug_f: + debug_f.write(generated_test.instrumented_perf_test_source) + + with generated_test.perf_file_path.open("w", encoding="utf8") as f: + f.write(generated_test.instrumented_perf_test_source) + logger.debug("[PIPELINE] Wrote perf test to %s", generated_test.perf_file_path) + + # File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.) + test_file_obj = TestFile( + instrumented_behavior_file_path=generated_test.behavior_file_path, + benchmarking_file_path=generated_test.perf_file_path, + original_file_path=None, + original_source=generated_test.generated_original_test_source, + test_type=TestType.GENERATED_REGRESSION, + tests_in_file=None, # This is currently unused. We can discover the tests in the file if needed. + ) + self.test_files.add(test_file_obj) + logger.debug( + "[PIPELINE] Added test file to collection: behavior=%s, perf=%s", + test_file_obj.instrumented_behavior_file_path, + test_file_obj.benchmarking_file_path, + ) + + logger.info("Generated test %s/%s", i + 1, count_tests) + if concolic_test_str: + logger.info("Generated test %s/%s", count_tests, count_tests) + + function_to_all_tests = { + key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) + for key in set(self.function_to_tests) | set(function_to_concolic_tests) + } + instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) + + assert self.args is not None + original_conftest_content = None + if self.args.override_fixtures: + original_conftest_content = self.instrument_test_fixtures(generated_test_paths + generated_perf_test_paths) + + return Ok( + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) + ) diff --git a/src/codeflash_python/optimizer_mixins/test_review.py b/src/codeflash_python/optimizer_mixins/test_review.py new file mode 100644 index 000000000..c27b3a631 --- /dev/null +++ b/src/codeflash_python/optimizer_mixins/test_review.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.danom import Err, Ok +from codeflash_python.code_utils.code_utils import encoded_tokens_len, module_name_from_file_path +from codeflash_python.code_utils.config_consts import ( + COVERAGE_THRESHOLD, + INDIVIDUAL_TESTCASE_TIMEOUT, + MAX_TEST_REPAIR_CYCLES, + OPTIMIZATION_CONTEXT_TOKEN_LIMIT, +) +from codeflash_python.models.models import TestType +from codeflash_python.telemetry.posthog_cf import ph +from codeflash_python.verification.edit_generated_tests import remove_test_functions +from codeflash_python.verification.test_runner import process_generated_test_strings + +if TYPE_CHECKING: + from typing import Any + + from codeflash_core.danom import Result + from codeflash_python.models.models import CodeOptimizationContext, CoverageData, GeneratedTestsList, TestResults + from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base +else: + _Base = object + +logger = logging.getLogger("codeflash_python") + + +class TestReviewMixin(_Base): + def review_and_repair_tests( + self, + generated_tests: GeneratedTestsList, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + ) -> Result[tuple[GeneratedTestsList, TestResults | None, CoverageData | None], str]: + """Run behavioral tests, review quality per-function, repair flagged functions. + + Flow (up to MAX_TEST_REPAIR_CYCLES): + behavioral -> collect failures -> AI review passing functions -> repair flagged -> loop + No benchmarking runs here -- only behavioral validation. + + Returns (generated_tests, behavioral_results, coverage) where behavioral/coverage are + non-None when the last cycle passed with no repairs (results can be reused by baseline). + """ + file_path_to_helper_classes = self.build_helper_classes_map(code_context) + behavioral_results: TestResults | None = None + coverage_results: CoverageData | None = None + previous_repair_errors: dict[int, dict[str, str]] = {} + # Apply token limit to function source (same progressive fallback as optimization/testgen context) + function_source_for_prompt = self.function_to_optimize_source_code + if encoded_tokens_len(function_source_for_prompt) > OPTIMIZATION_CONTEXT_TOKEN_LIMIT: + logger.debug("Function source exceeds token limit for review, extracting function only") + func = self.function_to_optimize + source_lines = self.function_to_optimize_source_code.splitlines(keepends=True) + func_start = (func.doc_start_line or func.starting_line or 1) - 1 + func_end = func.ending_line or len(source_lines) + function_source_for_prompt = "".join(source_lines[func_start:func_end]) + max_cycles = getattr(self.args, "testgen_review_turns", None) or MAX_TEST_REPAIR_CYCLES + for cycle in range(max_cycles): + validation = self.run_behavioral_validation(code_context, original_helper_code, file_path_to_helper_classes) + if validation is None: + return Err("Generated tests failed behavioral validation.") + behavioral_results, coverage_results = validation + + failed_by_file: dict[Path, list[str]] = defaultdict(list) + for result in behavioral_results.test_results: + if ( + result.test_type == TestType.GENERATED_REGRESSION + and not result.did_pass + and result.id.test_function_name + ): + failed_by_file[result.file_name].append(result.id.test_fn_qualified_name()) + + test_failure_messages = behavioral_results.test_failures or {} + + tests_for_review = [] + for i, gt in enumerate(generated_tests.generated_tests): + failed_fns = failed_by_file.get(gt.behavior_file_path, []) + failure_details = {fn: test_failure_messages[fn] for fn in failed_fns if fn in test_failure_messages} + tests_for_review.append( + { + "test_source": gt.raw_generated_test_source or gt.generated_original_test_source, + "test_index": i, + "failed_test_functions": failed_fns, + "failure_messages": failure_details, + } + ) + + coverage_summary = "" + coverage_details: dict[str, Any] | None = None + if coverage_results and coverage_results.coverage is not None: + coverage_summary = f"{coverage_results.coverage:.1f}%" + mc = coverage_results.main_func_coverage + coverage_details = { + "coverage_percentage": coverage_results.coverage, + "threshold_percentage": COVERAGE_THRESHOLD, + "function_start_line": self.function_to_optimize.starting_line, + "main_function": { + "name": mc.name, + "coverage": mc.coverage, + "executed_lines": sorted(mc.executed_lines), + "unexecuted_lines": sorted(mc.unexecuted_lines), + "executed_branches": mc.executed_branches, + "unexecuted_branches": mc.unexecuted_branches, + }, + } + dc = coverage_results.dependent_func_coverage + if dc: + coverage_details["dependent_function"] = { + "name": dc.name, + "coverage": dc.coverage, + "executed_lines": sorted(dc.executed_lines), + "unexecuted_lines": sorted(dc.unexecuted_lines), + "executed_branches": dc.executed_branches, + "unexecuted_branches": dc.unexecuted_branches, + } + + assert self.aiservice_client is not None + review_results = self.aiservice_client.review_generated_tests( + tests=tests_for_review, + function_source_code=function_source_for_prompt, + function_name=self.function_to_optimize.function_name, + trace_id=self.function_trace_id, + coverage_summary=coverage_summary, + coverage_details=coverage_details, + language=self.function_to_optimize.language, + ) + + all_to_repair = [r for r in review_results if r.functions_to_repair] + + if not all_to_repair: + return Ok((generated_tests, behavioral_results, coverage_results)) + + total_issues = 0 + for review in all_to_repair: + for _f in review.functions_to_repair: + total_issues += 1 + + any_repaired = False + repaired_files = 0 + # Snapshot all sources before repair so we can show diffs and revert on failure + original_sources: dict[int, str] = { + r.test_index: generated_tests.generated_tests[r.test_index].generated_original_test_source + for r in all_to_repair + } + pre_repair_snapshots: dict[int, tuple[str, str, str, str | None]] = { + r.test_index: ( + generated_tests.generated_tests[r.test_index].generated_original_test_source, + generated_tests.generated_tests[r.test_index].instrumented_behavior_test_source, + generated_tests.generated_tests[r.test_index].instrumented_perf_test_source, + generated_tests.generated_tests[r.test_index].raw_generated_test_source, + ) + for r in all_to_repair + } + repaired_indices: set[int] = set() + for review in all_to_repair: + gt = generated_tests.generated_tests[review.test_index] + ph( + "cli-testgen-repair", + { + "test_index": review.test_index, + "cycle": cycle + 1, + "functions": [f.function_name for f in review.functions_to_repair], + }, + ) + + test_module_path = Path( + module_name_from_file_path(gt.behavior_file_path, self.test_cfg.tests_project_rootdir) + ) + assert self.aiservice_client is not None + repair_result = self.aiservice_client.repair_generated_tests( + test_source=gt.generated_original_test_source, + functions_to_repair=review.functions_to_repair, + function_source_code=function_source_for_prompt, + module_source_code=self.function_to_optimize_source_code, + function_to_optimize=self.function_to_optimize, + helper_function_names=[f.fully_qualified_name for f in code_context.helper_functions], + module_path=Path(self.original_module_path), + test_module_path=test_module_path, + test_framework=self.test_cfg.test_framework, + test_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + trace_id=self.function_trace_id, + language=self.function_to_optimize.language, + coverage_details=coverage_details, + previous_repair_errors=previous_repair_errors.get(review.test_index), + ) + + if repair_result is None: + logger.debug("Repair failed for test %s, keeping original", review.test_index) + continue + + repaired_source, behavior_source, perf_source = repair_result + raw_repaired_source = repaired_source + repaired_source, behavior_source, perf_source = process_generated_test_strings( + generated_test_source=repaired_source, + instrumented_behavior_test_source=behavior_source, + instrumented_perf_test_source=perf_source, + function_to_optimize=self.function_to_optimize, + test_path=gt.behavior_file_path, + test_cfg=self.test_cfg, + project_module_system=None, + ) + + gt.generated_original_test_source = repaired_source + gt.instrumented_behavior_test_source = behavior_source + gt.instrumented_perf_test_source = perf_source + gt.raw_generated_test_source = raw_repaired_source + + gt.behavior_file_path.write_text(behavior_source, encoding="utf8") + gt.perf_file_path.write_text(perf_source, encoding="utf8") + any_repaired = True + repaired_files += 1 + repaired_indices.add(review.test_index) + + if not any_repaired: + logger.warning("All repair API calls failed; proceeding with unrepaired tests") + break + + validation = self.run_behavioral_validation(code_context, original_helper_code, file_path_to_helper_classes) + if validation is None: + for idx in repaired_indices: + gt = generated_tests.generated_tests[idx] + orig_source, orig_behavior, orig_perf, orig_raw = pre_repair_snapshots[idx] + gt.generated_original_test_source = orig_source + gt.instrumented_behavior_test_source = orig_behavior + gt.instrumented_perf_test_source = orig_perf + gt.raw_generated_test_source = orig_raw + gt.behavior_file_path.write_text(orig_behavior, encoding="utf8") + gt.perf_file_path.write_text(orig_perf, encoding="utf8") + return Err("Repaired tests failed behavioral validation.") + behavioral_results, coverage_results = validation + + # Collect failing and all test function names per file + still_failing_by_file: dict[Path, set[str]] = defaultdict(set) + all_fns_by_file: dict[Path, set[str]] = defaultdict(set) + for result in behavioral_results.test_results: + if result.test_type == TestType.GENERATED_REGRESSION and result.id.test_function_name: + fn_name = result.id.test_fn_qualified_name() + all_fns_by_file[result.file_name].add(fn_name) + if not result.did_pass: + still_failing_by_file[result.file_name].add(fn_name) + + reverted_indices = set() + partially_fixed_indices = set() + removed_fns_by_index: dict[int, set[str]] = {} + for idx in repaired_indices: + gt = generated_tests.generated_tests[idx] + failing_fns = still_failing_by_file.get(gt.behavior_file_path) + if not failing_fns: + continue + + all_fns_in_file = all_fns_by_file.get(gt.behavior_file_path, set()) + if failing_fns >= all_fns_in_file and all_fns_in_file: + # ALL functions fail -> full revert to pre-repair state + orig_source, orig_behavior, orig_perf, orig_raw = pre_repair_snapshots[idx] + gt.generated_original_test_source = orig_source + gt.instrumented_behavior_test_source = orig_behavior + gt.instrumented_perf_test_source = orig_perf + gt.raw_generated_test_source = orig_raw + gt.behavior_file_path.write_text(orig_behavior, encoding="utf8") + gt.perf_file_path.write_text(orig_perf, encoding="utf8") + reverted_indices.add(idx) + else: + # Partial failure -> remove only failing functions, keep passing ones + fns_to_remove = list(failing_fns) + removed_fns_by_index[idx] = set(fns_to_remove) + gt.generated_original_test_source = remove_test_functions( + gt.generated_original_test_source, fns_to_remove + ) + gt.instrumented_behavior_test_source = remove_test_functions( + gt.instrumented_behavior_test_source, fns_to_remove + ) + gt.instrumented_perf_test_source = remove_test_functions( + gt.instrumented_perf_test_source, fns_to_remove + ) + if gt.raw_generated_test_source is not None: + gt.raw_generated_test_source = remove_test_functions( + gt.raw_generated_test_source, fns_to_remove + ) + gt.behavior_file_path.write_text(gt.instrumented_behavior_test_source, encoding="utf8") + gt.perf_file_path.write_text(gt.instrumented_perf_test_source, encoding="utf8") + partially_fixed_indices.add(idx) + + # Show diffs only for repairs that survived re-validation + successful_repairs = [r for r in all_to_repair if r.test_index not in reverted_indices] + if successful_repairs: + self.display_repaired_functions(generated_tests, successful_repairs, original_sources) + + modified_indices = reverted_indices | partially_fixed_indices + if modified_indices: + messages = [] + if reverted_indices: + messages.append(f"reverted {len(reverted_indices)} test file(s)") + if partially_fixed_indices: + messages.append(f"removed failing functions from {len(partially_fixed_indices)} test file(s)") + # Collect error messages from failed functions so the next cycle can learn + revalidation_failures = behavioral_results.test_failures or {} + for idx in modified_indices: + gt = generated_tests.generated_tests[idx] + removed_fns = removed_fns_by_index.get(idx, set()) + errors_for_file: dict[str, str] = {} + for result in behavioral_results.test_results: + if ( + result.file_name == gt.behavior_file_path + and result.test_type == TestType.GENERATED_REGRESSION + and not result.did_pass + and result.id.test_function_name + ): + fn_name = result.id.test_fn_qualified_name() + if fn_name not in removed_fns: + errors_for_file[fn_name] = revalidation_failures.get(fn_name, "Test failed") + if errors_for_file: + previous_repair_errors[idx] = errors_for_file + # Invalidate behavioral results since files were modified + behavioral_results = None + coverage_results = None + + return Ok((generated_tests, behavioral_results, coverage_results)) diff --git a/src/codeflash_python/picklepatch/__init__.py b/src/codeflash_python/picklepatch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/picklepatch/pickle_patcher.py b/src/codeflash_python/picklepatch/pickle_patcher.py new file mode 100644 index 000000000..7440c34c6 --- /dev/null +++ b/src/codeflash_python/picklepatch/pickle_patcher.py @@ -0,0 +1,373 @@ +"""PicklePatcher - A utility for safely pickling objects with unpicklable components. + +This module provides functions to recursively pickle objects, replacing unpicklable +components with placeholders that provide informative errors when accessed. +""" + +from __future__ import annotations + +import contextlib +import pickle +import warnings +from typing import Any, ClassVar, cast + +import dill +from dill import PicklingWarning + +from .pickle_placeholder import PicklePlaceholder + +warnings.filterwarnings("ignore", category=PicklingWarning) + + +class PicklePatcher: + """A utility class for safely pickling objects with unpicklable components. + + This class provides methods to recursively pickle objects, replacing any + components that can't be pickled with placeholder objects. + """ + + # Class-level cache of unpicklable types + unpicklable_types: ClassVar[set[type]] = set() + + @staticmethod + def dumps(obj: object, protocol: int | None = None, max_depth: int = 100, **kwargs) -> bytes: # noqa: ANN003 + """Safely pickle an object, replacing unpicklable parts with placeholders. + + Args: + ---- + obj: The object to pickle + protocol: The pickle protocol version to use + max_depth: Maximum recursion depth + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + bytes: Pickled data with placeholders for unpicklable objects + + """ + return PicklePatcher.recursive_pickle(obj, max_depth, path=[], protocol=protocol, **kwargs) + + @staticmethod + def loads(pickled_data: bytes) -> object: + """Unpickle data that may contain placeholders. + + Args: + ---- + pickled_data: Pickled data with possible placeholders + + Returns: + ------- + The unpickled object with placeholders for unpicklable parts + + """ + return dill.loads(pickled_data) + + @staticmethod + def create_placeholder(obj: object, error_msg: str, path: list[str]) -> PicklePlaceholder: + """Create a placeholder for an unpicklable object. + + Args: + ---- + obj: The original unpicklable object + error_msg: Error message explaining why it couldn't be pickled + path: Path to this object in the object graph + + Returns: + ------- + PicklePlaceholder: A placeholder object + + """ + obj_type = type(obj) + try: + obj_str = str(obj)[:100] if hasattr(obj, "__str__") else f"" + except: # noqa: E722 + obj_str = f"" + + placeholder = PicklePlaceholder(obj_type.__name__, obj_str, error_msg, path) + + # Add this type to our known unpicklable types cache + PicklePatcher.unpicklable_types.add(obj_type) + return placeholder + + @staticmethod + def pickle( + obj: object, + path: list[str] | None = None, # noqa: ARG004 + protocol: int | None = None, + **kwargs: Any, + ) -> tuple[bool, bytes | str]: + """Try to pickle an object using pickle first, then dill. If both fail, create a placeholder. + + Args: + ---- + obj: The object to pickle + path: Path to this object in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + tuple: (success, result) where success is a boolean and result is either: + - Pickled bytes if successful + - Error message if not successful + + """ + # Try standard pickle first + try: + return True, pickle.dumps(obj, protocol=protocol, **kwargs) + except (pickle.PickleError, TypeError, AttributeError, ValueError): + # Then try dill (which is more powerful) + try: + return True, dill.dumps(obj, protocol=protocol, **kwargs) + except (dill.PicklingError, TypeError, AttributeError, ValueError) as e: + return False, str(e) + + @staticmethod + def recursive_pickle( + obj: object, + max_depth: int, + path: list[str] | None = None, + protocol: int | None = None, + **kwargs, # noqa: ANN003 + ) -> bytes: + """Recursively try to pickle an object, replacing unpicklable parts with placeholders. + + Args: + ---- + obj: The object to pickle + max_depth: Maximum recursion depth + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + bytes: Pickled data with placeholders for unpicklable objects + + """ + if path is None: + path = [] + + obj_type = type(obj) + + # Check if this type is known to be unpicklable + if obj_type in PicklePatcher.unpicklable_types: + placeholder = PicklePatcher.create_placeholder(obj, "Known unpicklable type", path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + # Check for max depth + if max_depth <= 0: + placeholder = PicklePatcher.create_placeholder(obj, "Max recursion depth exceeded", path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + # Try standard pickling + success, result = PicklePatcher.pickle(obj, path, protocol, **kwargs) + if success: + return cast("bytes", result) + + error_msg = cast("str", result) # Error message from pickling attempt + + # Handle different container types + if isinstance(obj, dict): + return PicklePatcher.handle_dict(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + if isinstance(obj, (list, tuple, set)): + return PicklePatcher.handle_sequence(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + if hasattr(obj, "__dict__"): + result = PicklePatcher.handle_object(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + + # If this was a failure, add the type to the cache + unpickled = dill.loads(result) + if isinstance(unpickled, PicklePlaceholder): + PicklePatcher.unpicklable_types.add(obj_type) + return result + + # For other unpicklable objects, use a placeholder + placeholder = PicklePatcher.create_placeholder(obj, error_msg, path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + @staticmethod + def handle_dict( + obj_dict: dict[Any, Any], + max_depth: int, + error_msg: str, # noqa: ARG004 + path: list[str], + protocol: int | None = None, + **kwargs: Any, + ) -> bytes: + """Handle pickling for dictionary objects. + + Args: + ---- + obj_dict: The dictionary to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + bytes: Pickled data with placeholders for unpicklable objects + + """ + if not isinstance(obj_dict, dict): + placeholder = PicklePatcher.create_placeholder( + obj_dict, f"Expected a dictionary, got {type(obj_dict).__name__}", path + ) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + result = {} + + for key, value in obj_dict.items(): + # Process the key + key_success, key_result = PicklePatcher.pickle(key, path, protocol, **kwargs) + if key_success: + key_result = key + else: + # If the key can't be pickled, use a string representation + try: + key_str = str(key)[:50] + except: # noqa: E722 + key_str = f"" + key_result = f"" + + # Process the value + value_path = [*path, f"[{repr(key)[:20]}]"] + value_success, value_bytes = PicklePatcher.pickle(value, value_path, protocol, **kwargs) + + if value_success: + value_result = value + else: + # Try recursive pickling for the value + try: + value_bytes = PicklePatcher.recursive_pickle( + value, max_depth - 1, value_path, protocol=protocol, **kwargs + ) + value_result = dill.loads(value_bytes) + except Exception as inner_e: + value_result = PicklePatcher.create_placeholder(value, str(inner_e), value_path) + + result[key_result] = value_result + + return dill.dumps(result, protocol=protocol, **kwargs) + + @staticmethod + def handle_sequence( + obj_seq: list[Any] | tuple[Any, ...] | set[Any], + max_depth: int, + error_msg: str, # noqa: ARG004 + path: list[str], + protocol: int | None = None, + **kwargs: Any, + ) -> bytes: + """Handle pickling for sequence types (list, tuple, set). + + Args: + ---- + obj_seq: The sequence to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + bytes: Pickled data with placeholders for unpicklable objects + + """ + result: list[Any] = [] + + for i, item in enumerate(obj_seq): + item_path = [*path, f"[{i}]"] + + # Try to pickle the item directly + success, _ = PicklePatcher.pickle(item, item_path, protocol, **kwargs) + if success: + result.append(item) + continue + + # If we couldn't pickle directly, try recursively + try: + item_bytes = PicklePatcher.recursive_pickle(item, max_depth - 1, item_path, protocol=protocol, **kwargs) + result.append(dill.loads(item_bytes)) + except Exception as inner_e: + # If recursive pickling fails, use a placeholder + placeholder = PicklePatcher.create_placeholder(item, str(inner_e), item_path) + result.append(placeholder) + + # Convert back to the original type + if isinstance(obj_seq, tuple): + result_final: list[Any] | tuple[Any, ...] | set[Any] = tuple(result) + elif isinstance(obj_seq, set): + # Try to create a set from the result + result_final = result + with contextlib.suppress(Exception): + result_final = set(result) + else: + result_final = result + + return dill.dumps(result_final, protocol=protocol, **kwargs) + + @staticmethod + def handle_object( + obj: object, max_depth: int, error_msg: str, path: list[str], protocol: int | None = None, **kwargs: Any + ) -> bytes: + """Handle pickling for custom objects with __dict__. + + Args: + ---- + obj: The object to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + ------- + bytes: Pickled data with placeholders for unpicklable objects + + """ + # Try to create a new instance of the same class + try: + # First try to create an empty instance + new_obj = object.__new__(type(obj)) + + # Handle __dict__ attributes if they exist + if hasattr(obj, "__dict__"): + obj_dict = obj.__dict__ + assert isinstance(obj_dict, dict) + for attr_name, attr_value in obj_dict.items(): + assert isinstance(attr_name, str) + attr_path: list[str] = [*path, attr_name] + + # Try to pickle directly first + success, _ = PicklePatcher.pickle(attr_value, attr_path, protocol, **kwargs) + if success: + setattr(new_obj, attr_name, attr_value) + continue + + # If direct pickling fails, try recursive pickling + try: + attr_bytes = PicklePatcher.recursive_pickle( + attr_value, max_depth - 1, attr_path, protocol=protocol, **kwargs + ) + setattr(new_obj, attr_name, dill.loads(attr_bytes)) + except Exception as inner_e: + # Use placeholder for unpicklable attribute + placeholder = PicklePatcher.create_placeholder(attr_value, str(inner_e), attr_path) + setattr(new_obj, attr_name, placeholder) + + # Try to pickle the patched object + success, result = PicklePatcher.pickle(new_obj, path, protocol, **kwargs) + if success: + assert isinstance(result, bytes) + return result + # Fall through to placeholder creation + except Exception: + pass # Fall through to placeholder creation + + # If we get here, just use a placeholder + placeholder = PicklePatcher.create_placeholder(obj, error_msg, path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) diff --git a/src/codeflash_python/picklepatch/pickle_placeholder.py b/src/codeflash_python/picklepatch/pickle_placeholder.py new file mode 100644 index 000000000..1bf1b691c --- /dev/null +++ b/src/codeflash_python/picklepatch/pickle_placeholder.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any + + +class PicklePlaceholderAccessError(Exception): + """Custom exception raised when attempting to access an unpicklable object.""" + + +class PicklePlaceholder: + """A placeholder for an object that couldn't be pickled. + + When unpickled, any attempt to access attributes or call methods on this + placeholder will raise a PicklePlaceholderAccessError. + """ + + def __init__(self, obj_type: str, obj_str: str, error_msg: str, path: list[str] | None = None) -> None: + # Store these directly in __dict__ to avoid __getattr__ recursion + self.__dict__["obj_type"] = obj_type + self.__dict__["obj_str"] = obj_str + self.__dict__["error_msg"] = error_msg + self.__dict__["path"] = path if path is not None else [] + + def __getattr__(self, name) -> Any: + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" + msg = ( + f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. " + f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" + ) + raise PicklePlaceholderAccessError(msg) + + def __setattr__(self, name: str, value: Any) -> None: + self.__getattr__(name) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" + msg = ( + f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. " + f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" + ) + raise PicklePlaceholderAccessError(msg) + + def __repr__(self) -> str: + try: + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root" + return f"" + except: # noqa: E722 + return "" + + def __str__(self) -> str: + return self.__repr__() + + def __reduce__(self) -> tuple: + return ( + PicklePlaceholder, + (self.__dict__["obj_type"], self.__dict__["obj_str"], self.__dict__["error_msg"], self.__dict__["path"]), + ) diff --git a/src/codeflash_python/plugin.py b/src/codeflash_python/plugin.py new file mode 100644 index 000000000..e6be25fd3 --- /dev/null +++ b/src/codeflash_python/plugin.py @@ -0,0 +1,594 @@ +"""PythonPlugin — adapter wiring codeflash_python to the codeflash_core LanguagePlugin protocol.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.models import BenchmarkResults, CodeContext, TestOutcome, TestOutcomeStatus, TestResults +from codeflash_python.plugin_ai_ops import PluginAiOpsMixin +from codeflash_python.plugin_helpers import ( + format_code_with_ruff_or_black, + make_test_env, + read_return_values, + replace_function_simple, +) +from codeflash_python.plugin_results import PluginResultsMixin +from codeflash_python.plugin_test_lifecycle import PluginTestLifecycleMixin +from codeflash_python.verification.test_runner import run_tests + +if TYPE_CHECKING: + import threading + + from codeflash_core.config import TestConfig + from codeflash_core.models import CoverageData, FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.context.call_graph_index import CallGraphIndex + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Plugin +# --------------------------------------------------------------------------- + + +class PythonPlugin(PluginAiOpsMixin, PluginTestLifecycleMixin, PluginResultsMixin): # type: ignore[cyclic-class-definition] + """Implements the codeflash_core LanguagePlugin protocol for Python. + + Converts between core types and internal types at the boundary. + """ + + def __init__(self, project_root: Path) -> None: + self.project_root = project_root + self.last_internal_context = None # cache for get_candidates + self.current_function: FunctionToOptimize | None = None # cache for coverage + self.tests_project_rootdir: Path | None = None # cached from test_config + self.is_numerical_code: bool | None = None # cached from generate_tests + self.ai_client = None + self.pending_code_markdown: str = "" # set by optimizer before replace_function + self.cancel_event: threading.Event | None = None # set by optimizer for cooperative cancellation + self.call_graph_index = None + self.dependency_counts: dict[str, int] = {} + + def is_cancelled(self) -> bool: + return self.cancel_event is not None and self.cancel_event.is_set() + + def get_ai_client(self) -> AiServiceClient: + if self.ai_client is None: + from codeflash_python.api.aiservice import AiServiceClient + + self.ai_client = AiServiceClient() + return self.ai_client + + # -- cleanup, comparison, environment validation -------------------------- + + def cleanup_run(self, tests_root: Path) -> None: + import contextlib + import shutil + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.optimization.optimizer import Optimizer as PyOptimizer + + # Remove leftover instrumented test files + if tests_root.exists(): + leftover = PyOptimizer.find_leftover_instrumented_test_files(tests_root) + for p in leftover: + with contextlib.suppress(OSError): + p.unlink(missing_ok=True) + + # Remove leftover return-value files (indices 0-30 match max_total in evaluate_candidates) + for i in range(31): + with contextlib.suppress(OSError): + get_run_tmp_file(Path(f"test_return_values_{i}.bin")).unlink(missing_ok=True) + with contextlib.suppress(OSError): + get_run_tmp_file(Path(f"test_return_values_{i}.sqlite")).unlink(missing_ok=True) + + # Remove the shared temp directory + if hasattr(get_run_tmp_file, "tmpdir_path"): + shutil.rmtree(get_run_tmp_file.tmpdir_path, ignore_errors=True) + del get_run_tmp_file.tmpdir_path + + # Close the call graph index DB connection + if self.call_graph_index and self.call_graph_index is not False: + self.call_graph_index.close() + self.call_graph_index = None + + def compare_outputs(self, baseline_output: object, candidate_output: object) -> bool: + from codeflash_python.verification.comparator import comparator + + return comparator(baseline_output, candidate_output) + + def validate_environment(self, config: object) -> bool: + from codeflash_python.code_utils.env_utils import check_formatter_installed + + if hasattr(config, "formatter_cmds") and config.formatter_cmds: + return check_formatter_installed(config.formatter_cmds) # type: ignore[arg-type] + return True + + # -- discover_functions -------------------------------------------------- + + def discover_functions(self, paths: list[Path]) -> list[FunctionToOptimize]: + results: list[FunctionToOptimize] = [] + for path in paths: + try: + source = path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError) as exc: + logger.warning("Skipping %s: %s", path, exc) + continue + + try: + from codeflash_python.discovery.function_visitors import discover_functions + + internal_fns = discover_functions(source, path) + except Exception as exc: + logger.warning("Skipping %s: failed to parse (%s)", path, exc) + continue + for fn in internal_fns: + # Attach source code so the core optimizer has it + lines = source.splitlines() + if fn.starting_line and fn.ending_line: + fn.source_code = "\n".join(lines[fn.starting_line - 1 : fn.ending_line]) + results.append(fn) + return results + + # -- build_index / rank_functions ----------------------------------------- + + def get_call_graph_index(self) -> CallGraphIndex | None: + if self.call_graph_index is None: + # Skip in CI — the cache DB doesn't persist between runs on ephemeral runners + if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"): + self.call_graph_index = False + return None + try: + from codeflash_python.context.call_graph_index import CallGraphIndex + + self.call_graph_index = CallGraphIndex(self.project_root) + except Exception: + logger.info("Failed to initialize CallGraphIndex, falling back to per-function Jedi analysis") + self.call_graph_index = False + return None + return self.call_graph_index if self.call_graph_index is not False else None + + def build_index(self, files: list[Path], on_progress=None) -> None: + graph = self.get_call_graph_index() + if graph is not None: + graph.build_index(files, on_progress=on_progress) + + def rank_functions( + self, + functions: list[FunctionToOptimize], + trace_file: Path | None = None, + test_counts: dict[tuple[Path, str], int] | None = None, + ) -> list[FunctionToOptimize]: + if not functions: + return functions + + # Primary: rank by trace-based addressable time (filters low-importance functions) + if trace_file and trace_file.exists(): + try: + from codeflash_python.benchmarking.function_ranker import FunctionRanker + + ranker = FunctionRanker(trace_file) + ranked = ranker.rank_functions(functions) + if test_counts: + ranked.sort( + key=lambda f: ( + -ranker.get_function_addressable_time(f), + -test_counts.get((f.file_path, f.qualified_name), 0), + ) + ) + logger.debug( + "Ranked %d functions by addressable time (filtered %d low-importance)", + len(ranked), + len(functions) - len(ranked), + ) + return ranked + except Exception: + logger.warning("Trace-based ranking failed, falling back to dependency count") + + # Fallback: rank by dependency count (most complex first) + graph = self.get_call_graph_index() + if graph is None: + return functions + from collections import defaultdict + + file_to_qns: dict[Path, set[str]] = defaultdict(set) + for func in functions: + file_to_qns[func.file_path].add(func.qualified_name) + callee_counts = graph.count_callees_per_function(dict(file_to_qns)) + self.dependency_counts = {qn: count for (_, qn), count in callee_counts.items()} + if test_counts: + ranked = sorted( + enumerate(functions), + key=lambda x: ( + -callee_counts.get((x[1].file_path, x[1].qualified_name), 0), + -test_counts.get((x[1].file_path, x[1].qualified_name), 0), + x[0], + ), + ) + else: + ranked = sorted( + enumerate(functions), key=lambda x: (-callee_counts.get((x[1].file_path, x[1].qualified_name), 0), x[0]) + ) + logger.debug("Ranked %d functions by dependency count (most complex first)", len(ranked)) + return [func for _, func in ranked] + + def get_dependency_counts(self) -> dict[str, int]: + return self.dependency_counts + + # -- extract_context ----------------------------------------------------- + + def extract_context(self, function: FunctionToOptimize) -> CodeContext: + from codeflash_python.context.code_context_extractor import get_code_optimization_context + from codeflash_python.context.types import function_sources_to_helpers + + internal_fn = function + ctx = get_code_optimization_context(internal_fn, self.project_root, call_graph=self.get_call_graph_index()) + self.last_internal_context = ctx + self.current_function = function + + helpers = function_sources_to_helpers(ctx.helper_functions) + + return CodeContext( + target_function=function, + target_code=ctx.read_writable_code.flat if ctx.read_writable_code else function.source_code, + target_file=function.file_path, + helper_functions=helpers, + read_only_context=ctx.read_only_context_code, + ) + + # -- run_tests ----------------------------------------------------------- + + def run_tests( + self, + test_config: TestConfig, + test_files: list[Path] | None = None, + test_iteration: int = 0, + enable_coverage: bool = False, + ) -> TestResults | tuple[TestResults, CoverageData | None]: + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + + if not files_to_run: + return TestResults(passed=True) + + # Clean up stale return-value files before this iteration (matches original) + from codeflash_python.code_utils.code_utils import get_run_tmp_file + + for ext in (".bin", ".sqlite"): + get_run_tmp_file(Path(f"test_return_values_{test_iteration}{ext}")).unlink(missing_ok=True) + + env = make_test_env(test_config.project_root, test_iteration=test_iteration) + timeout = int(test_config.timeout) + + results, _, cov_db, cov_config = run_tests( + test_files=files_to_run, + cwd=test_config.project_root, + env=env, + timeout=timeout, + enable_coverage=enable_coverage, + ) + + # Read return values from SQLite written by instrumented tests + return_values = read_return_values(test_iteration) + + outcomes = [] + for r in results: + # Match JUnit test name to SQLite test_function_name + # The pytest plugin strips parametrize brackets from CODEFLASH_TEST_FUNCTION + base_name = r.test_name.split("[", 1)[0] if "[" in r.test_name else r.test_name + ret_vals = return_values.get(base_name) + output = tuple(ret_vals) if ret_vals else None + + outcomes.append( + TestOutcome( + test_id=r.test_name, + status=TestOutcomeStatus.PASSED if r.passed else TestOutcomeStatus.FAILED, + duration=r.runtime_ns / 1e9 if r.runtime_ns else 0.0, + error_message=r.error_message or "", + output=output, + ) + ) + + test_results = TestResults(passed=all(r.passed for r in results), outcomes=outcomes, error=None) + + if enable_coverage: + coverage_data = self.load_coverage(cov_db, cov_config) + return test_results, coverage_data + + return test_results + + def load_coverage(self, cov_db: Path | None, cov_config: Path | None) -> CoverageData | None: + """Load coverage data from SQLite database and convert to core CoverageData.""" + if cov_db is None or cov_config is None: + return None + + function = self.current_function + code_context = self.last_internal_context + if function is None or code_context is None: + return None + + try: + from codeflash_core.models import CoverageData as CoreCoverageData + from codeflash_core.models import FunctionCoverage as CoreFunctionCoverage + from codeflash_python.verification.coverage_utils import CoverageUtils + + internal_cov = CoverageUtils.load_from_sqlite_database( + database_path=cov_db, + config_path=cov_config, + function_name=function.qualified_name, + code_context=code_context, + source_code_path=function.file_path, + ) + + main_fc = internal_cov.main_func_coverage + core_main = CoreFunctionCoverage( + name=main_fc.name, + coverage=main_fc.coverage, + executed_lines=list(main_fc.executed_lines), + unexecuted_lines=list(main_fc.unexecuted_lines), + executed_branches=list(main_fc.executed_branches), + unexecuted_branches=list(main_fc.unexecuted_branches), + ) + + core_dep = None + if internal_cov.dependent_func_coverage: + dep = internal_cov.dependent_func_coverage + core_dep = CoreFunctionCoverage( + name=dep.name, + coverage=dep.coverage, + executed_lines=list(dep.executed_lines), + unexecuted_lines=list(dep.unexecuted_lines), + executed_branches=list(dep.executed_branches), + unexecuted_branches=list(dep.unexecuted_branches), + ) + + from codeflash_python.code_utils.config_consts import COVERAGE_THRESHOLD + + return CoreCoverageData( + file_path=function.file_path, + coverage=internal_cov.coverage, + function_name=function.qualified_name, + main_func_coverage=core_main, + dependent_func_coverage=core_dep, + threshold_percentage=COVERAGE_THRESHOLD, + ) + except Exception: + logger.debug("Failed to load coverage data", exc_info=True) + return None + + # -- replace_function ---------------------------------------------------- + + def replace_function(self, file: Path, function: FunctionToOptimize, new_code: str) -> None: + internal_ctx = self.last_internal_context + code_markdown = self.pending_code_markdown + + if internal_ctx is not None and code_markdown: + try: + self.replace_function_full(function, internal_ctx, code_markdown) + return + except Exception: + logger.debug("Full replace_function failed, falling back to simple replacement", exc_info=True) + + # Fallback: simple single-file replacement + source = file.read_text(encoding="utf-8") + internal_fn = function + modified = replace_function_simple(source, internal_fn, new_code) + file.write_text(modified, encoding="utf-8") + + def replace_function_full(self, function: FunctionToOptimize, internal_ctx: object, code_markdown: str) -> None: + """Port of FunctionOptimizer.replace_function_and_helpers_with_optimized_code.""" + from collections import defaultdict + + from codeflash_python.context.unused_helper_detection import ( + detect_unused_helper_functions, + revert_unused_helper_functions, + ) + from codeflash_python.models.models import CodeStringsMarkdown + from codeflash_python.static_analysis.code_replacer import replace_function_definitions_in_module + + optimized_code = CodeStringsMarkdown.parse_markdown_code(code_markdown) + + internal_fn = function + + # Group functions by file (target + helpers where definition_type in ("function", None)) + functions_by_file: dict[Path, set[str]] = defaultdict(set) + functions_by_file[function.file_path].add(internal_fn.qualified_name) + for helper in internal_ctx.helper_functions: # type: ignore[attr-defined] + if helper.definition_type in ("function", None): + functions_by_file[helper.file_path].add(helper.qualified_name) + + # Capture original helper code for unused-helper revert + original_helper_code: dict[Path, str] = {} + for hp in functions_by_file: + if hp != function.file_path and hp.exists(): + original_helper_code[hp] = hp.read_text("utf-8") + + # Replace in each file + for module_abspath, qualified_names in functions_by_file.items(): + replace_function_definitions_in_module( + function_names=list(qualified_names), + optimized_code=optimized_code, + module_abspath=module_abspath, + preexisting_objects=internal_ctx.preexisting_objects, # type: ignore[attr-defined] + project_root_path=self.project_root, + ) + + # Detect and revert unused helpers + unused_helpers = detect_unused_helper_functions(internal_fn, internal_ctx, optimized_code) # type: ignore[arg-type] + if unused_helpers: + revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code) + + # -- restore_function ---------------------------------------------------- + + def restore_function(self, file: Path, function: FunctionToOptimize, original_code: str) -> None: + self.replace_function(file, function, original_code) + + # -- run_benchmarks ------------------------------------------------------ + + def run_benchmarks( + self, + function: FunctionToOptimize, + test_config: TestConfig, + test_files: list[Path] | None = None, + test_iteration: int = 0, + ) -> BenchmarkResults: + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + + if not files_to_run: + return BenchmarkResults() + + env = make_test_env(test_config.project_root, test_iteration=test_iteration) + timeout = int(test_config.timeout) + + results, *_ = run_tests( + test_files=files_to_run, + cwd=test_config.project_root, + env=env, + timeout=timeout, + min_loops=5, + max_loops=100_000, + target_seconds=10.0, + stability_check=True, + ) + + timings: dict[str, float] = {} + total = 0.0 + for r in results: + if r.runtime_ns: + secs = r.runtime_ns / 1e9 + timings[r.test_name] = secs + total += secs + + return BenchmarkResults(timings=timings, total_time=total) + + # -- format_code --------------------------------------------------------- + + def format_code(self, code: str, file: Path) -> str: + return format_code_with_ruff_or_black(code, file) + + def validate_candidate(self, code: str) -> bool: + import ast + + try: + ast.parse(code) + return True + except SyntaxError: + return False + + def normalize_code(self, code: str) -> str: + from codeflash_python.normalizer import normalize_python_code + + try: + return normalize_python_code(code, remove_docstrings=True) + except Exception: + return code + + # ======================================================================== + # Phase 2: Split behavioral / performance test running + # ======================================================================== + + def run_behavioral_tests(self, test_files: list[Path], test_config: TestConfig) -> TestResults: + result = self.run_tests(test_config, test_files=test_files) + if isinstance(result, tuple): + return result[0] # type: ignore[return-value] + return result + + def run_performance_tests( + self, test_files: list[Path], function: FunctionToOptimize, test_config: TestConfig + ) -> BenchmarkResults: + return self.run_benchmarks(function, test_config, test_files=test_files) + + # ======================================================================== + # Phase 3: Line profiler (stays here — uses run_tests directly) + # ======================================================================== + + def run_line_profiler( + self, function: FunctionToOptimize, test_config: TestConfig, test_files: list[Path] | None = None + ) -> str: + """Run line profiler on the target function and return formatted output. + + Returns empty string if profiling fails or is not applicable. + """ + from codeflash_python.benchmarking.parse_line_profile_test_output import parse_line_profile_results + from codeflash_python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator + + internal_fn = function + code_context = self.last_internal_context + if code_context is None: + logger.warning("No code context available for line profiler") + return "" + + # Read original source of function file + helper files for restore + original_sources: dict[Path, str] = {} + try: + original_sources[function.file_path] = function.file_path.read_text("utf-8") + except (OSError, UnicodeDecodeError): + logger.warning("Cannot read function file %s for line profiler", function.file_path) + return "" + + # Check JIT decorators in function file + if contains_jit_decorator(original_sources[function.file_path]): + logger.info("Skipping line profiler for %s - code contains JIT decorator", function.function_name) + return "" + + # Save and check helper file sources + for helper in code_context.helper_functions: + hp = helper.file_path + if hp not in original_sources: + try: + content = hp.read_text("utf-8") + except (OSError, UnicodeDecodeError): + continue + original_sources[hp] = content + if contains_jit_decorator(content): + logger.info( + "Skipping line profiler for %s - helper code contains JIT decorator", function.function_name + ) + return "" + + # Determine test files + if test_files is not None: + files_to_run = test_files + else: + files_to_run = sorted(test_config.tests_root.rglob("test_*.py")) + if not files_to_run: + files_to_run = sorted(test_config.tests_root.rglob("*_test.py")) + if not files_to_run: + return "" + + try: + # Inject line profiler decorators and imports into function + helper files + lprof_output_file = add_decorator_imports(internal_fn, code_context) + + # Run tests with LINE_PROFILE=1 env var + env = make_test_env(test_config.project_root, test_iteration=0) + env["LINE_PROFILE"] = "1" + + run_tests(test_files=files_to_run, cwd=test_config.project_root, env=env, timeout=int(test_config.timeout)) + + # Parse line profiler results from .lprof file + results, _ = parse_line_profile_results(lprof_output_file) + return results.get("str_out", "") + except Exception: + logger.debug("Line profiler failed for %s", function.function_name, exc_info=True) + return "" + finally: + # Restore original source files + for file_path, content in original_sources.items(): + try: + file_path.write_text(content, "utf-8") + except OSError: + logger.warning("Failed to restore %s after line profiler", file_path) diff --git a/src/codeflash_python/plugin_ai_ops.py b/src/codeflash_python/plugin_ai_ops.py new file mode 100644 index 000000000..280effac1 --- /dev/null +++ b/src/codeflash_python/plugin_ai_ops.py @@ -0,0 +1,242 @@ +"""Mixin: AI candidate generation, repair, refinement, adaptive optimization.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_core.models import Candidate +from codeflash_python.plugin_helpers import format_speedup_pct, map_candidate_source + +if TYPE_CHECKING: + from codeflash_core.models import BenchmarkResults, CodeContext, ScoredCandidate, TestDiff + from codeflash_python.plugin import PythonPlugin as _Base +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginAiOpsMixin(_Base): # type: ignore[cyclic-class-definition] + def get_candidates(self, context: CodeContext, trace_id: str = "") -> list[Candidate]: + client = self.get_ai_client() + assert trace_id, "trace_id must be provided" + + # Use cached internal context for markdown-formatted code (what the API expects) + internal_ctx = self.last_internal_context + if internal_ctx is not None: + source_code = internal_ctx.read_writable_code.markdown + dependency_code = internal_ctx.read_only_context_code + else: + source_code = context.target_code + dependency_code = context.read_only_context + + optimized = client.optimize_code( + source_code=source_code, + dependency_code=dependency_code, + trace_id=trace_id, + language="python", + is_numerical_code=self.is_numerical_code, + ) + + candidates = [] + for opt in optimized: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate(code=code, explanation=opt.explanation or "", source="optimize", code_markdown=code_md) + ) + return candidates + + def get_line_profiler_candidates( + self, context: CodeContext, line_profile_data: str, trace_id: str = "" + ) -> list[Candidate]: + assert trace_id, "trace_id must be provided" + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for line profiler") + return [] + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + dependency_code = internal_ctx.read_only_context_code if internal_ctx else context.read_only_context + + optimized = client.optimize_python_code_line_profiler( + source_code=source_code, + dependency_code=dependency_code, + trace_id=trace_id, + line_profiler_results=line_profile_data, + n_candidates=3, + ) + + candidates = [] + for opt in optimized: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate( + code=code, explanation=opt.explanation or "", source="line_profiler", code_markdown=code_md + ) + ) + return candidates + + def repair_candidate( + self, context: CodeContext, candidate: Candidate, test_diffs: list[TestDiff], trace_id: str = "" + ) -> Candidate | None: + assert trace_id, "trace_id must be provided" + from codeflash_python.api.types import AIServiceCodeRepairRequest, TestDiffScope + from codeflash_python.api.types import TestDiff as InternalTestDiff + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for repair") + return None + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + modified_code = candidate.code_markdown or candidate.code + + internal_diffs = [ + InternalTestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_pass=True, + candidate_pass=False, + original_value=str(d.baseline_output) if d.baseline_output is not None else None, + candidate_value=str(d.candidate_output) if d.candidate_output is not None else None, + ) + for d in test_diffs + ] + + request = AIServiceCodeRepairRequest( + optimization_id=candidate.candidate_id, + original_source_code=source_code, + modified_source_code=modified_code, + trace_id=trace_id, + test_diffs=internal_diffs, + ) + + try: + result = client.code_repair(request) + except Exception: + logger.exception("Code repair API call failed") + return None + + if result is None: + return None + + code = result.source_code.flat if result.source_code else "" + code_md = result.source_code.markdown if result.source_code else "" + if not code: + return None + + return Candidate( + code=code, + explanation=result.explanation or "", + source="repair", + parent_id=candidate.candidate_id, + code_markdown=code_md, + ) + + def refine_candidate( + self, context: CodeContext, candidate: ScoredCandidate, baseline_bench: BenchmarkResults, trace_id: str = "" + ) -> list[Candidate]: + assert trace_id, "trace_id must be provided" + from codeflash_python.api.types import AIServiceRefinerRequest + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for refinement") + return [] + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + dependency_code = internal_ctx.read_only_context_code if internal_ctx else context.read_only_context + optimized_code = candidate.candidate.code_markdown or candidate.candidate.code + + request = AIServiceRefinerRequest( + optimization_id=candidate.candidate.candidate_id, + original_source_code=source_code, + read_only_dependency_code=dependency_code, + original_code_runtime=int(baseline_bench.total_time * 1e9), + optimized_source_code=optimized_code, + optimized_explanation=candidate.candidate.explanation, + optimized_code_runtime=int(candidate.benchmark_results.total_time * 1e9), + speedup=format_speedup_pct(candidate.speedup), + trace_id=trace_id, + original_line_profiler_results="", + optimized_line_profiler_results="", + ) + + try: + results = client.optimize_code_refinement([request]) + except Exception: + logger.exception("Code refinement API call failed") + return [] + + candidates = [] + for opt in results: + code = opt.source_code.flat if opt.source_code else "" + code_md = opt.source_code.markdown if opt.source_code else "" + if code: + candidates.append( + Candidate( + code=code, + explanation=opt.explanation or "", + source="refine", + parent_id=candidate.candidate.candidate_id, + code_markdown=code_md, + ) + ) + return candidates + + def adaptive_optimize( + self, context: CodeContext, scored: list[ScoredCandidate], trace_id: str = "" + ) -> Candidate | None: + assert trace_id, "trace_id must be provided" + from codeflash_python.api.types import AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for adaptive optimization") + return None + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.flat if internal_ctx else context.target_code + + adaptive_candidates = [ + AdaptiveOptimizedCandidate( + optimization_id=sc.candidate.candidate_id, + source_code=sc.candidate.code, + explanation=sc.candidate.explanation, + source=map_candidate_source(sc.candidate.source), + speedup=f"Performance gain: {int(sc.speedup * 100 + 0.5)}%" + if sc.speedup > 0 + else "Candidate didn't match the behavior of the original code", + ) + for sc in scored + ] + + request = AIServiceAdaptiveOptimizeRequest( + trace_id=trace_id, original_source_code=source_code, candidates=adaptive_candidates + ) + + try: + result = client.adaptive_optimize(request) + except Exception: + logger.exception("Adaptive optimization API call failed") + return None + + if result is None: + return None + + code = result.source_code.flat if result.source_code else "" + if not code: + return None + + return Candidate(code=code, explanation=result.explanation or "", source="adaptive") diff --git a/src/codeflash_python/plugin_helpers.py b/src/codeflash_python/plugin_helpers.py new file mode 100644 index 000000000..15da19e06 --- /dev/null +++ b/src/codeflash_python/plugin_helpers.py @@ -0,0 +1,167 @@ +"""Standalone helper functions used by PythonPlugin methods.""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + from codeflash_core.models import CoverageData, FunctionToOptimize + from codeflash_python.models.models import OptimizedCandidateSource + +logger = logging.getLogger(__name__) + + +def make_test_env( + project_root: Path | str, *, loop_index: int = 0, test_iteration: int = 0, tracer_disable: int = 1 +) -> dict[str, str]: + """Return a copy of os.environ configured for running codeflash tests. + + Matches original codeflash get_test_env(): prepends project_root to PYTHONPATH + and sets CODEFLASH_* env vars expected by instrumented test harness. + """ + env = os.environ.copy() + project_root_str = str(project_root) + pythonpath = env.get("PYTHONPATH", "") + if pythonpath: + env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}" + else: + env["PYTHONPATH"] = project_root_str + env["CODEFLASH_LOOP_INDEX"] = str(loop_index) + env["CODEFLASH_TEST_ITERATION"] = str(test_iteration) + env["CODEFLASH_TRACER_DISABLE"] = str(tracer_disable) + return env + + +def format_speedup_pct(speedup: float) -> str: + """Format speedup as percentage string matching original codeflash API format.""" + return f"{int(speedup * 100)}%" + + +def read_return_values(test_iteration: int) -> dict[str, list[object]]: + """Read return values from the SQLite file written by instrumented tests. + + Returns a dict mapping test_function_name -> list of deserialized return values. + Only reads rows with loop_index == 1 (first timing iteration), matching original behavior. + """ + import pickle + import sqlite3 + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + + sqlite_path = get_run_tmp_file(Path(f"test_return_values_{test_iteration}.sqlite")) + if not sqlite_path.exists(): + return {} + + result: dict[str, list[object]] = {} + db = None + try: + db = sqlite3.connect(sqlite_path) + rows = db.execute("SELECT test_function_name, loop_index, return_value FROM test_results").fetchall() + db.close() + db = None + + for test_fn_name, loop_index, return_value_blob in rows: + if loop_index != 1 or not return_value_blob or not test_fn_name: + continue + try: + ret_val = pickle.loads(return_value_blob) + result.setdefault(test_fn_name, []).append(ret_val) + except Exception: + logger.debug("Failed to deserialize return value for %s", test_fn_name) + except Exception: + logger.debug("Failed to read return values from %s", sqlite_path) + finally: + if db is not None: + db.close() + + return result + + +def map_candidate_source(source: str) -> OptimizedCandidateSource: + """Map core Candidate.source string to OptimizedCandidateSource enum value.""" + from codeflash_python.models.models import OptimizedCandidateSource + + mapping = { + "optimize": OptimizedCandidateSource.OPTIMIZE, + "line_profiler": OptimizedCandidateSource.OPTIMIZE_LP, + "refine": OptimizedCandidateSource.REFINE, + "repair": OptimizedCandidateSource.REPAIR, + "adaptive": OptimizedCandidateSource.ADAPTIVE, + } + return mapping.get(source, OptimizedCandidateSource.OPTIMIZE) + + +def coverage_data_to_details_dict(cov_data: CoverageData) -> dict[str, Any]: + """Convert CoverageData to the dict format expected by the repair API.""" + mc = cov_data.main_func_coverage + details: dict[str, Any] = { + "coverage_percentage": cov_data.coverage, + "threshold_percentage": cov_data.threshold_percentage, + "main_function": { + "name": mc.name, + "coverage": mc.coverage, + "executed_lines": sorted(mc.executed_lines), + "unexecuted_lines": sorted(mc.unexecuted_lines), + "executed_branches": mc.executed_branches, + "unexecuted_branches": mc.unexecuted_branches, + }, + } + dc = cov_data.dependent_func_coverage + if dc: + details["dependent_function"] = { + "name": dc.name, + "coverage": dc.coverage, + "executed_lines": sorted(dc.executed_lines), + "unexecuted_lines": sorted(dc.unexecuted_lines), + "executed_branches": dc.executed_branches, + "unexecuted_branches": dc.unexecuted_branches, + } + return details + + +def replace_function_simple(source: str, function: FunctionToOptimize, new_source: str) -> str: + from codeflash_python.static_analysis.code_replacer import replace_functions_in_file + + try: + return replace_functions_in_file( + source_code=source, + original_function_names=[function.qualified_name], + optimized_code=new_source, + preexisting_objects=set(), + ) + except Exception: + logger.warning("Failed to replace function %s", function.function_name) + return source + + +def format_code_with_ruff_or_black(source: str, file_path: Path | None = None) -> str: + import subprocess + + try: + result = subprocess.run( + ["ruff", "format", "-"], check=False, input=source, capture_output=True, text=True, timeout=30 + ) + if result.returncode == 0: + return result.stdout + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + except Exception: + pass + + try: + result = subprocess.run( + ["black", "-q", "-"], check=False, input=source, capture_output=True, text=True, timeout=30 + ) + if result.returncode == 0: + return result.stdout + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + except Exception: + pass + + return source diff --git a/src/codeflash_python/plugin_results.py b/src/codeflash_python/plugin_results.py new file mode 100644 index 000000000..d436d66c4 --- /dev/null +++ b/src/codeflash_python/plugin_results.py @@ -0,0 +1,177 @@ +"""Mixin: ranking, explanation, PR creation, result logging.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.plugin_helpers import format_speedup_pct, replace_function_simple + +if TYPE_CHECKING: + from codeflash_core.models import CodeContext, GeneratedTestSuite, OptimizationResult, ScoredCandidate + from codeflash_python.plugin import PythonPlugin as _Base +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginResultsMixin(_Base): # type: ignore[cyclic-class-definition] + def rank_candidates( + self, scored: list[ScoredCandidate], context: CodeContext, trace_id: str = "" + ) -> list[int] | None: + assert trace_id, "trace_id must be provided" + from codeflash_core.diff import unified_diff + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for ranking") + return None + + diffs = [unified_diff(context.target_code, sc.candidate.code, context.target_file) for sc in scored] + optimization_ids = [sc.candidate.candidate_id for sc in scored] + speedups = [sc.speedup for sc in scored] + + try: + return client.generate_ranking( + trace_id=trace_id, diffs=diffs, optimization_ids=optimization_ids, speedups=speedups + ) + except Exception: + logger.exception("Ranking API call failed") + return None + + def generate_explanation( + self, result: OptimizationResult, context: CodeContext, trace_id: str = "", annotated_tests: str = "" + ) -> str: + assert trace_id, "trace_id must be provided" + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for explanation") + return "" + + # Convert runtimes to nanoseconds and humanize, matching original + optimized_ns = int(result.benchmark_results.total_time * 1e9) + baseline_ns = int(optimized_ns * result.speedup) if result.speedup > 0 else 0 + + try: + return client.get_new_explanation( + source_code=context.target_code, + optimized_code=result.optimized_code, + dependency_code=context.read_only_context, + trace_id=trace_id, + original_line_profiler_results="", + optimized_line_profiler_results="", + original_code_runtime=humanize_runtime(baseline_ns), + optimized_code_runtime=humanize_runtime(optimized_ns), + speedup=format_speedup_pct(result.speedup), + annotated_tests=annotated_tests, + optimization_id=result.candidate.candidate_id, + original_explanation=result.candidate.explanation, + ) + except Exception: + logger.exception("Explanation generation API call failed") + return "" + + def create_pr( + self, + result: OptimizationResult, + context: CodeContext, + trace_id: str = "", + generated_tests: GeneratedTestSuite | None = None, + ) -> str | None: + from codeflash_python.models.models import TestResults as InternalTestResults + from codeflash_python.result.create_pr import check_create_pr + from codeflash_python.result.explanation import Explanation + + try: + # Build original_code: file with original function (optimizer restores before returning) + original_code = {context.target_file: context.target_file.read_text("utf-8")} + + # Build new_code: file with optimized function applied in memory + original_source = original_code[context.target_file] + internal_fn = context.target_function + new_source = replace_function_simple(original_source, internal_fn, result.optimized_code) + new_code = {context.target_file: new_source} + + # Build Explanation from optimization result + # Use empty internal TestResults since PR comment uses runtime/speedup fields directly + optimized_ns = int(result.benchmark_results.total_time * 1e9) + baseline_ns = int(optimized_ns * result.speedup) if result.speedup > 0 else 0 + + explanation = Explanation( + raw_explanation_message=result.explanation or result.candidate.explanation, + winning_behavior_test_results=InternalTestResults(), + winning_benchmarking_test_results=InternalTestResults(), + original_runtime_ns=baseline_ns, + best_runtime_ns=optimized_ns, + function_name=context.target_function.qualified_name, + file_path=context.target_file, + ) + + # Collect generated test source + generated_tests_str = "" + if generated_tests and generated_tests.test_files: + generated_tests_str = "\n\n".join( + tf.original_test_source for tf in generated_tests.test_files if tf.original_test_source + ) + + check_create_pr( + original_code=original_code, + new_code=new_code, + explanation=explanation, + existing_tests_source="", + generated_original_test_source=generated_tests_str, + function_trace_id=trace_id, + coverage_message="", + replay_tests="", + root_dir=self.project_root, + git_remote=None, + ) + except Exception: + logger.exception("PR creation failed") + return None + else: + return None + + def log_results( + self, + result: OptimizationResult, + trace_id: str, + all_speedups: dict[str, float] | None = None, + all_runtimes: dict[str, float] | None = None, + all_correct: dict[str, bool] | None = None, + ) -> None: + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for logging") + return + + # Use accumulated all-candidate data if available, otherwise fall back to winner-only + speedup_ratios = all_speedups or {result.candidate.candidate_id: result.speedup} + is_correct = all_correct or {result.candidate.candidate_id: result.test_results.passed} + + # Convert runtimes from seconds to nanoseconds (matching original API contract) + if all_runtimes: + optimized_runtimes = {cid: int(t * 1e9) for cid, t in all_runtimes.items()} + else: + optimized_runtimes = {result.candidate.candidate_id: int(result.benchmark_results.total_time * 1e9)} + + baseline_ns = int(result.benchmark_results.total_time * 1e9 * result.speedup) if result.speedup > 0 else None + + try: + client.log_results( + function_trace_id=trace_id, + speedup_ratio=speedup_ratios, + original_runtime=baseline_ns, + optimized_runtime=optimized_runtimes, + is_correct=is_correct, + optimized_line_profiler_results=None, + metadata={"best_optimization_id": result.candidate.candidate_id}, + ) + except Exception: + logger.exception("Result logging API call failed") diff --git a/src/codeflash_python/plugin_test_lifecycle.py b/src/codeflash_python/plugin_test_lifecycle.py new file mode 100644 index 000000000..61a813283 --- /dev/null +++ b/src/codeflash_python/plugin_test_lifecycle.py @@ -0,0 +1,267 @@ +"""Mixin: test generation, review, and repair.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.models import ( + GeneratedTestFile, + GeneratedTestSuite, + TestOutcomeStatus, + TestRepairInfo, + TestReviewResult, +) +from codeflash_python.plugin_helpers import coverage_data_to_details_dict +from codeflash_python.verification.test_runner import process_generated_test_strings + +if TYPE_CHECKING: + from codeflash_core.config import TestConfig + from codeflash_core.models import CodeContext, CoverageData, FunctionToOptimize, TestResults + from codeflash_python.plugin import PythonPlugin as _Base +else: + _Base = object + +logger = logging.getLogger(__name__) + + +class PluginTestLifecycleMixin(_Base): # type: ignore[cyclic-class-definition] + def generate_tests( + self, function: FunctionToOptimize, context: CodeContext, test_config: TestConfig, trace_id: str = "" + ) -> GeneratedTestSuite | None: + from codeflash_python.code_utils.code_utils import module_name_from_file_path + from codeflash_python.verification.verification_utils import get_test_file_path + from codeflash_python.verification.verifier import generate_tests as _generate_tests + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test generation") + return None + + assert trace_id, "trace_id must be provided" + internal_fn = function + + internal_ctx = self.last_internal_context + source_code = internal_ctx.read_writable_code.markdown if internal_ctx else context.target_code + + # Compute is_numerical_code matching original analyze_code_characteristics + flat_code = internal_ctx.read_writable_code.flat if internal_ctx else context.target_code + try: + from codeflash_python.static_analysis.numerical_detection import is_numerical_code as _is_numerical_code + + numerical = _is_numerical_code(code_string=flat_code) + except Exception: + numerical = None + + self.is_numerical_code = numerical + + # Cache tests_project_rootdir for use in repair_generated_tests + tests_project_rootdir = test_config.tests_project_rootdir or test_config.project_root + self.tests_project_rootdir = tests_project_rootdir + + module_path = Path(module_name_from_file_path(function.file_path, test_config.project_root)) + helper_names = [h.qualified_name for h in context.helper_functions] + + test_dir = test_config.tests_root + test_dir.mkdir(parents=True, exist_ok=True) + + test_files: list[GeneratedTestFile] = [] + num_tests = 2 + + for i in range(num_tests): + behavior_path = get_test_file_path(test_dir, function.function_name, iteration=i, test_type="unit") + perf_path = get_test_file_path(test_dir, function.function_name, iteration=i, test_type="perf") + + try: + result = _generate_tests( + aiservice_client=client, + source_code_being_tested=source_code, + function_to_optimize=internal_fn, + helper_function_names=helper_names, + module_path=module_path, + test_cfg_project_root=tests_project_rootdir, + test_timeout=int(test_config.timeout), + function_trace_id=trace_id, + test_index=i, + test_path=behavior_path, + test_perf_path=perf_path, + is_numerical_code=numerical, + ) + except Exception: + logger.exception("Failed to generate test %d for %s", i, function.qualified_name) + continue + + if result is None: + continue + + gen_source, behavior_source, perf_source, _raw, _, _ = result + + # Write test files to disk + behavior_path.parent.mkdir(parents=True, exist_ok=True) + behavior_path.write_text(behavior_source, encoding="utf-8") + perf_path.parent.mkdir(parents=True, exist_ok=True) + perf_path.write_text(perf_source, encoding="utf-8") + + test_files.append( + GeneratedTestFile( + behavior_test_path=behavior_path, + perf_test_path=perf_path, + behavior_test_source=behavior_source, + perf_test_source=perf_source, + original_test_source=gen_source, + ) + ) + + if not test_files: + return None + + return GeneratedTestSuite(test_files=test_files) + + def review_generated_tests( + self, suite: GeneratedTestSuite, context: CodeContext, test_results: TestResults, trace_id: str = "" + ) -> list[TestReviewResult]: + assert trace_id, "trace_id must be provided" + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test review") + return [] + + # Collect failing test function names and error messages from test results + failed_test_functions: list[str] = [] + failure_messages: dict[str, str] = {} + for outcome in test_results.outcomes: + if outcome.status != TestOutcomeStatus.PASSED: + failed_test_functions.append(outcome.test_id) + if outcome.error_message: + failure_messages[outcome.test_id] = outcome.error_message + + tests_data = [ + { + "test_index": i, + "test_source": tf.original_test_source, + "failed_test_functions": failed_test_functions, + "failure_messages": failure_messages, + } + for i, tf in enumerate(suite.test_files) + ] + + try: + reviews = client.review_generated_tests( + tests=tests_data, + function_source_code=context.target_code, + function_name=context.target_function.function_name, + trace_id=trace_id, + language="python", + ) + except Exception: + logger.exception("Test review API call failed") + return [] + + return [ + TestReviewResult( + test_index=r.test_index, + functions_to_repair=[ + TestRepairInfo(function_name=f.function_name, reason=f.reason) for f in r.functions_to_repair + ], + ) + for r in reviews + ] + + def repair_generated_tests( + self, + suite: GeneratedTestSuite, + reviews: list[TestReviewResult], + context: CodeContext, + trace_id: str = "", + previous_repair_errors: dict[str, str] | None = None, + coverage_data: CoverageData | None = None, + ) -> GeneratedTestSuite | None: + from codeflash_python.api.types import FunctionRepairInfo + from codeflash_python.code_utils.code_utils import module_name_from_file_path + + try: + client = self.get_ai_client() + except Exception: + logger.exception("Failed to create AI client for test repair") + return None + + coverage_details = coverage_data_to_details_dict(coverage_data) if coverage_data is not None else None + internal_fn = context.target_function + assert trace_id, "trace_id must be provided" + + new_test_files = list(suite.test_files) + + for review in reviews: + if not review.functions_to_repair: + continue + + idx = review.test_index + if idx >= len(suite.test_files): + continue + + tf = suite.test_files[idx] + + repair_infos = [ + FunctionRepairInfo(function_name=f.function_name, reason=f.reason) for f in review.functions_to_repair + ] + + tests_project_rootdir = self.tests_project_rootdir or self.project_root + module_path = Path(module_name_from_file_path(context.target_file, self.project_root)) + test_module_path = Path(module_name_from_file_path(tf.behavior_test_path, tests_project_rootdir)) + + helper_names = [h.qualified_name for h in context.helper_functions] + + try: + result = client.repair_generated_tests( + test_source=tf.original_test_source, + functions_to_repair=repair_infos, + function_source_code=context.target_code, + function_to_optimize=internal_fn, + helper_function_names=helper_names, + module_path=module_path, + test_module_path=test_module_path, + test_framework="pytest", + test_timeout=60, + trace_id=trace_id, + language="python", + previous_repair_errors=previous_repair_errors, + module_source_code=context.target_code, + coverage_details=coverage_details, + ) + except Exception: + logger.exception("Test repair API call failed for test %d", idx) + continue + + if result is None: + continue + + gen_source, behavior_source, perf_source = result + + # Process (replace temp dir placeholders) + gen_source, behavior_source, perf_source = process_generated_test_strings( + generated_test_source=gen_source, + instrumented_behavior_test_source=behavior_source, + instrumented_perf_test_source=perf_source, + function_to_optimize=internal_fn, + test_path=tf.behavior_test_path, + test_cfg=None, + project_module_system=None, + ) + + # Write repaired tests + tf.behavior_test_path.write_text(behavior_source, encoding="utf-8") + tf.perf_test_path.write_text(perf_source, encoding="utf-8") + + new_test_files[idx] = GeneratedTestFile( + behavior_test_path=tf.behavior_test_path, + perf_test_path=tf.perf_test_path, + behavior_test_source=behavior_source, + perf_test_source=perf_source, + original_test_source=gen_source, + ) + + return GeneratedTestSuite(test_files=new_test_files) diff --git a/src/codeflash_python/result/__init__.py b/src/codeflash_python/result/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/result/create_pr.py b/src/codeflash_python/result/create_pr.py new file mode 100644 index 000000000..fd9041a1a --- /dev/null +++ b/src/codeflash_python/result/create_pr.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import git + +from codeflash_python.api import cfapi +from codeflash_python.code_utils import env_utils +from codeflash_python.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name +from codeflash_python.code_utils.tabulate import tabulate +from codeflash_python.code_utils.time_utils import format_perf, format_time +from codeflash_python.result.critic import performance_gain +from codeflash_python.result.github_utils import github_pr_url +from codeflash_python.result.pr_comment import FileDiffContent, PrComment +from codeflash_python.static_analysis.code_replacer import is_zero_diff + +if TYPE_CHECKING: + from codeflash_core.config import TestConfig + from codeflash_python.models.models import FunctionCalledInTest, InvocationId, TestFiles + from codeflash_python.result.explanation import Explanation + + +logger = logging.getLogger("codeflash_python") + + +def existing_tests_source_for( + function_qualified_name_with_modules_from_root: str, + function_to_tests: dict[str, set[FunctionCalledInTest]], + test_cfg: TestConfig, + original_runtimes_all: dict[InvocationId, list[int]], + optimized_runtimes_all: dict[InvocationId, list[int]], + test_files_registry: TestFiles | None = None, +) -> tuple[str, str, str]: + logger.debug( + "[PR-DEBUG] existing_tests_source_for called with func=%s", function_qualified_name_with_modules_from_root + ) + logger.debug("[PR-DEBUG] function_to_tests keys: %s", list(function_to_tests.keys())) + logger.debug("[PR-DEBUG] original_runtimes_all has %s entries", len(original_runtimes_all)) + logger.debug("[PR-DEBUG] optimized_runtimes_all has %s entries", len(optimized_runtimes_all)) + test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) + if not test_files: + logger.debug("[PR-DEBUG] No test_files found for %s", function_qualified_name_with_modules_from_root) + return "", "", "" + logger.debug("[PR-DEBUG] Found %s test_files", len(test_files)) + for tf in test_files: + logger.debug("[PR-DEBUG] test_file: %s, test_type=%s", tf.tests_in_file.test_file, tf.tests_in_file.test_type) + output_existing: str = "" + output_concolic: str = "" + output_replay: str = "" + rows_existing = [] + rows_concolic = [] + rows_replay = [] + headers = ["Test File::Test Function", "Original \u23f1\ufe0f", "Optimized \u23f1\ufe0f", "Speedup"] + tests_root = test_cfg.tests_root + original_tests_to_runtimes: dict[Path, dict[str, int]] = {} + optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {} + + # Build lookup from instrumented path -> original path using the test_files_registry + # Include both behavior and benchmarking paths since test results might come from either + instrumented_to_original: dict[Path, Path] = {} + if test_files_registry: + for registry_tf in test_files_registry.test_files: + if registry_tf.original_file_path: + if registry_tf.instrumented_behavior_file_path: + instrumented_to_original[registry_tf.instrumented_behavior_file_path.resolve()] = ( + registry_tf.original_file_path.resolve() + ) + logger.debug( + "[PR-DEBUG] Mapping (behavior): %s -> %s", + registry_tf.instrumented_behavior_file_path.name, + registry_tf.original_file_path.name, + ) + if registry_tf.benchmarking_file_path: + instrumented_to_original[registry_tf.benchmarking_file_path.resolve()] = ( + registry_tf.original_file_path.resolve() + ) + logger.debug( + "[PR-DEBUG] Mapping (perf): %s -> %s", + registry_tf.benchmarking_file_path.name, + registry_tf.original_file_path.name, + ) + + # Resolve all paths to absolute for consistent comparison + non_generated_tests: set[Path] = set() + for test_file in test_files: + resolved = test_file.tests_in_file.test_file.resolve() + non_generated_tests.add(resolved) + logger.debug("[PR-DEBUG] Added to non_generated_tests: %s", resolved) + all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys() + logger.debug("[PR-DEBUG] Processing %s invocation_ids", len(all_invocation_ids)) + matched_count = 0 + skipped_count = 0 + for invocation_id in all_invocation_ids: + test_module_path = invocation_id.test_module_path + abs_path = Path(test_module_path.replace(".", os.sep)).with_suffix(".py").resolve() + if abs_path not in non_generated_tests: + skipped_count += 1 + if skipped_count <= 5: + logger.debug("[PR-DEBUG] SKIP: abs_path=%s", abs_path.name) + logger.debug("[PR-DEBUG] Expected one of: %s", [p.name for p in list(non_generated_tests)[:3]]) + continue + matched_count += 1 + logger.debug("[PR-DEBUG] MATCHED: %s", abs_path.name) + if abs_path not in original_tests_to_runtimes: + original_tests_to_runtimes[abs_path] = {} + if abs_path not in optimized_tests_to_runtimes: + optimized_tests_to_runtimes[abs_path] = {} + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + if qualified_name not in original_tests_to_runtimes[abs_path]: + original_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index] + if qualified_name not in optimized_tests_to_runtimes[abs_path]: + optimized_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index] + if invocation_id in original_runtimes_all: + original_tests_to_runtimes[abs_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index] + if invocation_id in optimized_runtimes_all: + optimized_tests_to_runtimes[abs_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index] + logger.debug("[PR-DEBUG] SUMMARY: matched=%s, skipped=%s", matched_count, skipped_count) + logger.debug("[PR-DEBUG] original_tests_to_runtimes has %s files", len(original_tests_to_runtimes)) + # parse into string + all_abs_paths = ( + original_tests_to_runtimes.keys() + ) # both will have the same keys as some default values are assigned in the previous loop + for filename in sorted(all_abs_paths): + all_qualified_names = original_tests_to_runtimes[ + filename + ].keys() # both will have the same keys as some default values are assigned in the previous loop + for qualified_name in sorted(all_qualified_names): + # if not present in optimized output nan + if ( + original_tests_to_runtimes[filename][qualified_name] != 0 + and optimized_tests_to_runtimes[filename][qualified_name] != 0 + ): + print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name]) + print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name]) + print_filename = filename.resolve().relative_to(tests_root.resolve()).as_posix() + greater = ( + optimized_tests_to_runtimes[filename][qualified_name] + > original_tests_to_runtimes[filename][qualified_name] + ) + perf_gain = format_perf( + performance_gain( + original_runtime_ns=original_tests_to_runtimes[filename][qualified_name], + optimized_runtime_ns=optimized_tests_to_runtimes[filename][qualified_name], + ) + * 100 + ) + if greater: + if "__replay_test_" in str(print_filename): + rows_replay.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u26a0\ufe0f", + ] + ) + elif "codeflash_concolic" in str(print_filename): + rows_concolic.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u26a0\ufe0f", + ] + ) + else: + rows_existing.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u26a0\ufe0f", + ] + ) + elif "__replay_test_" in str(print_filename): + rows_replay.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u2705", + ] + ) + elif "codeflash_concolic" in str(print_filename): + rows_concolic.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u2705", + ] + ) + else: + rows_existing.append( + [ + f"`{print_filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"{perf_gain}%\u2705", + ] + ) + output_existing += tabulate( + headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_existing += "\n" + if len(rows_existing) == 0: + output_existing = "" + output_concolic += tabulate( + headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_concolic += "\n" + if len(rows_concolic) == 0: + output_concolic = "" + output_replay += tabulate( + headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_replay += "\n" + if len(rows_replay) == 0: + output_replay = "" + return output_existing, output_replay, output_concolic + + +def check_create_pr( + original_code: dict[Path, str], + new_code: dict[Path, str], + explanation: Explanation, + existing_tests_source: str, + generated_original_test_source: str, + function_trace_id: str, + coverage_message: str, + replay_tests: str, + root_dir: Path, + concolic_tests: str = "", + git_remote: str | None = None, + optimization_review: str = "", + original_line_profiler: str | None = None, + optimized_line_profiler: str | None = None, +) -> None: + pr_number: int | None = env_utils.get_pr_number() + git_repo = git.Repo(search_parent_directories=True) + + if pr_number is not None: + logger.info("Suggesting changes to PR #%s ...", pr_number) + owner, repo = get_repo_owner_and_name(git_repo) + relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix() + build_file_changes = { + Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent( + oldContent=original_code[p], newContent=new_code[p] + ) + for p in original_code + if not is_zero_diff(original_code[p], new_code[p]) + } + if not build_file_changes: + logger.info("No changes to suggest to PR.") + return + response = cfapi.suggest_changes( + owner=owner, + repo=repo, + pr_number=pr_number, + file_changes=build_file_changes, + pr_comment=PrComment( + optimization_explanation=explanation.explanation_message(), + best_runtime=explanation.best_runtime_ns, + original_runtime=explanation.original_runtime_ns, + function_name=explanation.function_name, + relative_file_path=relative_path, + speedup_x=explanation.speedup_x, + speedup_pct=explanation.speedup_pct, + winning_behavior_test_results=explanation.winning_behavior_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details, + original_async_throughput=explanation.original_async_throughput, + best_async_throughput=explanation.best_async_throughput, + ), + existing_tests=existing_tests_source, + generated_tests=generated_original_test_source, + trace_id=function_trace_id, + coverage_message=coverage_message, + replay_tests=replay_tests, + concolic_tests=concolic_tests, + optimization_review=optimization_review, + original_line_profiler=original_line_profiler, + optimized_line_profiler=optimized_line_profiler, + ) + if response.ok: + logger.info("Suggestions were successfully made to PR #%s", pr_number) + else: + logger.error( + "Optimization was successful, but I failed to suggest changes to PR #%s. Response from server was: %s", + pr_number, + response.text, + ) + else: + logger.info("Creating a new PR with the optimized code...") + + owner, repo = get_repo_owner_and_name(git_repo, git_remote) + logger.info("Pushing to %s - Owner: %s, Repo: %s", git_remote, owner, repo) + + if not check_and_push_branch(git_repo, git_remote, wait_for_push=True): + logger.warning("\u23ed\ufe0f Branch is not pushed, skipping PR creation...") + return + relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix() + base_branch = get_current_branch() + build_file_changes = { + Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent( + oldContent=original_code[p], newContent=new_code[p] + ) + for p in original_code + } + + response = cfapi.create_pr( + owner=owner, + repo=repo, + base_branch=base_branch, + file_changes=build_file_changes, + pr_comment=PrComment( + optimization_explanation=explanation.explanation_message(), + best_runtime=explanation.best_runtime_ns, + original_runtime=explanation.original_runtime_ns, + function_name=explanation.function_name, + relative_file_path=relative_path, + speedup_x=explanation.speedup_x, + speedup_pct=explanation.speedup_pct, + winning_behavior_test_results=explanation.winning_behavior_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details, + original_async_throughput=explanation.original_async_throughput, + best_async_throughput=explanation.best_async_throughput, + ), + existing_tests=existing_tests_source, + generated_tests=generated_original_test_source, + trace_id=function_trace_id, + coverage_message=coverage_message, + replay_tests=replay_tests, + concolic_tests=concolic_tests, + optimization_review=optimization_review, + original_line_profiler=original_line_profiler, + optimized_line_profiler=optimized_line_profiler, + ) + if response.ok: + pr_id = response.text + pr_url = github_pr_url(owner, repo, pr_id) + logger.info("Successfully created a new PR #%s with the optimized code: %s", pr_id, pr_url) + else: + logger.error( + "Optimization was successful, but I failed to create a PR with the optimized code." + " Response from server was: %s", + response.text, + ) diff --git a/src/codeflash_python/result/critic.py b/src/codeflash_python/result/critic.py new file mode 100644 index 000000000..a49ed3054 --- /dev/null +++ b/src/codeflash_python/result/critic.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +from codeflash_python.code_utils import env_utils +from codeflash_python.code_utils.config_consts import ( + COVERAGE_THRESHOLD, + MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD, + MIN_IMPROVEMENT_THRESHOLD, + MIN_TESTCASE_PASSED_THRESHOLD, + MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD, +) +from codeflash_python.models.test_type import TestType + +if TYPE_CHECKING: + from codeflash_python.models.models import ( + ConcurrencyMetrics, + CoverageData, + OptimizedCandidateResult, + OriginalCodeBaseline, + ) + + +class AcceptanceReason(Enum): + RUNTIME = "runtime" + THROUGHPUT = "throughput" + CONCURRENCY = "concurrency" + NONE = "none" + + +def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float: + """Calculate the performance gain of an optimized code over the original code. + + This value multiplied by 100 gives the percentage improvement in runtime. + """ + if optimized_runtime_ns == 0: + return 0.0 + return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns + + +def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float: + """Calculate the throughput gain of an optimized code over the original code. + + This value multiplied by 100 gives the percentage improvement in throughput. + For throughput, higher values are better (more executions per time period). + """ + if original_throughput == 0: + return 0.0 + return (optimized_throughput - original_throughput) / original_throughput + + +def concurrency_gain(original_metrics: ConcurrencyMetrics, optimized_metrics: ConcurrencyMetrics) -> float: + """Calculate concurrency ratio improvement. + + Returns the relative improvement in concurrency ratio. + Higher is better - means the optimized code scales better with concurrent execution. + + concurrency_ratio = sequential_time / concurrent_time + A ratio of 10 means concurrent execution is 10x faster than sequential. + """ + if original_metrics.concurrency_ratio == 0: + return 0.0 + return ( + optimized_metrics.concurrency_ratio - original_metrics.concurrency_ratio + ) / original_metrics.concurrency_ratio + + +def speedup_critic( + candidate_result: OptimizedCandidateResult, + original_code_runtime: int, + best_runtime_until_now: int | None, + *, + disable_gh_action_noise: bool = False, + original_async_throughput: int | None = None, + best_throughput_until_now: int | None = None, + original_concurrency_metrics: ConcurrencyMetrics | None = None, + best_concurrency_ratio_until_now: float | None = None, +) -> bool: + """Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user. + + Evaluates runtime performance, async throughput, and concurrency improvements. + + For runtime performance: + - Ensures the optimization is actually faster than the original code, above the noise floor. + - The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD + when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime. + - The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance. + + For async throughput (when available): + - Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD + - Throughput improvements complement runtime improvements for async functions + + For concurrency (when available): + - Evaluates concurrency ratio improvements using MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD + - Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents + """ + # Runtime performance evaluation + noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD + if not disable_gh_action_noise and env_utils.is_ci(): + noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode + + perf_gain = performance_gain( + original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime + ) + runtime_improved = perf_gain > noise_floor + + # Check runtime comparison with best so far + runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now + + throughput_improved = True # Default to True if no throughput data + throughput_is_best = True # Default to True if no throughput data + + if original_async_throughput is not None and candidate_result.async_throughput is not None: + if original_async_throughput > 0: + throughput_gain_value = throughput_gain( + original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput + ) + throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD + + throughput_is_best = ( + best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now + ) + + # Concurrency evaluation + concurrency_improved = False + concurrency_is_best = True + if original_concurrency_metrics is not None and candidate_result.concurrency_metrics is not None: + conc_gain = concurrency_gain(original_concurrency_metrics, candidate_result.concurrency_metrics) + concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD + concurrency_is_best = ( + best_concurrency_ratio_until_now is None + or candidate_result.concurrency_metrics.concurrency_ratio > best_concurrency_ratio_until_now + ) + + # Accept if ANY of: runtime, throughput, or concurrency improves significantly + if original_async_throughput is not None and candidate_result.async_throughput is not None: + throughput_acceptance = throughput_improved and throughput_is_best + runtime_acceptance = runtime_improved and runtime_is_best + concurrency_acceptance = concurrency_improved and concurrency_is_best + return throughput_acceptance or runtime_acceptance or concurrency_acceptance + return runtime_improved and runtime_is_best + + +def get_acceptance_reason( + original_runtime_ns: int, + optimized_runtime_ns: int, + *, + original_async_throughput: int | None = None, + optimized_async_throughput: int | None = None, + original_concurrency_metrics: ConcurrencyMetrics | None = None, + optimized_concurrency_metrics: ConcurrencyMetrics | None = None, +) -> AcceptanceReason: + """Determine why an optimization was accepted. + + Returns the primary reason for acceptance, with priority: + concurrency > throughput > runtime (for async code). + """ + noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD + if env_utils.is_ci(): + noise_floor = noise_floor * 2 + + perf_gain = performance_gain(original_runtime_ns=original_runtime_ns, optimized_runtime_ns=optimized_runtime_ns) + runtime_improved = perf_gain > noise_floor + + throughput_improved = False + if ( + original_async_throughput is not None + and optimized_async_throughput is not None + and original_async_throughput > 0 + ): + throughput_gain_value = throughput_gain( + original_throughput=original_async_throughput, optimized_throughput=optimized_async_throughput + ) + throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD + + concurrency_improved = False + if original_concurrency_metrics is not None and optimized_concurrency_metrics is not None: + conc_gain = concurrency_gain(original_concurrency_metrics, optimized_concurrency_metrics) + concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD + + # Return reason with priority: concurrency > throughput > runtime + if original_async_throughput is not None and optimized_async_throughput is not None: + if concurrency_improved: + return AcceptanceReason.CONCURRENCY + if throughput_improved: + return AcceptanceReason.THROUGHPUT + if runtime_improved: + return AcceptanceReason.RUNTIME + return AcceptanceReason.NONE + + if runtime_improved: + return AcceptanceReason.RUNTIME + return AcceptanceReason.NONE + + +def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool: + test_results = candidate_result.behavior_test_results + report = test_results.get_test_pass_fail_report_by_type() + + pass_count = 0 + for test_type in report: + pass_count += report[test_type]["passed"] + + if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD: + return True + # If one or more tests passed, check if least one of them was a successful REPLAY_TEST + return bool(pass_count >= 1 and report[TestType.REPLAY_TEST]["passed"] >= 1) + + +def coverage_critic(original_code_coverage: CoverageData | None) -> bool: + """Check if the coverage meets the threshold.""" + if original_code_coverage: + return original_code_coverage.coverage >= COVERAGE_THRESHOLD + return False diff --git a/src/codeflash_python/result/explanation.py b/src/codeflash_python/result/explanation.py new file mode 100644 index 000000000..93ada3726 --- /dev/null +++ b/src/codeflash_python/result/explanation.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from pathlib import Path + +from pydantic.dataclasses import dataclass + +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults +from codeflash_python.result.critic import AcceptanceReason, concurrency_gain, throughput_gain + + +@dataclass(frozen=True, config={"arbitrary_types_allowed": True}) +class Explanation: + raw_explanation_message: str + winning_behavior_test_results: TestResults + winning_benchmarking_test_results: TestResults + original_runtime_ns: int + best_runtime_ns: int + function_name: str + file_path: Path + benchmark_details: list[BenchmarkDetail] | None = None + original_async_throughput: int | None = None + best_async_throughput: int | None = None + original_concurrency_metrics: ConcurrencyMetrics | None = None + best_concurrency_metrics: ConcurrencyMetrics | None = None + acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME + + @property + def perf_improvement_line(self) -> str: + improvement_type = { + AcceptanceReason.RUNTIME: "runtime", + AcceptanceReason.THROUGHPUT: "throughput", + AcceptanceReason.CONCURRENCY: "concurrency", + AcceptanceReason.NONE: "", + }.get(self.acceptance_reason, "") + + if improvement_type: + return f"{self.speedup_pct} {improvement_type} improvement ({self.speedup_x} faster)." + return f"{self.speedup_pct} improvement ({self.speedup_x} faster)." + + @property + def speedup(self) -> float: + """Returns the improvement value for the metric that caused acceptance.""" + if ( + self.acceptance_reason == AcceptanceReason.CONCURRENCY + and self.original_concurrency_metrics + and self.best_concurrency_metrics + ): + return concurrency_gain(self.original_concurrency_metrics, self.best_concurrency_metrics) + + if ( + self.acceptance_reason == AcceptanceReason.THROUGHPUT + and self.original_async_throughput is not None + and self.best_async_throughput is not None + and self.original_async_throughput > 0 + ): + return throughput_gain( + original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput + ) + + return (self.original_runtime_ns / self.best_runtime_ns) - 1 + + @property + def speedup_x(self) -> str: + return f"{self.speedup:,.2f}x" + + @property + def speedup_pct(self) -> str: + return f"{self.speedup * 100:,.0f}%" + + def __str__(self) -> str: + original_runtime_human = humanize_runtime(self.original_runtime_ns) + best_runtime_human = humanize_runtime(self.best_runtime_ns) + + # Determine if we're showing throughput or runtime improvements + benchmark_info = "" + + if self.benchmark_details: + headers = ["Benchmark Module Path", "Test Function", "Original Runtime", "Expected New Runtime", "Speedup"] + rows = [] + for detail in self.benchmark_details: + rows.append( + [ + detail.benchmark_name, + detail.test_function, + detail.original_timing, + detail.expected_new_timing, + f"{detail.speedup_percent:.2f}%", + ] + ) + col_widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + col_widths[i] = max(col_widths[i], len(cell)) + fmt = " ".join(f"{{:<{w}}}" for w in col_widths) + lines = [ + "Benchmark Performance Details", + fmt.format(*headers), + "-" * sum([*col_widths, 2 * (len(headers) - 1)]), + ] + for row in rows: + lines.append(fmt.format(*row)) + benchmark_info = "\n".join(lines) + "\n\n" + + if ( + self.acceptance_reason == AcceptanceReason.CONCURRENCY + and self.original_concurrency_metrics + and self.best_concurrency_metrics + ): + orig_ratio = self.original_concurrency_metrics.concurrency_ratio + best_ratio = self.best_concurrency_metrics.concurrency_ratio + performance_description = ( + f"Concurrency ratio improved from {orig_ratio:.2f}x to {best_ratio:.2f}x " + f"(concurrent execution now {best_ratio:.2f}x faster than sequential)\n\n" + ) + elif ( + self.acceptance_reason == AcceptanceReason.THROUGHPUT + and self.original_async_throughput is not None + and self.best_async_throughput is not None + ): + performance_description = ( + f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second " + f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n" + ) + else: + performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + + return ( + f"Optimized {self.function_name} in {self.file_path}\n" + f"{self.perf_improvement_line}\n" + + performance_description + + (benchmark_info if benchmark_info else "") + + self.raw_explanation_message + + " \n\n" + + "The new optimized code was tested for correctness. The results are listed below.\n" + f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n" + ) + + def explanation_message(self) -> str: + return self.raw_explanation_message diff --git a/src/codeflash_python/result/github_utils.py b/src/codeflash_python/result/github_utils.py new file mode 100644 index 000000000..af779ed21 --- /dev/null +++ b/src/codeflash_python/result/github_utils.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash_python.api.cfapi import is_github_app_installed_on_repo +from codeflash_python.cli_common import apologize_and_exit +from codeflash_python.code_utils.compat import LF +from codeflash_python.code_utils.git_utils import get_repo_owner_and_name + +logger = logging.getLogger("codeflash_python") + +if TYPE_CHECKING: + from git import Repo + + +def get_github_secrets_page_url(repo: Repo | None = None) -> str: + owner, repo_name = get_repo_owner_and_name(repo) + return f"https://github.com/{owner}/{repo_name}/settings/secrets/actions" + + +def require_github_app_or_exit(owner: str, repo: str) -> None: + # Suppress low-level HTTP error logging to avoid duplicate logs; we present a friendly panel instead + if not is_github_app_installed_on_repo(owner, repo, suppress_errors=True): + # Show a clear, user-friendly panel instead of raw error logs + message = ( + f"It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo} " + f"or the GitHub account linked to your CODEFLASH_API_KEY does not have access to the repository {owner}/{repo}.{LF}{LF}" + "To continue, install the Codeflash GitHub App on your repository:" + f"{LF}https://github.com/apps/codeflash-ai/installations/select_target{LF}{LF}" + "Tip: If you want to find optimizations without opening PRs, run Codeflash with the --no-pr flag." + ) + logger.warning("GitHub App Required: %s", message) + apologize_and_exit() + + +def github_pr_url(owner: str, repo: str, pr_number: str) -> str: + return f"https://github.com/{owner}/{repo}/pull/{pr_number}" diff --git a/src/codeflash_python/result/pr_comment.py b/src/codeflash_python/result/pr_comment.py new file mode 100644 index 000000000..815bd6463 --- /dev/null +++ b/src/codeflash_python/result/pr_comment.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from pydantic import BaseModel +from pydantic.dataclasses import dataclass + +from codeflash_python.code_utils.time_utils import humanize_runtime +from codeflash_python.models.models import BenchmarkDetail, TestResults + + +@dataclass(frozen=True, config={"arbitrary_types_allowed": True}) +class PrComment: + optimization_explanation: str + best_runtime: int + original_runtime: int + function_name: str + relative_file_path: str + speedup_x: str + speedup_pct: str + winning_behavior_test_results: TestResults + winning_benchmarking_test_results: TestResults + benchmark_details: list[BenchmarkDetail] | None = None + original_async_throughput: int | None = None + best_async_throughput: int | None = None + + def to_json(self) -> dict[str, str | int | dict[str, dict[str, int]] | list[BenchmarkDetail] | None]: + report_table: dict[str, dict[str, int]] = {} + for test_type, counts in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items(): + name = test_type.to_name() + if name: + report_table[name] = counts + + result: dict[str, str | int | dict[str, dict[str, int]] | list[BenchmarkDetail] | None] = { + "optimization_explanation": self.optimization_explanation, + "best_runtime": humanize_runtime(self.best_runtime), + "original_runtime": humanize_runtime(self.original_runtime), + "function_name": self.function_name, + "file_path": self.relative_file_path, + "speedup_x": self.speedup_x, + "speedup_pct": self.speedup_pct, + "loop_count": self.winning_benchmarking_test_results.number_of_loops(), + "report_table": report_table, + "benchmark_details": self.benchmark_details if self.benchmark_details else None, + } + + if self.original_async_throughput is not None and self.best_async_throughput is not None: + result["original_async_throughput"] = self.original_async_throughput + result["best_async_throughput"] = self.best_async_throughput + + return result + + +class FileDiffContent(BaseModel): + oldContent: str # noqa: N815 + newContent: str # noqa: N815 diff --git a/src/codeflash_python/setup/__init__.py b/src/codeflash_python/setup/__init__.py new file mode 100644 index 000000000..e603c7054 --- /dev/null +++ b/src/codeflash_python/setup/__init__.py @@ -0,0 +1,58 @@ +"""Setup module for Codeflash auto-detection and first-run experience. + +This module provides: +- Python project detection +- First-run experience with auto-detection and quick confirm +- Config writing to pyproject.toml +""" + +from __future__ import annotations + +from typing import Any + +from codeflash_python.setup.config_schema import CodeflashConfig + +try: + from codeflash_python.setup.detector import DetectedProject, detect_project, has_existing_config +except ImportError: + # Stub imports if detector not available + class DetectedProject: + pass + + def detect_project() -> Any: + msg = "detector not available" + raise NotImplementedError(msg) + + def has_existing_config(project_root: Any) -> tuple[bool, None]: + return False, None + + +try: + from codeflash_python.setup.config_writer import write_config +except ImportError: + from codeflash_core.danom import Err, Result # noqa: TC001 + + def write_config(detected: Any, config: Any = None) -> Result[str, str]: + return Err("config_writer not available") + + +try: + from codeflash_python.setup.first_run import handle_first_run, is_first_run +except ImportError: + + def is_first_run(project_root: Any = None) -> bool: + return False + + def handle_first_run(args: Any = None, skip_confirm: bool = False, skip_api_key: bool = False) -> Any: + return args + + +__all__ = [ + "CodeflashConfig", + "DetectedProject", + "detect_project", + "handle_first_run", + "has_existing_config", + "is_first_run", + "write_config", +] diff --git a/src/codeflash_python/setup/config_schema.py b/src/codeflash_python/setup/config_schema.py new file mode 100644 index 000000000..999b7b83f --- /dev/null +++ b/src/codeflash_python/setup/config_schema.py @@ -0,0 +1,118 @@ +"""Codeflash configuration schema using Pydantic.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, Field + +if TYPE_CHECKING: + from pathlib import Path + + +class CodeflashConfig(BaseModel): + """Internal representation of Codeflash configuration. + + This is the canonical config format used internally. It can be converted + to/from pyproject.toml [tool.codeflash] format. + + Note: All paths are stored as strings (relative to project root). + """ + + # Core settings (always present after detection) + language: str = Field(default="python", description="Project language") + module_root: str = Field(default=".", description="Root directory containing source code") + tests_root: str | None = Field(default=None, description="Root directory containing tests") + + # Tooling settings (auto-detected, can be overridden) + test_framework: str | None = Field(default=None, description="Test framework: pytest") + formatter_cmds: list[str] = Field(default_factory=list, description="Formatter commands") + + # Optional settings + ignore_paths: list[str] = Field(default_factory=list, description="Paths to ignore") + benchmarks_root: str | None = Field(default=None, description="Benchmarks directory") + + # Git settings + git_remote: str = Field(default="origin", description="Git remote for PRs") + + # Privacy settings + disable_telemetry: bool = Field(default=False, description="Disable telemetry") + + # Python-specific settings + pytest_cmd: str = Field(default="pytest", description="Pytest command") + disable_imports_sorting: bool = Field(default=False, description="Disable import sorting") + override_fixtures: bool = Field(default=False, description="Override test fixtures") + + model_config = ConfigDict(extra="allow") # Allow extra fields for forward compatibility + + def to_pyproject_dict(self) -> dict[str, Any]: + """Convert to pyproject.toml [tool.codeflash] format. + + Uses kebab-case keys as per TOML conventions. + Only includes non-default values to keep config minimal. + """ + config: dict[str, Any] = {} + + # Always include required fields + config["module-root"] = self.module_root + if self.tests_root: + config["tests-root"] = self.tests_root + + # Include non-default optional fields + if self.ignore_paths: + config["ignore-paths"] = self.ignore_paths + + if self.formatter_cmds and self.formatter_cmds != ["black $file"]: + config["formatter-cmds"] = self.formatter_cmds + elif not self.formatter_cmds: + config["formatter-cmds"] = ["disabled"] + + if self.benchmarks_root: + config["benchmarks-root"] = self.benchmarks_root + + if self.git_remote and self.git_remote != "origin": + config["git-remote"] = self.git_remote + + if self.disable_telemetry: + config["disable-telemetry"] = True + + if self.pytest_cmd and self.pytest_cmd != "pytest": + config["pytest-cmd"] = self.pytest_cmd + + if self.disable_imports_sorting: + config["disable-imports-sorting"] = True + + if self.override_fixtures: + config["override-fixtures"] = True + + return config + + @classmethod + def from_detected_project(cls, detected: Any) -> CodeflashConfig: + """Create config from DetectedProject.""" + return cls( + language="python", + module_root=str(detected.module_root.relative_to(detected.project_root)) + if detected.module_root != detected.project_root + else ".", + tests_root=str(detected.tests_root.relative_to(detected.project_root)) if detected.tests_root else None, + test_framework=detected.test_runner, + formatter_cmds=detected.formatter_cmds, + ignore_paths=[ + str(p.relative_to(detected.project_root)) for p in detected.ignore_paths if p != detected.project_root + ], + pytest_cmd=detected.test_runner, + ) + + @classmethod + def from_pyproject_dict(cls, data: dict[str, Any], project_root: Path | None = None) -> CodeflashConfig: + """Create config from pyproject.toml [tool.codeflash] section.""" + _ = project_root # Reserved for future path resolution + + def convert_key(key: str) -> str: + """Convert kebab-case to snake_case.""" + return key.replace("-", "_") + + converted = {convert_key(k): v for k, v in data.items()} + converted.setdefault("language", "python") + return cls(**converted) diff --git a/src/codeflash_python/setup/config_writer.py b/src/codeflash_python/setup/config_writer.py new file mode 100644 index 000000000..c4224f348 --- /dev/null +++ b/src/codeflash_python/setup/config_writer.py @@ -0,0 +1,118 @@ +"""Config writer for pyproject.toml. + +This module writes Codeflash configuration to pyproject.toml [tool.codeflash]. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import tomlkit + +from codeflash_core.danom import Err, Ok + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.danom import Result + from codeflash_python.setup.config_schema import CodeflashConfig + from codeflash_python.setup.detector import DetectedProject + + +def write_config(detected: DetectedProject, config: CodeflashConfig | None = None) -> Result[str, str]: + """Write Codeflash config to pyproject.toml.""" + from codeflash_python.setup.config_schema import CodeflashConfig + + if config is None: + config = CodeflashConfig.from_detected_project(detected) + + return write_pyproject_toml(detected.project_root, config) + + +def write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> Result[str, str]: + """Write config to pyproject.toml [tool.codeflash] section.""" + pyproject_path = project_root / "pyproject.toml" + + try: + if pyproject_path.exists(): + with pyproject_path.open("rb") as f: + doc = tomlkit.parse(f.read()) + else: + doc = tomlkit.document() + + if "tool" not in doc: + doc["tool"] = tomlkit.table() + + codeflash_table = tomlkit.table() + codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai")) + + config_dict = config.to_pyproject_dict() + for key, value in config_dict.items(): + codeflash_table[key] = value + + doc["tool"]["codeflash"] = codeflash_table # type: ignore[index] + + with pyproject_path.open("w", encoding="utf8") as f: + f.write(tomlkit.dumps(doc)) + + return Ok(f"Config saved to {pyproject_path}") + + except Exception as e: + return Err(f"Failed to write pyproject.toml: {e}") + + +def create_pyproject_toml(project_root: Path) -> Result[str, str]: + """Create a minimal pyproject.toml file.""" + pyproject_path = project_root / "pyproject.toml" + + if pyproject_path.exists(): + return Err(f"pyproject.toml already exists at {pyproject_path}") + + try: + doc = tomlkit.document() + doc.add(tomlkit.comment("Created by Codeflash")) + doc.add(tomlkit.nl()) + + tool_table = tomlkit.table() + codeflash_table = tomlkit.table() + codeflash_table.add(tomlkit.comment("Codeflash configuration - https://docs.codeflash.ai")) + tool_table["codeflash"] = codeflash_table + doc["tool"] = tool_table + + with pyproject_path.open("w", encoding="utf8") as f: + f.write(tomlkit.dumps(doc)) + + return Ok(f"Created {pyproject_path}") + + except Exception as e: + return Err(f"Failed to create pyproject.toml: {e}") + + +def remove_config(project_root: Path) -> Result[str, str]: + """Remove Codeflash config from pyproject.toml.""" + return remove_from_pyproject(project_root) + + +def remove_from_pyproject(project_root: Path) -> Result[str, str]: + """Remove [tool.codeflash] section from pyproject.toml.""" + pyproject_path = project_root / "pyproject.toml" + + if not pyproject_path.exists(): + return Ok("No pyproject.toml found") + + try: + with pyproject_path.open("rb") as f: + doc = tomlkit.parse(f.read()) + + if "tool" in doc and "codeflash" in doc["tool"]: # type: ignore[operator] + del doc["tool"]["codeflash"] # type: ignore[attr-defined] + + with pyproject_path.open("w", encoding="utf8") as f: + f.write(tomlkit.dumps(doc)) + + return Ok("Removed [tool.codeflash] section from pyproject.toml") + + return Ok("No codeflash config found in pyproject.toml") + + except Exception as e: + return Err(f"Failed to remove config: {e}") diff --git a/src/codeflash_python/setup/detector.py b/src/codeflash_python/setup/detector.py new file mode 100644 index 000000000..1633028e8 --- /dev/null +++ b/src/codeflash_python/setup/detector.py @@ -0,0 +1,246 @@ +"""Python project detection engine for Codeflash. + +Usage: + from codeflash_python.setup import detect_project + + detected = detect_project() +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +import tomlkit + +from codeflash_python.setup.detector_python import ( + detect_python_formatter, + detect_python_module_root, + detect_python_test_runner, +) + + +@dataclass +class DetectedProject: + """Result of project auto-detection. + + All paths are absolute. The confidence score indicates how certain + we are about the detection (0.0 = guessing, 1.0 = certain). + """ + + # Core detection results + language: str + project_root: Path + module_root: Path + tests_root: Path | None + + # Tooling detection + test_runner: str + formatter_cmds: list[str] + + # Ignore paths (absolute paths to ignore) + ignore_paths: list[Path] = field(default_factory=list) + + # Confidence score for the detection (0.0 - 1.0) + confidence: float = 0.8 + + # Detection details (for debugging/display) + detection_details: dict[str, str] = field(default_factory=dict) + + def to_display_dict(self) -> dict[str, str]: + """Convert to dictionary for display purposes.""" + formatter_display = self.formatter_cmds[0] if self.formatter_cmds else "none detected" + if len(self.formatter_cmds) > 1: + formatter_display += f" (+{len(self.formatter_cmds) - 1} more)" + + ignore_display = ", ".join(p.name for p in self.ignore_paths[:3]) + if len(self.ignore_paths) > 3: + ignore_display += f" (+{len(self.ignore_paths) - 3} more)" + + return { + "Language": self.language.capitalize(), + "Module root": str(self.module_root.relative_to(self.project_root)) + if self.module_root != self.project_root + else ".", + "Tests root": str(self.tests_root.relative_to(self.project_root)) if self.tests_root else "not detected", + "Test runner": self.test_runner, + "Formatter": formatter_display or "none", + "Ignoring": ignore_display or "defaults only", + } + + +def detect_project(path: Path | None = None) -> DetectedProject: + """Auto-detect all project settings. + + This is the main entry point for project detection. It finds the project root, + detects the language, and auto-detects all configuration values. + + Args: + path: Starting path for detection. Defaults to current working directory. + + Returns: + DetectedProject with all detected settings. + + Raises: + ValueError: If no valid project can be detected. + + """ + start_path = path or Path.cwd() + detection_details: dict[str, str] = {} + + # Step 1: Find project root + project_root = find_project_root(start_path) + if project_root is None: + project_root = start_path + detection_details["project_root"] = "using current directory (no markers found)" + else: + detection_details["project_root"] = f"found at {project_root}" + + detection_details["language"] = "python" + + # Step 2: Detect module root + module_root, module_detail = detect_python_module_root(project_root) + detection_details["module_root"] = module_detail + + # Step 3: Detect tests root + tests_root, tests_detail = detect_tests_root(project_root) + detection_details["tests_root"] = tests_detail + + # Step 4: Detect test runner + test_runner, runner_detail = detect_python_test_runner(project_root) + detection_details["test_runner"] = runner_detail + + # Step 5: Detect formatter + formatter_cmds, formatter_detail = detect_python_formatter(project_root) + detection_details["formatter"] = formatter_detail + + # Step 6: Detect ignore paths + ignore_paths, ignore_detail = detect_ignore_paths(project_root) + detection_details["ignore_paths"] = ignore_detail + + return DetectedProject( + language="python", + project_root=project_root, + module_root=module_root, + tests_root=tests_root, + test_runner=test_runner, + formatter_cmds=formatter_cmds, + ignore_paths=ignore_paths, + confidence=1.0, + detection_details=detection_details, + ) + + +def find_project_root(start_path: Path) -> Path | None: + """Find the project root by walking up the directory tree. + + Looks for: + - .git directory (git repository root) + - pyproject.toml (Python project) + + """ + current = start_path.resolve() + + while current != current.parent: + markers = [".git", "pyproject.toml", "setup.py", "setup.cfg"] + for marker in markers: + if (current / marker).exists(): + return current + current = current.parent + + return None + + +def detect_tests_root(project_root: Path) -> tuple[Path | None, str]: + """Detect the tests directory.""" + for test_dir in ("tests", "test"): + test_path = project_root / test_dir + if test_path.is_dir(): + return test_path, f"{test_dir}/ directory" + + # Check if tests are alongside source + test_files = list(project_root.glob("test_*.py")) + if test_files: + return project_root, "test files in project root" + + return None, "not detected" + + +def detect_ignore_paths(project_root: Path) -> tuple[list[Path], str]: + """Detect paths to ignore during optimization.""" + ignore_paths: list[Path] = [] + sources: list[str] = [] + + default_ignores = [ + "__pycache__", + ".pytest_cache", + ".mypy_cache", + ".ruff_cache", + "venv", + ".venv", + "env", + ".env", + "dist", + "build", + ".egg-info", + ".tox", + ".nox", + "htmlcov", + ".coverage", + ] + + for pattern in default_ignores: + path = project_root / pattern + if path.exists(): + ignore_paths.append(path) + + if ignore_paths: + sources.append("defaults") + + # Parse .gitignore + gitignore_path = project_root / ".gitignore" + if gitignore_path.exists(): + try: + content = gitignore_path.read_text(encoding="utf8") + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("!"): + continue + pattern = line.rstrip("/").lstrip("/") + if "*" in pattern or "?" in pattern: + continue + path = project_root / pattern + if path.exists() and path not in ignore_paths: + ignore_paths.append(path) + + if ".gitignore" not in sources: + sources.append(".gitignore") + except Exception: + pass + + detail = " + ".join(sources) if sources else "none" + return ignore_paths, detail + + +def has_existing_config(project_root: Path) -> tuple[bool, str | None]: + """Check if project has existing Codeflash configuration. + + Returns: + Tuple of (has_config, config_file_type). + config_file_type is "pyproject.toml", "codeflash.toml", or None. + + """ + for toml_filename in ("pyproject.toml", "codeflash.toml"): + toml_path = project_root / toml_filename + if toml_path.exists(): + try: + with toml_path.open("rb") as f: + data = tomlkit.parse(f.read()) + if "tool" in data and "codeflash" in data["tool"]: # type: ignore[unsupported-operator] + return True, toml_filename + except Exception: + pass + + return False, None diff --git a/src/codeflash_python/setup/detector_python.py b/src/codeflash_python/setup/detector_python.py new file mode 100644 index 000000000..e94d6b2c3 --- /dev/null +++ b/src/codeflash_python/setup/detector_python.py @@ -0,0 +1,141 @@ +"""Python-specific project detection functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import tomlkit + +if TYPE_CHECKING: + from pathlib import Path + + +def detect_python_module_root(project_root: Path) -> tuple[Path, str]: + """Detect Python module root. + + Priority: + 1. pyproject.toml [tool.poetry.name] or [project.name] + 2. src/ directory with __init__.py + 3. Directory with __init__.py matching project name + 4. src/ directory (even without __init__.py) + 5. Project root + + """ + # Try to get project name from pyproject.toml + pyproject_path = project_root / "pyproject.toml" + project_name = None + + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + data = tomlkit.parse(f.read()) + + # Try poetry name + project_name = data.get("tool", {}).get("poetry", {}).get("name") + # Try standard project name + if not project_name: + project_name = data.get("project", {}).get("name") + except Exception: + pass + + # Check for src layout + src_dir = project_root / "src" + if src_dir.is_dir(): + # Check for package inside src + if project_name: + pkg_dir = src_dir / project_name + if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists(): + return pkg_dir, f"src/{project_name}/ (from pyproject.toml name)" + + # Check for any package in src + for child in src_dir.iterdir(): + if child.is_dir() and (child / "__init__.py").exists(): + return child, f"src/{child.name}/ (first package in src)" + + # Use src/ even without __init__.py + return src_dir, "src/ directory" + + # Check for package at project root + if project_name: + pkg_dir = project_root / project_name + if pkg_dir.is_dir() and (pkg_dir / "__init__.py").exists(): + return pkg_dir, f"{project_name}/ (from pyproject.toml name)" + + # Look for any directory with __init__.py at project root + for child in project_root.iterdir(): + if ( + child.is_dir() + and not child.name.startswith(".") + and child.name not in ("tests", "test", "docs", "venv", ".venv", "env") + ): + if (child / "__init__.py").exists(): + return child, f"{child.name}/ (has __init__.py)" + + # Default to project root + return project_root, "project root (no package structure detected)" + + +def detect_python_test_runner(project_root: Path) -> tuple[str, str]: + """Detect Python test runner.""" + # Check for pytest markers + pytest_markers = ["pytest.ini", "pyproject.toml", "conftest.py", "setup.cfg"] + for marker in pytest_markers: + marker_path = project_root / marker + if marker_path.exists(): + if marker == "pyproject.toml": + # Check for [tool.pytest] section + try: + with marker_path.open("rb") as f: + data = tomlkit.parse(f.read()) + if "tool" in data and "pytest" in data["tool"]: # type: ignore[unsupported-operator] + return "pytest", "pyproject.toml [tool.pytest]" + except Exception: + pass + elif marker == "conftest.py": + return "pytest", "conftest.py found" + elif marker in ("pytest.ini", "setup.cfg"): + # Check for pytest section in setup.cfg + if marker == "setup.cfg": + try: + content = marker_path.read_text(encoding="utf8") + if "[tool:pytest]" in content or "[pytest]" in content: + return "pytest", "setup.cfg [pytest]" + except Exception: + pass + else: + return "pytest", "pytest.ini found" + + # Default to pytest (most common) + return "pytest", "default" + + +def detect_python_formatter(project_root: Path) -> tuple[list[str], str]: + """Detect Python formatter.""" + pyproject_path = project_root / "pyproject.toml" + + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + data = tomlkit.parse(f.read()) + + tool = data.get("tool", {}) + + # Check for ruff + if "ruff" in tool: + return ["ruff check --exit-zero --fix $file", "ruff format $file"], "from pyproject.toml [tool.ruff]" + + # Check for black + if "black" in tool: + return ["black $file"], "from pyproject.toml [tool.black]" + except Exception: + pass + + # Check for config files + if (project_root / "ruff.toml").exists() or (project_root / ".ruff.toml").exists(): + return ["ruff check --exit-zero --fix $file", "ruff format $file"], "ruff.toml found" + + if (project_root / ".black").exists() or (project_root / "pyproject.toml").exists(): + # Default to black if pyproject.toml exists (common setup) + return ["black $file"], "default (black)" + + return [], "none detected" diff --git a/src/codeflash_python/setup/first_run.py b/src/codeflash_python/setup/first_run.py new file mode 100644 index 000000000..9dcd79aea --- /dev/null +++ b/src/codeflash_python/setup/first_run.py @@ -0,0 +1,300 @@ +"""First-run experience for Codeflash. + +This module handles the seamless first-run experience: +1. Auto-detect project settings +2. Display detected settings +3. Quick confirmation +4. API key setup +5. Save config and continue + +Usage: + from codeflash_python.setup.first_run import handle_first_run, is_first_run + + if is_first_run(): + args = handle_first_run(args) +""" + +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from argparse import Namespace + from pathlib import Path + +try: + from codeflash_python.setup.config_writer import write_config + from codeflash_python.setup.detector import detect_project, has_existing_config +except ImportError: + # Stubs for missing modules + def detect_project() -> Any: + msg = "detector module not available" + raise NotImplementedError(msg) + + def has_existing_config(project_root: Any) -> tuple[bool, None]: + return False, None + + def write_config(detected: Any) -> tuple[bool, str]: + return False, "config_writer not available" + + +def is_first_run(project_root: Path | None = None) -> bool: + """Check if this is the first run (no config exists). + + Args: + project_root: Project root to check. Defaults to auto-detect. + + Returns: + True if no Codeflash config exists. + + """ + if project_root is None: + try: + detected = detect_project() + project_root = detected.project_root + except Exception: + return True + + has_config, _ = has_existing_config(project_root) + return not has_config + + +def handle_first_run( + args: Namespace | None = None, skip_confirm: bool = False, skip_api_key: bool = False +) -> Namespace | None: + """Handle the first-run experience with auto-detection and quick confirm. + + This is the main entry point for the frictionless setup experience. + + Args: + args: Optional CLI args namespace to update. + skip_confirm: Skip confirmation prompt (--yes flag). + skip_api_key: Skip API key prompt. + + Returns: + Updated args namespace with detected settings, or None if user cancelled. + + """ + from argparse import Namespace + + # Auto-detect project + try: + detected = detect_project() + except Exception as e: + show_detection_error(str(e)) + return None + + # Show welcome message + show_welcome() + + # Show detected settings + show_detected_settings(detected) + + # Get user confirmation + if not skip_confirm: + choice = prompt_confirmation() + if choice == "n": + show_cancelled() + return None + if choice == "customize": + print("\nRun codeflash init for full customization.\n") + return None + + # Handle API key + if not skip_api_key: + api_key_ok = handle_api_key() + if not api_key_ok: + return None + + # Save config + config_result = write_config(detected) + # Handle Result type from write_config + if hasattr(config_result, "is_ok") and config_result.is_ok(): # type: ignore[union-attr] + print(f"\n{config_result.unwrap()}\n") # type: ignore[union-attr] + elif hasattr(config_result, "error"): + print(f"\n{config_result.error}\n") + print("Continuing with detected settings (not saved).\n") + else: + # Handle tuple fallback case + success, message = config_result # type: ignore[misc] + if success: + print(f"\n{message}\n") + else: + print(f"\n{message}\n") + print("Continuing with detected settings (not saved).\n") + + # Create/update args namespace + if args is None: + args = Namespace() + + # Populate args with detected values + args.module_root = str(detected.module_root) + args.tests_root = str(detected.tests_root) if detected.tests_root else None + args.project_root = str(detected.project_root) + args.formatter_cmds = detected.formatter_cmds + args.ignore_paths = [str(p) for p in detected.ignore_paths] + args.pytest_cmd = detected.test_runner + args.language = detected.language + + # Set defaults for other common args + if not hasattr(args, "disable_telemetry"): + args.disable_telemetry = False + if not hasattr(args, "git_remote"): + args.git_remote = "origin" + + return args + + +def show_welcome() -> None: + """Show welcome message for first-time users.""" + print("First-Time Setup") + print("Welcome to Codeflash!") + print() + print("I've auto-detected your project settings.") + print("This will only take a moment.") + print() + + +def show_detected_settings(detected: detect_project) -> None: # type: ignore[valid-type] + """Display detected settings in a nice table.""" + try: + from codeflash_python.setup.detector import DetectedProject + except ImportError: + # Stub + class DetectedProject: + pass + + if not isinstance(detected, DetectedProject): + return + + display_dict = detected.to_display_dict() + details = detected.detection_details + + print("Auto-Detected Settings") + print(f" {'Setting':<15} {'Value':<30} {'Source'}") + print(f" {'-' * 15} {'-' * 30} {'-' * 20}") + for key, value in display_dict.items(): + source = details.get(key.lower().replace(" ", "_"), "") + if len(source) > 30: + source = source[:27] + "..." + source_str = f"({source})" if source else "" + print(f" {key:<15} {value:<30} {source_str}") + print() + + +def prompt_confirmation() -> str: + """Prompt user for confirmation. + + Returns: + "y" for yes, "n" for no, "customize" for customization. + + """ + # Check if we're in a non-interactive environment + if not sys.stdin.isatty(): + print("Non-interactive environment detected. Use --yes to skip confirmation.") + return "n" + + print("? Proceed with these settings?") + print(" Y - Yes, save and continue") + print(" n - No, cancel") + print(" c - Customize (run full setup)") + print() + + try: + choice = input("Your choice [Y]/n/c: ").strip().lower() + except (KeyboardInterrupt, EOFError): + return "n" + + if choice in ("", "y", "yes"): + return "y" + if choice in ("c", "customize"): + return "customize" + return "n" + + +def handle_api_key() -> bool: + """Handle API key setup if not already configured. + + Returns: + True if API key is available, False if user cancelled. + + """ + try: + from codeflash_python.code_utils.env_utils import get_codeflash_api_key + except ImportError: + # Stub + def get_codeflash_api_key() -> str | None: + return os.getenv("CODEFLASH_API_KEY") + + # Check for existing API key + try: + existing_key = get_codeflash_api_key() + if existing_key: + display_key = f"{existing_key[:3]}****{existing_key[-4:]}" + print(f"Found API key: {display_key}\n") + return True + except OSError: + pass + + # Prompt for API key + print("API Key Required") + print(" Get your API key at: https://app.codeflash.ai/app/apikeys\n") + + try: + api_key = input(" Enter API key (or press Enter to open browser): ").strip() + except (KeyboardInterrupt, EOFError): + return False + + if not api_key: + # Open browser + import click + + click.launch("https://app.codeflash.ai/app/apikeys") + print("\n Opening browser...") + try: + api_key = input(" Enter API key: ").strip() + except (KeyboardInterrupt, EOFError): + return False + + if not api_key: + print("\nAPI key required. Run codeflash init to set up.\n") + return False + + if not api_key.startswith("cf-"): + print("\nInvalid API key format. Should start with 'cf-'.\n") + return False + + # Save API key to environment + os.environ["CODEFLASH_API_KEY"] = api_key + + # Try to save to shell rc + try: + from codeflash_python.code_utils.shell_utils import save_api_key_to_rc + + result = save_api_key_to_rc(api_key) + if hasattr(result, "is_ok") and result.is_ok(): + print(f"\nAPI key saved. {result.unwrap()}\n") + elif hasattr(result, "error"): + print(f"\nCould not save to shell: {result.error}") + print(" API key set for this session only.\n") + else: + print("\nAPI key set for this session.\n") + except Exception: + print("\nAPI key set for this session.\n") + + return True + + +def show_detection_error(error: str) -> None: + """Show error message when detection fails.""" + print("Detection Failed") + print("Could not auto-detect project settings.\n") + print(f"Error: {error}\n") + print("Please run codeflash init for manual setup.") + + +def show_cancelled() -> None: + """Show cancellation message.""" + print("\nSetup cancelled. Run codeflash init when ready.\n") diff --git a/src/codeflash_python/static_analysis/__init__.py b/src/codeflash_python/static_analysis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/static_analysis/code_extractor.py b/src/codeflash_python/static_analysis/code_extractor.py new file mode 100644 index 000000000..6d09d7196 --- /dev/null +++ b/src/codeflash_python/static_analysis/code_extractor.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import ast +import logging +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionParent + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.models import FunctionToOptimize + + +logger = logging.getLogger("codeflash_python") + + +def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]: + """Return the code for a function or methods in a Python module. + + functions_to_optimize is either a singleton FunctionToOptimize instance, which represents either a function at the + module level or a method of a class at the module level, or it represents a list of methods of the same class. + """ + if ( + not functions_to_optimize + or (functions_to_optimize[0].parents and functions_to_optimize[0].parents[0].type != "ClassDef") + or ( + len(functions_to_optimize[0].parents) > 1 + or ((len(functions_to_optimize) > 1) and len({fn.parents[0] for fn in functions_to_optimize}) != 1) + ) + ): + return None, set() + + file_path: Path = functions_to_optimize[0].file_path + class_skeleton: set[tuple[int, int | None]] = set() + contextual_dunder_methods: set[tuple[str, str]] = set() + target_code: str = "" + + def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[str]) -> ast.AST | None: + target: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign | None = None + node: ast.stmt + for node in node_list: + if ( + # The many mypy issues will be fixed once this code moves to the backend, + # using Type Guards as we move to 3.10+. + # We will cover the Type Alias case on the backend since it's a 3.12 feature. + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == name_parts[0] + ): + target = node + break + # The next two cases cover type aliases in pre-3.12 syntax, where only single assignment is allowed. + if ( + isinstance(node, ast.Assign) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == name_parts[0] + ) or (isinstance(node, ast.AnnAssign) and hasattr(node.target, "id") and node.target.id == name_parts[0]): + if class_skeleton: + break + target = node + break + + if target is None or len(name_parts) == 1: + return target + + if not isinstance(target, ast.ClassDef) or len(name_parts) < 2: + return None + # At this point, name_parts has at least 2 elements + method_name: str = name_parts[1] # type: ignore[misc] + class_skeleton.add((target.lineno, target.body[0].lineno - 1)) + cbody = target.body + if isinstance(cbody[0], ast.expr): # Is a docstring + class_skeleton.add((cbody[0].lineno, cbody[0].end_lineno)) + cbody = cbody[1:] + cnode: ast.stmt + for cnode in cbody: + # Collect all dunder methods. + cnode_name: str + if ( + isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)) + and len(cnode_name := cnode.name) > 4 + and cnode_name != method_name + and cnode_name.isascii() + and cnode_name.startswith("__") + and cnode_name.endswith("__") + ): + contextual_dunder_methods.add((target.name, cnode_name)) + class_skeleton.add((cnode.lineno, cnode.end_lineno)) + + return find_target(target.body, (method_name,)) + + with file_path.open(encoding="utf8") as file: + source_code: str = file.read() + try: + module_node: ast.Module = ast.parse(source_code) + except SyntaxError: + logger.exception("get_code - Syntax error while parsing code") + return None, set() + # Get the source code lines for the target node + lines: list[str] = source_code.splitlines(keepends=True) + if len(functions_to_optimize[0].parents) == 1: + if ( + functions_to_optimize[0].parents[0].type == "ClassDef" + ): # All functions_to_optimize functions are methods of the same class. + qualified_name_parts_list: list[tuple[str, str] | tuple[str]] = [ + (fto.parents[0].name, fto.function_name) for fto in functions_to_optimize + ] + + else: + logger.error("Error: get_code does not support inner functions: %s", functions_to_optimize[0].parents) + return None, set() + elif len(functions_to_optimize[0].parents) == 0: + qualified_name_parts_list = [(functions_to_optimize[0].function_name,)] + else: + logger.error( + "Error: get_code does not support more than one level of nesting for now. Parents: %s", + functions_to_optimize[0].parents, + ) + return None, set() + for qualified_name_parts in qualified_name_parts_list: + target_node = find_target(module_node.body, qualified_name_parts) + if target_node is None: + continue + # find_target returns FunctionDef, AsyncFunctionDef, ClassDef, Assign, or AnnAssign - all have lineno/end_lineno + if not isinstance( + target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Assign, ast.AnnAssign) + ): + continue + + if ( + isinstance(target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + and target_node.decorator_list + ): + target_code += "".join(lines[target_node.decorator_list[0].lineno - 1 : target_node.end_lineno]) + else: + target_code += "".join(lines[target_node.lineno - 1 : target_node.end_lineno]) + if not target_code: + return None, set() + class_list: list[tuple[int, int | None]] = sorted(class_skeleton) + class_code = "".join(["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list]) + return class_code + target_code, contextual_dunder_methods + + +def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]: + edited_code, contextual_dunder_methods = get_code(functions_to_optimize) + if edited_code is None: + return None, set() + try: + compile(edited_code, "edited_code", "exec") + except SyntaxError as e: + logger.exception("extract_code - Syntax error in extracted optimization candidate code: %s", e) + return None, set() + return edited_code, contextual_dunder_methods + + +def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]: + """Find all preexisting functions, classes or class methods in the source code.""" + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set() + try: + module_node: ast.Module = ast.parse(source_code) + except SyntaxError: + logger.exception("find_preexisting_objects - Syntax error while parsing code") + return preexisting_objects + for node in module_node.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.add((node.name, ())) + elif isinstance(node, ast.ClassDef): + preexisting_objects.add((node.name, ())) + for cnode in node.body: + if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) + return preexisting_objects diff --git a/src/codeflash_python/static_analysis/code_replacer.py b/src/codeflash_python/static_analysis/code_replacer.py new file mode 100644 index 000000000..3c74ec593 --- /dev/null +++ b/src/codeflash_python/static_analysis/code_replacer.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import ast +import logging +from collections import defaultdict +from functools import lru_cache +from itertools import chain +from typing import TYPE_CHECKING, TypeVar + +import libcst as cst +from libcst.metadata import PositionProvider + +from codeflash_core.models import FunctionParent +from codeflash_python.code_utils.config_parser import find_conftest_files +from codeflash_python.code_utils.formatter import sort_imports +from codeflash_python.static_analysis.code_replacer_base import get_optimized_code_for_module +from codeflash_python.static_analysis.global_code_transforms import ( + add_global_assignments, + find_insertion_index_after_imports, +) +from codeflash_python.static_analysis.import_analysis import add_needed_imports_from_module +from codeflash_python.static_analysis.line_profile_utils import ImportAdder + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_python.models.models import CodeStringsMarkdown + +logger = logging.getLogger("codeflash_python") + +ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) + + +def normalize_node(node: ASTNodeT) -> ASTNodeT: + if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node): + node.body = node.body[1:] + if hasattr(node, "body"): + node.body = [normalize_node(n) for n in node.body if not isinstance(n, (ast.Import, ast.ImportFrom))] # type: ignore[assignment,attr-defined] + return node + + +@lru_cache(maxsize=3) +def normalize_code(code: str) -> str: + return ast.unparse(normalize_node(ast.parse(code))) + + +def has_autouse_fixture(node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + dec = decorator.decorator + if not isinstance(dec, cst.Call): + continue + is_fixture = ( + isinstance(dec.func, cst.Attribute) + and isinstance(dec.func.value, cst.Name) + and dec.func.attr.value == "fixture" + and dec.func.value.value == "pytest" + ) or (isinstance(dec.func, cst.Name) and dec.func.value == "fixture") + if is_fixture: + for arg in dec.args: + if ( + arg.keyword + and arg.keyword.value == "autouse" + and isinstance(arg.value, cst.Name) + and arg.value.value == "True" + ): + return True + return False + + +class AddRequestArgument(cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if not has_autouse_fixture(original_node): + return updated_node + + args = updated_node.params.params + arg_names = {arg.name.value for arg in args} + + if "request" in arg_names: + return updated_node + + request_param = cst.Param(name=cst.Name("request")) + + if args: + first_arg = args[0].name.value + if first_arg in {"self", "cls"}: + new_params = [args[0], request_param, *args[1:]] + else: + new_params = [request_param, *args] + else: + new_params = [request_param] + + new_param_list = updated_node.params.with_changes(params=new_params) + return updated_node.with_changes(params=new_param_list) + + +class PytestMarkAdder(cst.CSTTransformer): + """Transformer that adds pytest marks to test functions.""" + + def __init__(self, mark_name: str) -> None: + super().__init__() + self.mark_name = mark_name + self.has_pytest_import = False + + def visit_Module(self, node: cst.Module) -> None: + """Check if pytest is already imported.""" + for statement in node.body: + if isinstance(statement, cst.SimpleStatementLine): + for stmt in statement.body: + if isinstance(stmt, cst.Import): + for import_alias in stmt.names: + if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": + self.has_pytest_import = True + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + """Add pytest import if not present.""" + if not self.has_pytest_import: + # Create import statement + import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])]) + # Add import at the beginning + updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body]) + return updated_node + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + """Add pytest mark to test functions.""" + # Check if the mark already exists + for decorator in updated_node.decorators: + if self.is_pytest_mark(decorator.decorator, self.mark_name): + return updated_node + + # Create the pytest mark decorator + mark_decorator = self.create_pytest_mark() + + # Add the decorator + new_decorators = [*list(updated_node.decorators), mark_decorator] + return updated_node.with_changes(decorators=new_decorators) + + def is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool: + """Check if a decorator is a specific pytest mark.""" + if isinstance(decorator, cst.Attribute): + if ( + isinstance(decorator.value, cst.Attribute) + and isinstance(decorator.value.value, cst.Name) + and decorator.value.value.value == "pytest" + and decorator.value.attr.value == "mark" + and decorator.attr.value == mark_name + ): + return True + elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute): + return self.is_pytest_mark(decorator.func, mark_name) + return False + + def create_pytest_mark(self) -> cst.Decorator: + """Create a pytest mark decorator.""" + # Base: pytest.mark.{mark_name} + mark_attr = cst.Attribute( + value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name) + ) + decorator = mark_attr + return cst.Decorator(decorator=decorator) + + +class AutouseFixtureModifier(cst.CSTTransformer): + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if not has_autouse_fixture(original_node): + return updated_node + + else_block = cst.Else(body=updated_node.body) + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') + yield_statement = cst.parse_statement("yield") + if_body = cst.IndentedBlock(body=[yield_statement]) + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) + + +def disable_autouse(test_path: Path) -> str: + file_content = test_path.read_text(encoding="utf-8") + module = cst.parse_module(file_content) + add_request_argument = AddRequestArgument() + disable_autouse_fixture = AutouseFixtureModifier() + modified_module = module.visit(add_request_argument) + modified_module = modified_module.visit(disable_autouse_fixture) + test_path.write_text(modified_module.code, encoding="utf-8") + return file_content + + +def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, str]: + # find fixutre definition in conftetst.py (the one closest to the test) + # get fixtures present in override-fixtures in pyproject.toml + # add if marker closest return + file_content_map = {} + conftest_files = find_conftest_files(test_paths) + for cf_file in conftest_files: + # iterate over all functions in the file + # if function has autouse fixture, modify function to bypass with custom marker + original_content = disable_autouse(cf_file) + file_content_map[cf_file] = original_content + return file_content_map + + +# # reuse line profiler utils to add decorator and import to test fns +def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: + for test_path in test_paths: + # read file + file_content = test_path.read_text(encoding="utf-8") + module = cst.parse_module(file_content) + importadder = ImportAdder("import pytest") + modified_module = module.visit(importadder) + modified_module = cst.parse_module(sort_imports(code=modified_module.code, float_to_top=True)) + pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = modified_module.visit(pytest_mark_adder) + test_path.write_text(modified_module.code, encoding="utf-8") + + +def replace_functions_in_file( + source_code: str, + original_function_names: list[str], + optimized_code: str, + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], +) -> str: + parsed_function_names = [] + for original_function_name in original_function_names: + if original_function_name.count(".") == 0: + class_name, function_name = None, original_function_name + elif original_function_name.count(".") == 1: + class_name, function_name = original_function_name.split(".") + else: + msg = f"Unable to find {original_function_name}. Returning unchanged source code." + logger.error(msg) + return source_code + parsed_function_names.append((class_name, function_name)) + + # Collect functions from optimized code without using MetadataWrapper + optimized_module = cst.parse_module(optimized_code) + modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {} + new_functions: list[cst.FunctionDef] = [] + new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) + new_classes: list[cst.ClassDef] = [] + modified_init_functions: dict[str, cst.FunctionDef] = {} + + function_names_set = set(parsed_function_names) + + for node in optimized_module.body: + if isinstance(node, cst.FunctionDef): + key = (None, node.name.value) + if key in function_names_set: + modified_functions[key] = node + elif preexisting_objects and (node.name.value, ()) not in preexisting_objects: + new_functions.append(node) + + elif isinstance(node, cst.ClassDef): + class_name = node.name.value + parents = (FunctionParent(name=class_name, type="ClassDef"),) + + if (class_name, ()) not in preexisting_objects: + new_classes.append(node) + + for child in node.body.body: + if isinstance(child, cst.FunctionDef): + method_key = (class_name, child.name.value) + if method_key in function_names_set: + modified_functions[method_key] = child + elif ( + child.name.value == "__init__" + and preexisting_objects + and (class_name, ()) in preexisting_objects + ): + modified_init_functions[class_name] = child + elif preexisting_objects and (child.name.value, parents) not in preexisting_objects: + new_class_functions[class_name].append(child) + + original_module = cst.parse_module(source_code) + + max_function_index = None + max_class_index = None + for index, _node in enumerate(original_module.body): + if isinstance(_node, cst.FunctionDef): + max_function_index = index + if isinstance(_node, cst.ClassDef): + max_class_index = index + + new_body: list[cst.CSTNode] = [] + existing_class_names = set() + + for node in original_module.body: + if isinstance(node, cst.FunctionDef): + key = (None, node.name.value) + if key in modified_functions: + modified_func = modified_functions[key] + new_body.append(node.with_changes(body=modified_func.body, decorators=modified_func.decorators)) + else: + new_body.append(node) + + elif isinstance(node, cst.ClassDef): + class_name = node.name.value + existing_class_names.add(class_name) + + new_members: list[cst.CSTNode] = [] + for child in node.body.body: + if isinstance(child, cst.FunctionDef): + key = (class_name, child.name.value) + if key in modified_functions: + modified_func = modified_functions[key] + new_members.append( + child.with_changes(body=modified_func.body, decorators=modified_func.decorators) + ) + elif child.name.value == "__init__" and class_name in modified_init_functions: + new_members.append(modified_init_functions[class_name]) + else: + new_members.append(child) + else: + new_members.append(child) + + if class_name in new_class_functions: + new_members.extend(new_class_functions[class_name]) + + new_body.append(node.with_changes(body=node.body.with_changes(body=new_members))) + else: + new_body.append(node) + + if new_classes: + unique_classes = [nc for nc in new_classes if nc.name.value not in existing_class_names] + if unique_classes: + new_classes_insertion_idx = ( + max_class_index if max_class_index is not None else find_insertion_index_after_imports(original_module) + ) + new_body = list( + chain(new_body[:new_classes_insertion_idx], unique_classes, new_body[new_classes_insertion_idx:]) + ) + + if new_functions: + if max_function_index is not None: + new_body = [*new_body[: max_function_index + 1], *new_functions, *new_body[max_function_index + 1 :]] + elif max_class_index is not None: + new_body = [*new_body[: max_class_index + 1], *new_functions, *new_body[max_class_index + 1 :]] + else: + new_body = [*new_functions, *new_body] + + updated_module = original_module.with_changes(body=new_body) + return updated_module.code + + +def replace_functions_and_add_imports( + source_code: str, + function_names: list[str], + optimized_code: str, + module_abspath: Path, + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], + project_root_path: Path, +) -> str: + return add_needed_imports_from_module( + optimized_code, + replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects), + module_abspath, + module_abspath, + project_root_path, + ) + + +def replace_function_definitions_in_module( + function_names: list[str], + optimized_code: CodeStringsMarkdown, + module_abspath: Path, + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], + project_root_path: Path, + should_add_global_assignments: bool = True, +) -> bool: + source_code: str = module_abspath.read_text(encoding="utf8") + code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) + + new_code: str = replace_functions_and_add_imports( + # adding the global assignments before replacing the code, not after + # because of an "edge case" where the optimized code intoduced a new import and a global assignment using that import + # and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet) + # this was added at https://github.com/codeflash-ai/codeflash/pull/448 + add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code, + function_names, + code_to_apply, + module_abspath, + preexisting_objects, + project_root_path, + ) + if is_zero_diff(source_code, new_code): + return False + module_abspath.write_text(new_code, encoding="utf8") + return True + + +def is_zero_diff(original_code: str, new_code: str) -> bool: + return normalize_code(original_code) == normalize_code(new_code) diff --git a/src/codeflash_python/static_analysis/code_replacer_base.py b/src/codeflash_python/static_analysis/code_replacer_base.py new file mode 100644 index 000000000..a26315cfe --- /dev/null +++ b/src/codeflash_python/static_analysis/code_replacer_base.py @@ -0,0 +1,39 @@ +"""Code replacement utilities for swapping function definitions.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from codeflash_python.models.models import CodeStringsMarkdown + + +def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str: + file_to_code_context = optimized_code.file_to_path() + module_optimized_code = file_to_code_context.get(str(relative_path)) + if module_optimized_code is not None: + return module_optimized_code + + # Fallback 1: single code block with no file path + if "None" in file_to_code_context and len(file_to_code_context) == 1: + logger.debug("Using code block with None file_path for %s", relative_path) + return file_to_code_context["None"] + + # Fallback 2: match by filename (basename) -- the LLM sometimes returns a different + # directory prefix but the correct filename + target_name = relative_path.name + basename_matches = [ + code for path, code in file_to_code_context.items() if path != "None" and Path(path).name == target_name + ] + if len(basename_matches) == 1: + logger.debug("Using basename-matched code block for %s", relative_path) + return basename_matches[0] + + logger.warning( + "Optimized code not found for %s, existing files are %s", relative_path, list(file_to_code_context.keys()) + ) + return "" diff --git a/src/codeflash_python/static_analysis/concolic_utils.py b/src/codeflash_python/static_analysis/concolic_utils.py new file mode 100644 index 000000000..0b94dcadc --- /dev/null +++ b/src/codeflash_python/static_analysis/concolic_utils.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import ast +import re +import subprocess +import uuid + +import sentry_sdk + +from codeflash_python.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir +from codeflash_python.code_utils.shell_utils import make_env_with_project_root + +# Known CrossHair limitations that produce invalid Python syntax in generated tests: +# - "" - higher-order functions returning nested functions +# - " object at 0x" - objects with default __repr__ +# - "", " object at 0x", " bool: + try: + ast.parse(test_code) + except SyntaxError: + is_known_limitation = any(pattern in test_code for pattern in CROSSHAIR_KNOWN_LIMITATION_PATTERNS) + if not is_known_limitation: + sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}") + return False + + temp_path = (codeflash_temp_dir / f"concolic_test_{uuid.uuid4().hex}.py").resolve() + temp_path.write_text(test_code, encoding="utf-8") + + try: + result = subprocess.run( + [SAFE_SYS_EXECUTABLE, "-m", "pytest", "-x", "-q", temp_path.as_posix()], + check=False, + capture_output=True, + text=True, + cwd=project_root, + timeout=10, + env=make_env_with_project_root(project_root) if project_root else None, + ) + except (subprocess.TimeoutExpired, Exception): + return False + else: + return result.returncode == 0 + finally: + temp_path.unlink(missing_ok=True) + + +class AssertCleanup: + def transform_asserts(self, code: str) -> str: + lines = code.splitlines() + result_lines = [] + + for line in lines: + transformed = self.transform_assert_line(line) + result_lines.append(transformed if transformed is not None else line) + + return "\n".join(result_lines) + + def transform_assert_line(self, line: str) -> str | None: + indent = line[: len(line) - len(line.lstrip())] + + assert_match = self.assert_re.match(line) + if assert_match: + expression = assert_match.group(1).strip() + if expression.startswith("not "): + return f"{indent}{expression}" + + expression = expression.rstrip(",;") + return f"{indent}{expression}" + + unittest_match = self.unittest_re.match(line) + if unittest_match: + indent, _assert_method, args = unittest_match.groups() + + if args: + arg_parts = self.first_top_level_arg(args) + if arg_parts: + return f"{indent}{arg_parts}" + + return None + + def __init__(self) -> None: + # Pre-compiling regular expressions for faster execution + self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") + self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") + + def first_top_level_arg(self, args: str) -> str: + depth = 0 + for i, ch in enumerate(args): + if ch in "([{": + depth += 1 + elif ch in ")]}": + depth -= 1 + elif ch == "," and depth == 0: + return args[:i].strip() + return args.strip() + + +def clean_concolic_tests(test_suite_code: str) -> str: + try: + tree = ast.parse(test_suite_code) + can_parse = True + except Exception: + can_parse = False + tree = None + + if not can_parse or tree is None: + return AssertCleanup().transform_asserts(test_suite_code) + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): + new_body: list[ast.stmt] = [] + for stmt in node.body: + if isinstance(stmt, ast.Assert): + if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call): + new_body.append(ast.Expr(value=stmt.test.left)) + else: + new_body.append(stmt) + else: + new_body.append(stmt) + node.body = new_body + + return ast.unparse(tree).strip() diff --git a/src/codeflash_python/static_analysis/coverage_utils.py b/src/codeflash_python/static_analysis/coverage_utils.py new file mode 100644 index 000000000..8adc3858c --- /dev/null +++ b/src/codeflash_python/static_analysis/coverage_utils.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import ast +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from codeflash_python.code_utils.code_utils import get_run_tmp_file + +if TYPE_CHECKING: + from codeflash_python.models.models import CodeOptimizationContext + + +def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: + """Extract the single dependent function from the code context excluding the main function.""" + dependent_functions = set() + + # Compare using bare name since AST extracts bare function names + bare_main = main_function.rsplit(".", 1)[-1] if "." in main_function else main_function + + for code_string in code_context.testgen_context.code_strings: + # Quick heuristic: skip parsing entirely if there is no 'def' token, + # since no function definitions can be present without it. + if "def" not in code_string.code: + continue + + ast_tree = ast.parse(code_string.code) + # Add function names directly, skipping the bare main name. + for node in ast_tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + name = node.name + if name == bare_main: + continue + dependent_functions.add(name) + # If more than one dependent function (other than the main) is found, + # we can return False early since the final result cannot be a single name. + if len(dependent_functions) > 1: + return False + + if not dependent_functions: + return False + + if len(dependent_functions) != 1: + return False + + return build_fully_qualified_name(dependent_functions.pop(), code_context) + + +def build_fully_qualified_name(function_name: str, code_context: CodeOptimizationContext) -> str: + # If the name is already qualified (contains a dot), return as-is + if "." in function_name: + return function_name + full_name = function_name + for obj_name, parents in code_context.preexisting_objects: + if obj_name == function_name: + for parent in parents: + if parent.type == "ClassDef": + full_name = f"{parent.name}.{full_name}" + break + return full_name + + +def generate_candidates(source_code_path: Path) -> set[str]: + """Generate all the possible candidates for coverage data based on the source code path.""" + candidates = set() + # Add the filename as a candidate + name = source_code_path.name + candidates.add(name) + + # Precompute parts for efficient candidate path construction + parts = source_code_path.parts + n = len(parts) + + # Walk up the directory structure without creating Path objects or repeatedly converting to posix + last_added = name + # Start from the last parent and move up to the root, exclusive (skip the root itself) + for i in range(n - 2, 0, -1): + # Combine the ith part with the accumulated path (last_added) + candidate_path = f"{parts[i]}/{last_added}" + candidates.add(candidate_path) + last_added = candidate_path + + # Add the absolute posix path as a candidate + candidates.add(source_code_path.as_posix()) + return candidates + + +def prepare_coverage_files() -> tuple[Path, Path]: + """Prepare coverage configuration and output files.""" + coverage_database_file = get_run_tmp_file(Path(".coverage")) + coveragercfile = get_run_tmp_file(Path(".coveragerc")) + coveragerc_content = f"[run]\n branch = True\ndata_file={coverage_database_file}\n" + coveragercfile.write_text(coveragerc_content) + return coverage_database_file, coveragercfile diff --git a/src/codeflash_python/static_analysis/global_code_transforms.py b/src/codeflash_python/static_analysis/global_code_transforms.py new file mode 100644 index 000000000..b923b0fa0 --- /dev/null +++ b/src/codeflash_python/static_analysis/global_code_transforms.py @@ -0,0 +1,503 @@ +from __future__ import annotations + +from itertools import chain +from typing import cast + +import libcst as cst + + +class GlobalFunctionCollector(cst.CSTVisitor): + """Collects all module-level function definitions (not inside classes or other functions).""" + + def __init__(self) -> None: + super().__init__() + self.functions: dict[str, cst.FunctionDef] = {} + self.function_order: list[str] = [] + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + if self.scope_depth == 0: + # Module-level function + name = node.name.value + self.functions[name] = node + if name not in self.function_order: + self.function_order.append(name) + self.scope_depth += 1 + return True + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.scope_depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.scope_depth += 1 + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.scope_depth -= 1 + + +class GlobalFunctionTransformer(cst.CSTTransformer): + """Transforms/adds module-level functions from the new file to the original file.""" + + def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None: + super().__init__() + self.new_functions = new_functions + self.new_function_order = new_function_order + self.processed_functions: set[str] = set() + self.scope_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + self.scope_depth -= 1 + if self.scope_depth > 0: + return updated_node + + # Check if this is a module-level function we need to replace + name = original_node.name.value + if name in self.new_functions: + self.processed_functions.add(name) + return self.new_functions[name] + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + self.scope_depth -= 1 + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # Add any new functions that weren't in the original file + new_statements = list(updated_node.body) + + functions_to_append = [ + self.new_functions[name] + for name in self.new_function_order + if name not in self.processed_functions and name in self.new_functions + ] + + if functions_to_append: + # Find the position of the last function or class definition + insert_index = find_insertion_index_after_imports(updated_node) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 + + # Add empty line before each new function + function_nodes = [] + for func in functions_to_append: + func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines]) + function_nodes.append(func_with_empty_line) + + new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:])) + + return updated_node.with_changes(body=new_statements) + + +def collect_referenced_names(node: cst.CSTNode) -> set[str]: + """Collect all names referenced in a CST node using recursive traversal.""" + names: set[str] = set() + + def collect(n: cst.CSTNode) -> None: + if isinstance(n, cst.Name): + names.add(n.value) + # Recursively process all children + for child in n.children: + collect(child) + + collect(node) + return names + + +class GlobalAssignmentCollector(cst.CSTVisitor): + """Collects all global assignment statements.""" + + def __init__(self) -> None: + super().__init__() + self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {} + self.assignment_order: list[str] = [] + # Track scope depth to identify global assignments + self.scope_depth = 0 + self.if_else_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.scope_depth += 1 + return True + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.scope_depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.scope_depth += 1 + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.scope_depth -= 1 + + def visit_If(self, node: cst.If) -> bool | None: + self.if_else_depth += 1 + return True + + def leave_If(self, original_node: cst.If) -> None: + self.if_else_depth -= 1 + + def visit_Else(self, node: cst.Else) -> bool | None: + # Else blocks are already counted as part of the if statement + return True + + def visit_Assign(self, node: cst.Assign) -> bool | None: + # Only process global assignments (not inside functions, classes, etc.) + if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level + for target in node.targets: + if isinstance(target.target, cst.Name): + name = target.target.value + self.assignments[name] = node + if name not in self.assignment_order: + self.assignment_order.append(name) + return True + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + # Handle annotated assignments like: _CACHE: Dict[str, int] = {} + # Only process module-level annotated assignments with a value + if ( + self.scope_depth == 0 + and self.if_else_depth == 0 + and isinstance(node.target, cst.Name) + and node.value is not None + ): + name = node.target.value + self.assignments[name] = node + if name not in self.assignment_order: + self.assignment_order.append(name) + return True + + +def find_insertion_index_after_imports(node: cst.Module) -> int: + """Find the position of the last import statement in the top-level of the module.""" + insert_index = 0 + for i, stmt in enumerate(node.body): + is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any( + isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body + ) + + is_conditional_import = isinstance(stmt, cst.If) and all( + isinstance(inner, cst.SimpleStatementLine) + and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body) + for inner in stmt.body.body + ) + + if is_top_level_import or is_conditional_import: + insert_index = i + 1 + + # Stop scanning once we reach a class or function definition. + # Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file. + # Without this check, a stray import later in the file + # would incorrectly shift our insertion index below actual code definitions. + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + break + + return insert_index + + +class GlobalAssignmentTransformer(cst.CSTTransformer): + """Transforms global assignments in the original file with those from the new file.""" + + def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None: + super().__init__() + self.new_assignments = new_assignments + self.new_assignment_order = new_assignment_order + self.processed_assignments: set[str] = set() + self.scope_depth = 0 + self.if_else_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + self.scope_depth -= 1 + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + self.scope_depth -= 1 + return updated_node + + def visit_If(self, node: cst.If) -> None: + self.if_else_depth += 1 + + def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: + self.if_else_depth -= 1 + return updated_node + + def visit_Else(self, node: cst.Else) -> None: + # Else blocks are already counted as part of the if statement + pass + + def leave_Assign( + self, original_node: cst.Assign, updated_node: cst.Assign + ) -> cst.Assign | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: + if self.scope_depth > 0 or self.if_else_depth > 0: + return updated_node + + # Check if this is a global assignment we need to replace + for target in original_node.targets: + if isinstance(target.target, cst.Name): + name = target.target.value + if name in self.new_assignments: + self.processed_assignments.add(name) + return cast("cst.Assign", self.new_assignments[name]) + + return updated_node + + def leave_AnnAssign( + self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign + ) -> cst.AnnAssign | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: + if self.scope_depth > 0 or self.if_else_depth > 0: + return updated_node + + # Check if this is a global annotated assignment we need to replace + if isinstance(original_node.target, cst.Name): + name = original_node.target.value + if name in self.new_assignments: + self.processed_assignments.add(name) + return cast("cst.AnnAssign", self.new_assignments[name]) + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # Add any new assignments that weren't in the original file + new_statements = list(updated_node.body) + + # Find assignments to append + assignments_to_append = [ + (name, self.new_assignments[name]) + for name in self.new_assignment_order + if name not in self.processed_assignments and name in self.new_assignments + ] + + if not assignments_to_append: + return updated_node.with_changes(body=new_statements) + + # Collect all class and function names defined in the module + # These are the names that assignments might reference + module_defined_names: set[str] = set() + for stmt in new_statements: + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + module_defined_names.add(stmt.name.value) + + # Partition assignments: those that reference module definitions go at the end, + # those that don't can go right after imports + assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + + for name, assignment in assignments_to_append: + # Get the value being assigned + if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None: + value_node = assignment.value + else: + # No value to analyze, safe to place after imports + assignments_after_imports.append((name, assignment)) + continue + + # Collect names referenced in the assignment value + referenced_names = collect_referenced_names(value_node) + + # Check if any referenced names are module-level definitions + if referenced_names & module_defined_names: + # This assignment references a class/function, place it after definitions + assignments_after_definitions.append((name, assignment)) + else: + # Safe to place right after imports + assignments_after_imports.append((name, assignment)) + + # Insert assignments that don't depend on module definitions right after imports + if assignments_after_imports: + insert_index = find_insertion_index_after_imports(updated_node) + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for _, assignment in assignments_after_imports + ] + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + # Insert assignments that depend on module definitions after all class/function definitions + if assignments_after_definitions: + # Find the position after the last function or class definition + insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements)) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 + + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for _, assignment in assignments_after_definitions + ] + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + return updated_node.with_changes(body=new_statements) + + +class GlobalStatementTransformer(cst.CSTTransformer): + """Transformer that appends global statements at the end of the module. + + This ensures that global statements (like function calls at module level) are placed + after all functions, classes, and assignments they might reference, preventing NameError + at module load time. + + This transformer should be run LAST after GlobalFunctionTransformer and + GlobalAssignmentTransformer have already added their content. + """ + + def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None: + super().__init__() + self.global_statements = global_statements + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + if not self.global_statements: + return updated_node + + new_statements = list(updated_node.body) + + # Add empty line before each statement for readability + statement_lines = [ + stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements + ] + + # Append statements at the end of the module + # This ensures they come after all functions, classes, and assignments + new_statements.extend(statement_lines) + + return updated_node.with_changes(body=new_statements) + + +class GlobalStatementCollector(cst.CSTVisitor): + """Visitor that collects all global statements (excluding imports and functions/classes).""" + + def __init__(self) -> None: + super().__init__() + self.global_statements = [] + self.in_function_or_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + # Don't visit inside classes + self.in_function_or_class = True + return False + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.in_function_or_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + # Don't visit inside functions + self.in_function_or_class = True + return False + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.in_function_or_class = False + + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: + if not self.in_function_or_class: + for statement in node.body: + # Skip imports and assignments (both regular and annotated) + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): + self.global_statements.append(node) + break + + +class LastImportFinder(cst.CSTVisitor): + """Finds the position of the last import statement in the module.""" + + def __init__(self) -> None: + super().__init__() + self.last_import_line = 0 + self.current_line = 0 + + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: + self.current_line += 1 + for statement in node.body: + if isinstance(statement, (cst.Import, cst.ImportFrom)): + self.last_import_line = self.current_line + + +def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]: + """Extract global statements from source code.""" + module = cst.parse_module(source_code) + collector = GlobalStatementCollector() + module.visit(collector) + return module, collector.global_statements + + +def find_last_import_line(target_code: str) -> int: + """Find the line number of the last import statement.""" + module = cst.parse_module(target_code) + finder = LastImportFinder() + module.visit(finder) + return finder.last_import_line + + +def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: + src_module, new_added_global_statements = extract_global_statements(src_module_code) + dst_module, existing_global_statements = extract_global_statements(dst_module_code) + + unique_global_statements = [] + for stmt in new_added_global_statements: + if any( + stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements + ): + continue + unique_global_statements.append(stmt) + + # Reuse already-parsed dst_module + original_module = dst_module + + # Parse the src_module_code once only (already done above: src_module) + # Collect assignments from the new file + new_assignment_collector = GlobalAssignmentCollector() + src_module.visit(new_assignment_collector) + + # Collect module-level functions from both source and destination + src_function_collector = GlobalFunctionCollector() + src_module.visit(src_function_collector) + + dst_function_collector = GlobalFunctionCollector() + original_module.visit(dst_function_collector) + + # Filter out functions that already exist in the destination (only add truly new functions) + new_functions = { + name: func + for name, func in src_function_collector.functions.items() + if name not in dst_function_collector.functions + } + new_function_order = [name for name in src_function_collector.function_order if name in new_functions] + + # If there are no assignments, no new functions, and no global statements, return unchanged + if not new_assignment_collector.assignments and not new_functions and not unique_global_statements: + return dst_module_code + + # The order of transformations matters: + # 1. Functions first - so assignments and statements can reference them + # 2. Assignments second - so they come after functions but before statements + # 3. Global statements last - so they can reference both functions and assignments + + # Transform functions if any + if new_functions: + function_transformer = GlobalFunctionTransformer(new_functions, new_function_order) + original_module = original_module.visit(function_transformer) + + # Transform assignments if any + if new_assignment_collector.assignments: + transformer = GlobalAssignmentTransformer( + new_assignment_collector.assignments, new_assignment_collector.assignment_order + ) + original_module = original_module.visit(transformer) + + # Insert global statements (like function calls at module level) LAST, + # after all functions and assignments are added, to ensure they can reference any + # functions or variables defined in the module + if unique_global_statements: + statement_transformer = GlobalStatementTransformer(unique_global_statements) + original_module = original_module.visit(statement_transformer) + + return original_module.code diff --git a/src/codeflash_python/static_analysis/import_analysis.py b/src/codeflash_python/static_analysis/import_analysis.py new file mode 100644 index 000000000..6cea2de73 --- /dev/null +++ b/src/codeflash_python/static_analysis/import_analysis.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import ast +import logging +from typing import TYPE_CHECKING + +import libcst as cst +from libcst.codemod import CodemodContext +from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor +from libcst.helpers import calculate_module_and_package + +if TYPE_CHECKING: + from pathlib import Path + + from libcst.helpers import ModuleNameAndPackage + + from codeflash_python.models.models import FunctionSource + +logger = logging.getLogger("codeflash_python") + + +class DottedImportCollector(cst.CSTVisitor): + """Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`. + + Examples + -------- + import os ==> "os" + import dbt.adapters.factory ==> "dbt.adapters.factory" + from pathlib import Path ==> "pathlib.Path" + from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter" + from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional" + from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps" + + """ + + def __init__(self) -> None: + self.imports: set[str] = set() + self.depth = 0 # top-level + + def get_full_dotted_name(self, expr: cst.BaseExpression) -> str: + if isinstance(expr, cst.Name): + return expr.value + if isinstance(expr, cst.Attribute): + return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}" + return "" + + def collect_imports_from_block(self, block: cst.IndentedBlock | cst.Module) -> None: + for statement in block.body: + if isinstance(statement, cst.SimpleStatementLine): + for child in statement.body: + if isinstance(child, cst.Import): + for alias in child.names: + module = self.get_full_dotted_name(alias.name) + asname = alias.asname.name.value if alias.asname else alias.name.value # type: ignore[attr-defined] + if isinstance(asname, cst.Attribute): + self.imports.add(module) + else: + self.imports.add(module if module == asname else f"{module}.{asname}") + + elif isinstance(child, cst.ImportFrom): + if child.module is None: + continue + module = self.get_full_dotted_name(child.module) + if isinstance(child.names, cst.ImportStar): + continue + for alias in child.names: + if isinstance(alias, cst.ImportAlias): + name = alias.name.value + asname = alias.asname.name.value if alias.asname else name # type: ignore[attr-defined] + self.imports.add(f"{module}.{asname}") + + def visit_Module(self, node: cst.Module) -> None: + self.depth = 0 + self.collect_imports_from_block(node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + self.depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.depth -= 1 + + def visit_If(self, node: cst.If) -> None: + if self.depth == 0 and isinstance(node.body, (cst.IndentedBlock, cst.Module)): + self.collect_imports_from_block(node.body) + + def visit_Try(self, node: cst.Try) -> None: + if self.depth == 0 and isinstance(node.body, (cst.IndentedBlock, cst.Module)): + self.collect_imports_from_block(node.body) + + +class FutureAliasedImportTransformer(cst.CSTTransformer): + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: + import libcst.matchers as m + + if ( + (updated_node_module := updated_node.module) + and updated_node_module.value == "__future__" + and not isinstance(updated_node.names, cst.ImportStar) + and all(m.matches(name, m.ImportAlias()) for name in updated_node.names) + ): + if names := [name for name in updated_node.names if name.asname is None]: + return updated_node.with_changes(names=names) + return cst.RemoveFromParent() + return updated_node + + +def delete___future___aliased_imports(module_code: str) -> str: + return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code + + +def resolve_star_import(module_name: str, project_root: Path) -> set[str]: + try: + module_path = module_name.replace(".", "/") + possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"] + + module_file = None + for path in possible_paths: + if path.exists(): + module_file = path + break + + if module_file is None: + logger.warning("Could not find module file for %s, skipping star import resolution", module_name) + return set() + + with module_file.open(encoding="utf8") as f: + module_code = f.read() + + tree = ast.parse(module_code) + + all_names = None + for node in ast.walk(tree): + if ( + isinstance(node, ast.Assign) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == "__all__" + ): + if isinstance(node.value, (ast.List, ast.Tuple)): + all_names = [] + for elt in node.value.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + all_names.append(elt.value) + elif isinstance(elt, ast.Str): # type: ignore[deprecated] # Python < 3.8 compatibility + all_names.append(elt.s) + break + + if all_names is not None: + return set(all_names) + + public_names = set() + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + if not node.name.startswith("_"): + public_names.add(node.name) + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and not target.id.startswith("_"): + public_names.add(target.id) + elif isinstance(node, ast.AnnAssign): + if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"): + public_names.add(node.target.id) + elif isinstance(node, ast.Import) or ( + isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names) + ): + for alias in node.names: + name = alias.asname or alias.name + if not name.startswith("_"): + public_names.add(name) + + return public_names + + except Exception as e: + logger.warning("Error resolving star import for %s: %s", module_name, e) + return set() + + +def add_needed_imports_from_module( + src_module_code: str, + dst_module_code: str | cst.Module, + src_path: Path, + dst_path: Path, + project_root: Path, + helper_functions: list[FunctionSource] | None = None, + helper_functions_fqn: set[str] | None = None, +) -> str: + """Add all needed and used source module code imports to the destination module code, and return it.""" + src_module_code = delete___future___aliased_imports(src_module_code) + if not helper_functions_fqn: + helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])} + + dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code + + src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path) + dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) + + dst_context: CodemodContext = CodemodContext( + filename=src_path.name, + full_module_name=dst_module_and_package.name, + full_package_name=dst_module_and_package.package, + ) + gatherer: GatherImportsVisitor = GatherImportsVisitor( + CodemodContext( + filename=src_path.name, + full_module_name=src_module_and_package.name, + full_package_name=src_module_and_package.package, + ) + ) + try: + src_module = cst.parse_module(src_module_code) + # Exclude function/class bodies so GatherImportsVisitor only sees module-level imports. + # Nested imports (inside functions) are part of function logic and must not be + # scheduled for add/remove — RemoveImportsVisitor would strip them as "unused". + module_level_only = src_module.with_changes( + body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))] + ) + module_level_only.visit(gatherer) + except Exception as e: + logger.exception("Error parsing source module code: %s", e) + return dst_code_fallback + + dotted_import_collector = DottedImportCollector() + if isinstance(dst_module_code, cst.Module): + parsed_dst_module = dst_module_code + parsed_dst_module.visit(dotted_import_collector) + else: + try: + parsed_dst_module = cst.parse_module(dst_module_code) + parsed_dst_module.visit(dotted_import_collector) + except cst.ParserSyntaxError as e: + logger.exception("Syntax error in destination module code: %s", e) + return dst_code_fallback + + try: + for mod in gatherer.module_imports: + # Skip __future__ imports as they cannot be imported directly + # __future__ imports should only be imported with specific objects i.e from __future__ import annotations + if mod == "__future__": + continue + if mod not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod) + RemoveImportsVisitor.remove_unused_import(dst_context, mod) + aliased_objects = set() + for mod, alias_pairs in gatherer.alias_mapping.items(): + for alias_pair in alias_pairs: + if alias_pair[0] and alias_pair[1]: # Both name and alias exist + aliased_objects.add(f"{mod}.{alias_pair[0]}") + + for mod, obj_seq in gatherer.object_mapping.items(): + for obj in obj_seq: + if ( + f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps + ): + continue # Skip adding imports for helper functions already in the context + + if f"{mod}.{obj}" in aliased_objects: + continue + + # Handle star imports by resolving them to actual symbol names + if obj == "*": + resolved_symbols = resolve_star_import(mod, project_root) + logger.debug("Resolved star import from %s: %s", mod, resolved_symbols) + + for symbol in resolved_symbols: + if ( + f"{mod}.{symbol}" not in helper_functions_fqn + and f"{mod}.{symbol}" not in dotted_import_collector.imports + ): + AddImportsVisitor.add_needed_import(dst_context, mod, symbol) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol) + else: + if f"{mod}.{obj}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, obj) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) + except Exception as e: + logger.exception("Error adding imports to destination module code: %s", e) + return dst_code_fallback + + for mod, asname in gatherer.module_aliases.items(): + if not asname: + continue + if f"{mod}.{asname}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname) + + for mod, alias_pairs in gatherer.alias_mapping.items(): + for alias_pair in alias_pairs: + if f"{mod}.{alias_pair[0]}" in helper_functions_fqn: + continue + + if not alias_pair[0] or not alias_pair[1]: + continue + + if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) + + try: + add_imports_visitor = AddImportsVisitor(dst_context) + transformed_module = add_imports_visitor.transform_module(parsed_dst_module) + transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module) + return transformed_module.code.lstrip("\n") + except Exception as e: + logger.exception("Error adding imports to destination module code: %s", e) + return dst_code_fallback diff --git a/src/codeflash_python/static_analysis/line_profile_utils.py b/src/codeflash_python/static_analysis/line_profile_utils.py new file mode 100644 index 000000000..5d8ff603d --- /dev/null +++ b/src/codeflash_python/static_analysis/line_profile_utils.py @@ -0,0 +1,387 @@ +"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License).""" + +from __future__ import annotations + +import ast +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_python.code_utils.code_utils import get_run_tmp_file +from codeflash_python.code_utils.formatter import sort_imports + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + from codeflash_python.models.models import CodeOptimizationContext + +# Known JIT decorators organized by module +# Format: {module_path: {decorator_name, ...}} +JIT_DECORATORS: dict[str, set[str]] = { + "numba": {"jit", "njit", "vectorize", "guvectorize", "stencil", "cfunc", "generated_jit"}, + "numba.cuda": {"jit"}, + "torch": {"compile"}, + "torch.jit": {"script", "trace"}, + "tensorflow": {"function"}, + "jax": {"jit"}, +} + + +class JitDecoratorDetector(ast.NodeVisitor): + """AST visitor that detects JIT compilation decorators considering import aliases.""" + + def __init__(self) -> None: + # Maps local name -> (module, original_name) + # e.g., {"nb": ("numba", None), "my_jit": ("numba", "jit")} + self.import_aliases: dict[str, tuple[str, str | None]] = {} + self.found_jit_decorator = False + + def visit_Import(self, node: ast.Import) -> None: + """Track regular imports like 'import numba' or 'import numba as nb'.""" + for alias in node.names: + # alias.name is the module name, alias.asname is the alias (or None) + local_name = alias.asname if alias.asname else alias.name + # For module imports, we store (module_name, None) to indicate it's a module import + self.import_aliases[local_name] = (alias.name, None) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Track from imports like 'from numba import jit' or 'from numba import jit as my_jit'.""" + if node.module is None: + self.generic_visit(node) + return + + for alias in node.names: + local_name = alias.asname if alias.asname else alias.name + # For from imports, we store (module_name, imported_name) + self.import_aliases[local_name] = (node.module, alias.name) + self.generic_visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Check function decorators for JIT decorators.""" + for decorator in node.decorator_list: + if self.is_jit_decorator(decorator): + self.found_jit_decorator = True + return + self.generic_visit(node) + + def is_jit_decorator(self, node: ast.expr) -> bool: + """Check if a decorator node is a known JIT decorator.""" + # Handle Call nodes (e.g., @jit() or @numba.jit(nopython=True)) + if isinstance(node, ast.Call): + return self.is_jit_decorator(node.func) + + # Handle simple Name nodes (e.g., @jit when imported directly) + if isinstance(node, ast.Name): + return self.check_name_decorator(node.id) + + # Handle Attribute nodes (e.g., @numba.jit or @nb.jit) + if isinstance(node, ast.Attribute): + return self.check_attribute_decorator(node) + + return False + + def check_name_decorator(self, name: str) -> bool: + """Check if a simple name decorator (e.g., @jit) is a JIT decorator.""" + if name not in self.import_aliases: + return False + + module, imported_name = self.import_aliases[name] + + if imported_name is None: + # This is a module import used as decorator (unlikely but possible) + return False + + # Check if this is a known JIT decorator from the module + return self.is_known_jit_decorator(module, imported_name) + + def check_attribute_decorator(self, node: ast.Attribute) -> bool: + """Check if an attribute decorator (e.g., @numba.jit) is a JIT decorator.""" + # Build the full attribute chain + parts = self.get_attribute_parts(node) + if not parts: + return False + + # The first part might be an alias + first_part = parts[0] + rest_parts = parts[1:] + + # Check if first_part is an imported alias + if first_part in self.import_aliases: + module, imported_name = self.import_aliases[first_part] + + if imported_name is None: + # It's a module import (e.g., import numba as nb) + # The full path is module + rest_parts + if rest_parts: + full_module = module + decorator_name = rest_parts[-1] + if len(rest_parts) > 1: + full_module = f"{module}.{'.'.join(rest_parts[:-1])}" + return self.is_known_jit_decorator(full_module, decorator_name) + # It's a from import of something that has attributes + # e.g., from torch import jit; @jit.script + elif rest_parts: + full_module = f"{module}.{imported_name}" + decorator_name = rest_parts[-1] + if len(rest_parts) > 1: + full_module = f"{full_module}.{'.'.join(rest_parts[:-1])}" + return self.is_known_jit_decorator(full_module, decorator_name) + # first_part is used directly (e.g., @numba.jit without import alias) + # Reconstruct the full path + elif rest_parts: + full_module = first_part + if len(rest_parts) > 1: + full_module = f"{first_part}.{'.'.join(rest_parts[:-1])}" + decorator_name = rest_parts[-1] + return self.is_known_jit_decorator(full_module, decorator_name) + + return False + + def get_attribute_parts(self, node: ast.Attribute) -> list[str]: + """Get all parts of an attribute chain (e.g., ['numba', 'cuda', 'jit']).""" + parts = [] + current = node + + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + + if isinstance(current, ast.Name): + parts.append(current.id) + parts.reverse() + return parts + + return [] + + def is_known_jit_decorator(self, module: str, decorator_name: str) -> bool: + """Check if a decorator from a module is a known JIT decorator.""" + if module in JIT_DECORATORS: + return decorator_name in JIT_DECORATORS[module] + return False + + +def contains_jit_decorator(code: str) -> bool: + """Check if the code contains JIT compilation decorators from numba, torch, tensorflow, or jax. + + This function uses AST parsing to accurately detect JIT decorators even when: + - They are imported with aliases (e.g., import numba as nb; @nb.jit) + - They are imported directly (e.g., from numba import jit; @jit) + - They are called with arguments (e.g., @jit(nopython=True)) + """ + try: + tree = ast.parse(code) + except SyntaxError: + return False + + detector = JitDecoratorDetector() + detector.visit(tree) + return detector.found_jit_decorator + + +class LineProfilerDecoratorAdder(cst.CSTTransformer): + """Transformer that adds a decorator to a function with a specific qualified name.""" + + # TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure + def __init__(self, qualified_name: str, decorator_name: str) -> None: + """Initialize the transformer. + + Args: + ---- + qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). + decorator_name: The name of the decorator to add. + + """ + super().__init__() + self.qualified_name_parts = qualified_name.split(".") + self.decorator_name = decorator_name + + # Track our current context path, only add when we encounter a class + self.context_stack = [] + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Track when we enter a class + self.context_stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + # Pop the context when we leave a class + self.context_stack.pop() + return updated_node + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + # Track when we enter a function + self.context_stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # Check if the current context path matches our target qualified name + if self.context_stack == self.qualified_name_parts: + # Check if the decorator is already present + has_decorator = any(self.is_target_decorator(decorator.decorator) for decorator in original_node.decorators) # type: ignore[arg-type] + + # Only add the decorator if it's not already there + if not has_decorator: + new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name)) + + # Add our new decorator to the existing decorators + updated_decorators = [new_decorator, *list(updated_node.decorators)] + updated_node = updated_node.with_changes(decorators=tuple(updated_decorators)) + + # Pop the context when we leave a function + self.context_stack.pop() + return updated_node + + def is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool: + """Check if a decorator matches our target decorator name.""" + if isinstance(decorator_node, cst.Name): + return decorator_node.value == self.decorator_name + if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name): + return decorator_node.func.value == self.decorator_name + return False + + +class ProfileEnableTransformer(cst.CSTTransformer): + def __init__(self, filename: str) -> None: + # Flag to track if we found the import statement + self.found_import = False + # Track indentation of the import statement + self.import_indentation = None + self.filename = filename + + def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: + # Check if this is the line profiler import statement + if ( + isinstance(original_node.module, cst.Name) + and original_node.module.value == "line_profiler" + and not isinstance(original_node.names, cst.ImportStar) + and any( + name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile") # type: ignore[attr-defined] + for name in original_node.names + ) + ): + self.found_import = True + # Get the indentation from the original node + if hasattr(original_node, "leading_lines"): + leading_whitespace = original_node.leading_lines[-1].whitespace if original_node.leading_lines else "" # type: ignore[index,attr-defined] + self.import_indentation = leading_whitespace + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + if not self.found_import: + return updated_node + + # Create a list of statements from the original module + new_body = list(updated_node.body) + + # Find the index of the import statement + import_index = None + for i, stmt in enumerate(new_body): + if isinstance(stmt, cst.SimpleStatementLine): + for small_stmt in stmt.body: + if isinstance(small_stmt, cst.ImportFrom) and ( + isinstance(small_stmt.module, cst.Name) + and small_stmt.module.value == "line_profiler" + and not isinstance(small_stmt.names, cst.ImportStar) + and any( + name.name.value == "profile" + and (not name.asname or name.asname.name.value == "codeflash_line_profile") # type: ignore[attr-defined] + for name in small_stmt.names + ) + ): + import_index = i + break + if import_index is not None: + break + + if import_index is not None: + # Create the new enable statement to insert after the import + enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')") + + # Insert the new statement after the import statement + new_body.insert(import_index + 1, enable_statement) + + # Create a new module with the updated body + return updated_node.with_changes(body=new_body) + + +def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str, decorator_name: str) -> cst.Module: + """Add a decorator to a function with the exact qualified name in the source code. + + Args: + ---- + module: The Python source code as a CST module. + qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). + decorator_name: The name of the decorator to add. + + Returns: + ------- + The modified CST module. + + """ + transformer = LineProfilerDecoratorAdder(qualified_name, decorator_name) + return module.visit(transformer) + + +def add_profile_enable(original_code: str, line_profile_output_file: str) -> str: + module = cst.parse_module(original_code) + transformer = ProfileEnableTransformer(line_profile_output_file) + modified_module = module.visit(transformer) + return modified_module.code + + +class ImportAdder(cst.CSTTransformer): + def __init__(self, import_statement) -> None: + self.import_statement = import_statement + self.has_import = False + + def leave_Module(self, original_node, updated_node): # noqa: ANN201 + # If the import is already there, don't add it again + if self.has_import: + return updated_node + + # Parse the import statement into a CST node + import_node = cst.parse_statement(self.import_statement) + + # Add the import to the module's body + return updated_node.with_changes(body=[import_node, *list(updated_node.body)]) + + def visit_ImportFrom(self, node) -> None: + # Check if the profile is already imported from line_profiler + if node.module and node.module.value == "line_profiler": + for import_alias in node.names: + if import_alias.name.value == "profile": + self.has_import = True + + +def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext) -> Path: + """Add a profile decorator to a function in a Python file and all its helper functions.""" + # self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + # grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile + file_paths = defaultdict(list) + line_profile_output_file = get_run_tmp_file(Path("baseline_lprof")) + file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name) + for elem in code_context.helper_functions: + file_paths[elem.file_path].append(elem.qualified_name) + for file_path, fns_present in file_paths.items(): + # open file + file_contents = file_path.read_text("utf-8") + # parse to cst + module_node = cst.parse_module(file_contents) + for fn_name in fns_present: + # add decorator + module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile") + # add imports + # Create a transformer to add the import + transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile") + # Apply the transformer to add the import + module_node = module_node.visit(transformer) + modified_code = sort_imports(code=module_node.code, float_to_top=True) + # write to file + with file_path.open("w", encoding="utf-8") as file: + file.write(modified_code) + # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files + file_contents = function_to_optimize.file_path.read_text("utf-8") + modified_code = add_profile_enable(file_contents, line_profile_output_file.as_posix()) + function_to_optimize.file_path.write_text(modified_code, "utf-8") + return line_profile_output_file diff --git a/src/codeflash_python/static_analysis/numerical_detection.py b/src/codeflash_python/static_analysis/numerical_detection.py new file mode 100644 index 000000000..1021a3c77 --- /dev/null +++ b/src/codeflash_python/static_analysis/numerical_detection.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import ast +from importlib.util import find_spec + +has_numba = find_spec("numba") is not None + +NUMERICAL_MODULES = frozenset({"numpy", "torch", "numba", "jax", "tensorflow", "math", "scipy"}) +# Modules that require numba to be installed for optimization +NUMBA_REQUIRED_MODULES = frozenset({"numpy", "math", "scipy"}) + + +class NumericalUsageChecker(ast.NodeVisitor): + """AST visitor that checks if a function uses numerical computing libraries.""" + + def __init__(self, numerical_names: set[str]) -> None: + self.numerical_names = numerical_names + self.found_numerical = False + + def visit_Call(self, node: ast.Call) -> None: + """Check function calls for numerical library usage.""" + if self.found_numerical: + return + call_name = self.get_root_name(node.func) + if call_name and call_name in self.numerical_names: + self.found_numerical = True + return + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute) -> None: + """Check attribute access for numerical library usage.""" + if self.found_numerical: + return + root_name = self.get_root_name(node) + if root_name and root_name in self.numerical_names: + self.found_numerical = True + return + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + """Check name references for numerical library usage.""" + if self.found_numerical: + return + if node.id in self.numerical_names: + self.found_numerical = True + + def get_root_name(self, node: ast.expr) -> str | None: + """Get the root name from an expression (e.g., 'np' from 'np.array').""" + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return self.get_root_name(node.value) + return None + + +def collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]: + """Collect names that reference numerical computing libraries from imports. + + Returns: + A tuple of (numerical_names, modules_used) where: + - numerical_names: set of names/aliases that reference numerical libraries + - modules_used: set of actual module names (e.g., "numpy", "math") being imported + + """ + numerical_names: set[str] = set() + modules_used: set[str] = set() + + stack: list[ast.AST] = [tree] + while stack: + node = stack.pop() + if isinstance(node, ast.Import): + for alias in node.names: + # import numpy or import numpy as np + module_root = alias.name.split(".")[0] + if module_root in NUMERICAL_MODULES: + # Use the alias if present, otherwise the module name + name = alias.asname if alias.asname else alias.name.split(".")[0] + numerical_names.add(name) + modules_used.add(module_root) + elif isinstance(node, ast.ImportFrom) and node.module: + module_root = node.module.split(".")[0] + if module_root in NUMERICAL_MODULES: + # from numpy import array, zeros as z + for alias in node.names: + if alias.name == "*": + # Can't track star imports, but mark the module as numerical + numerical_names.add(module_root) + else: + name = alias.asname if alias.asname else alias.name + numerical_names.add(name) + modules_used.add(module_root) + else: + stack.extend(ast.iter_child_nodes(node)) + + return numerical_names, modules_used + + +def find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.FunctionDef | None: + """Find a function node in the AST given its qualified name parts. + + Note: This function only finds regular (sync) functions, not async functions. + + Args: + tree: The parsed AST module + name_parts: List of name parts, e.g., ["ClassName", "method_name"] or ["function_name"] + + Returns: + The function node if found, None otherwise + + """ + if not name_parts: + return None + + if len(name_parts) == 1: + # Top-level function + func_name = name_parts[0] + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name == func_name: + return node + return None + + if len(name_parts) == 2: + # Class method: ClassName.method_name + class_name, method_name = name_parts + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for class_node in node.body: + if isinstance(class_node, ast.FunctionDef) and class_node.name == method_name: + return class_node + return None + + return None + + +def is_numerical_code(code_string: str, function_name: str | None = None) -> bool: + """Check if a function uses numerical computing libraries. + + Detects usage of numpy, torch, numba, jax, tensorflow, scipy, and math libraries + within the specified function. + + Note: For math, numpy, and scipy usage, this function returns True only if numba + is installed in the environment, as numba is required to optimize such code. + + Args: + code_string: The entire file's content as a string + function_name: The name of the function to check. Can be a simple name like "foo" + or a qualified name like "ClassName.method_name" for methods, + staticmethods, or classmethods. + + Returns: + True if the function uses any numerical computing library functions, False otherwise. + Returns False for math/numpy/scipy usage if numba is not installed. + + Examples: + >>> code = ''' + ... import numpy as np + ... def process_data(x): + ... return np.sum(x) + ... ''' + >>> is_numerical_code(code, "process_data") # Returns True only if numba is installed + True + + >>> code = ''' + ... def simple_func(x): + ... return x + 1 + ... ''' + >>> is_numerical_code(code, "simple_func") + False + + """ + try: + tree = ast.parse(code_string) + except SyntaxError: + return False + + # Collect names that reference numerical modules from imports + numerical_names, modules_used = collect_numerical_imports(tree) + + if not function_name: + # Return True if modules used and (numba available or modules don't all require numba) + return bool(modules_used) and (has_numba or not modules_used.issubset(NUMBA_REQUIRED_MODULES)) + + # Split the function name to handle class methods + name_parts = function_name.split(".") + + # Find the target function node + target_function = find_function_node(tree, name_parts) + if target_function is None: + return False + + # Check if the function body uses any numerical library + checker = NumericalUsageChecker(numerical_names) + checker.visit(target_function) + + if not checker.found_numerical: + return False + + # If numba is not installed and all modules used require numba for optimization, + # return False since we can't optimize this code + return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES)) diff --git a/src/codeflash_python/static_analysis/reference_analysis.py b/src/codeflash_python/static_analysis/reference_analysis.py new file mode 100644 index 000000000..c19e775e1 --- /dev/null +++ b/src/codeflash_python/static_analysis/reference_analysis.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +import ast +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import jedi + +from codeflash_python.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + +logger = logging.getLogger("codeflash_python") + + +@dataclass +class FunctionCallLocation: + """Represents a location where the target function is called.""" + + calling_function: str + line: int + column: int + + +@dataclass +class FunctionDefinitionInfo: + """Contains information about a function definition.""" + + name: str + node: ast.FunctionDef + source_code: str + start_line: int + end_line: int + is_method: bool + class_name: str | None = None + + +class FunctionCallFinder(ast.NodeVisitor): + """AST visitor that finds all function definitions that call a specific qualified function. + + Args: + target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") + target_filepath: The filepath where the target function is defined + + """ + + def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]) -> None: + self.target_function_name = target_function_name + self.target_filepath = target_filepath + self.source_lines = source_lines # Store original source lines for extraction + + # Parse the target function name into parts + self.target_parts = target_function_name.split(".") + self.target_base_name = self.target_parts[-1] + + # Track current context + self.current_function_stack: list[tuple[str, ast.FunctionDef]] = [] + self.current_class_stack: list[str] = [] + + # Track imports to resolve qualified names + self.imports: dict[str, str] = {} # Maps imported names to their full paths + + # Results + self.function_calls: list[FunctionCallLocation] = [] + self.calling_functions: set[str] = set() + self.function_definitions: dict[str, FunctionDefinitionInfo] = {} + + # Track if we found calls in the current function + self.found_call_in_current_function = False + self.functions_with_nested_calls: set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Track regular imports.""" + for alias in node.names: + if alias.asname: + # import module as alias + self.imports[alias.asname] = alias.name + else: + # import module + self.imports[alias.name.split(".")[-1]] = alias.name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Track from imports.""" + if node.module: + for alias in node.names: + if alias.name == "*": + # from module import * + self.imports["*"] = node.module + elif alias.asname: + # from module import name as alias + self.imports[alias.asname] = f"{node.module}.{alias.name}" + else: + # from module import name + self.imports[alias.name] = f"{node.module}.{alias.name}" + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Track when entering a class definition.""" + self.current_class_stack.append(node.name) + self.generic_visit(node) + self.current_class_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + self.visit_function_def(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Track when entering an async function definition.""" + self.visit_function_def(node) # type: ignore[arg-type] + + def visit_function_def(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + func_name = node.name + + # Build the full qualified name including class if applicable + full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name + + self.current_function_stack.append((full_name, node)) + self.found_call_in_current_function = False + + # Visit the function body + self.generic_visit(node) + + # Process the function after visiting its body + if self.found_call_in_current_function and full_name not in self.function_definitions: + # Extract function source code + source_code = self.extract_source_code(node) + + self.function_definitions[full_name] = FunctionDefinitionInfo( + name=full_name, + node=node, + source_code=source_code, + start_line=node.lineno, + end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, + is_method=bool(self.current_class_stack), + class_name=self.current_class_stack[-1] if self.current_class_stack else None, + ) + + # Handle nested functions - mark parent as containing nested calls + if self.found_call_in_current_function and len(self.current_function_stack) > 1: + parent_name = self.current_function_stack[-2][0] + self.functions_with_nested_calls.add(parent_name) + + # Also store the parent function if not already stored + if parent_name not in self.function_definitions: + parent_node = self.current_function_stack[-2][1] + parent_source = self.extract_source_code(parent_node) + + # Check if parent is a method (excluding current level) + parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] + + self.function_definitions[parent_name] = FunctionDefinitionInfo( + name=parent_name, + node=parent_node, + source_code=parent_source, + start_line=parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, + is_method=bool(parent_class_context), + class_name=parent_class_context[-1] if parent_class_context else None, + ) + + self.current_function_stack.pop() + + # Reset flag for parent function + if self.current_function_stack: + parent_name = self.current_function_stack[-1][0] + self.found_call_in_current_function = parent_name in self.calling_functions + + def visit_Call(self, node: ast.Call) -> None: + """Check if this call matches our target function.""" + if not self.current_function_stack: + # Not inside a function, skip + self.generic_visit(node) + return + + if self.is_target_function_call(node): + current_func_name = self.current_function_stack[-1][0] + + call_location = FunctionCallLocation( + calling_function=current_func_name, line=node.lineno, column=node.col_offset + ) + + self.function_calls.append(call_location) + self.calling_functions.add(current_func_name) + self.found_call_in_current_function = True + + self.generic_visit(node) + + def is_target_function_call(self, node: ast.Call) -> bool: + """Determine if this call node is calling our target function.""" + call_name = self.get_call_name(node.func) + if not call_name: + return False + + # Check if it matches directly + if call_name == self.target_function_name: + return True + + # Check if it's just the base name matching + if call_name == self.target_base_name: + # Could be imported with a different name, check imports + if call_name in self.imports: + imported_path = self.imports[call_name] + if imported_path == self.target_function_name or imported_path.endswith( + f".{self.target_function_name}" + ): + return True + # Could also be a direct call if we're in the same file + return True + + # Check for qualified calls with imports + call_parts = call_name.split(".") + if call_parts[0] in self.imports: + # Resolve the full path using imports + base_import = self.imports[call_parts[0]] + full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import + + if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): + return True + + return False + + def get_call_name(self, func_node) -> str | None: + """Extract the name being called from a function node.""" + # Fast path short-circuit for ast.Name nodes + if isinstance(func_node, ast.Name): + return func_node.id + + # Fast attribute chain extraction (speed: append, loop, join, NO reversed) + if isinstance(func_node, ast.Attribute): + parts = [] + current = func_node + # Unwind attribute chain as tight as possible (checked at each loop iteration) + while True: + parts.append(current.attr) + val = current.value + if isinstance(val, ast.Attribute): + current = val + continue + if isinstance(val, ast.Name): + parts.append(val.id) + # Join in-place backwards via slice instead of reversed for slight speedup + return ".".join(parts[::-1]) + break + return None + + def extract_source_code(self, node: ast.FunctionDef) -> str: + """Extract source code for a function node using original source lines.""" + if not self.source_lines or not hasattr(node, "lineno"): + # Fallback to ast.unparse if available (Python 3.9+) + try: + return ast.unparse(node) + except AttributeError: + return f"# Source code extraction not available for {node.name}" + + # Get the lines for this function + start_line = node.lineno - 1 # Convert to 0-based index + end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) + + # Extract the function lines + func_lines = self.source_lines[start_line:end_line] + + # Find the minimum indentation (excluding empty lines) + min_indent = float("inf") + for line in func_lines: + if line.strip(): # Skip empty lines + indent = len(line) - len(line.lstrip()) + min_indent = min(min_indent, indent) + + # If this is a method (inside a class), preserve one level of indentation + if self.current_class_stack: + # Keep 4 spaces of indentation for methods + dedent_amount = max(0, min_indent - 4) + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) + else: + result_lines.append(line) + else: + # For top-level functions, remove all leading indentation + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[min_indent:] if len(line) > min_indent else line) + else: + result_lines.append(line) + + return "".join(result_lines).rstrip() + + def get_results(self) -> dict[str, str]: + """Get the results of the analysis. + + Returns: + A dictionary mapping qualified function names to their source code definitions. + + """ + return {info.name: info.source_code for info in self.function_definitions.values()} + + +def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict[str, str]: + """Find all function definitions that call a specific target function. + + Args: + source_code: The Python source code to analyze + target_function_name: The qualified name of the function to find (e.g., "module.function") + target_filepath: The filepath where the target function is defined + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Example: {"function_a": "def function_a(): ...", "MyClass.method_one": "def method_one(self): ..."} + + """ + # Parse the source code + tree = ast.parse(source_code) + + # Split source into lines for source extraction + source_lines = source_code.splitlines(keepends=True) + + # Create and run the visitor + visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) + visitor.visit(tree) + + return visitor.get_results() + + +def find_references( + function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 +) -> list: + """Find all references (call sites) to a function across the codebase.""" + from codeflash_python.context.types import ReferenceInfo + + try: + source = function.file_path.read_text() + script = jedi.Script(code=source, path=function.file_path) + names = script.get_names(all_scopes=True, definitions=True) + + function_pos = None + for name in names: + if name.type == "function" and name.name == function.function_name: + if function.class_name: + parent = name.parent() + if parent and parent.name == function.class_name and parent.type == "class": + function_pos = (name.line, name.column) + break + else: + function_pos = (name.line, name.column) + break + + if function_pos is None: + return [] + + script = jedi.Script(code=source, path=function.file_path, project=jedi.Project(path=project_root)) + references = script.get_references(line=function_pos[0], column=function_pos[1]) + + result: list[ReferenceInfo] = [] + seen_locations: set[tuple[Path, int, int]] = set() + + for ref in references: + if not ref.module_path: + continue + ref_path = Path(ref.module_path) + if ref_path == function.file_path and ref.line == function_pos[0]: + continue + if tests_root: + try: + ref_path.relative_to(tests_root) + continue + except ValueError: + pass + loc_key = (ref_path, ref.line, ref.column) + if loc_key in seen_locations: + continue + seen_locations.add(loc_key) + try: + ref_source = ref_path.read_text() + lines = ref_source.splitlines() + context = lines[ref.line - 1] if ref.line <= len(lines) else "" + except Exception: + context = "" + caller_function = None + try: + parent = ref.parent() + if parent and parent.type == "function": + caller_function = parent.name + except Exception: + pass + result.append( + ReferenceInfo( + file_path=ref_path, + line=ref.line, + column=ref.column, + end_line=ref.line, + end_column=ref.column + len(function.function_name), + context=context.strip(), + reference_type="call", + import_name=function.function_name, + caller_function=caller_function, + ) + ) + return result + except Exception as e: + logger.warning("Failed to find references for %s: %s", function.function_name, e) + return [] + + +def extract_calling_function_source(source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function in Python.""" + try: + import ast as _ast + + lines = source_code.splitlines() + tree = _ast.parse(source_code) + for node in _ast.walk(tree): + if isinstance(node, (_ast.FunctionDef, _ast.AsyncFunctionDef)) and node.name == function_name: + end_line = node.end_lineno or node.lineno + if node.lineno <= ref_line <= end_line: + return "\n".join(lines[node.lineno - 1 : end_line]) + except Exception: + return None + return None + + +def get_opt_review_metrics( + source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path, language: str +) -> str: + """Get function reference metrics for optimization review. + + Uses static analysis to find references in Python code. + + Args: + source_code: Source code of the file containing the function. + file_path: Path to the file. + qualified_name: Qualified name of the function (e.g., "module.ClassName.method"). + project_root: Root of the project. + tests_root: Root of the tests directory. + language: The programming language. + + Returns: + Markdown-formatted string with code blocks showing calling functions. + + """ + from codeflash_core.models import FunctionParent, FunctionToOptimize + + start_time = time.perf_counter() + + try: + # Parse qualified name to get function name and class name + qualified_name_split = qualified_name.rsplit(".", maxsplit=1) + if len(qualified_name_split) == 1: + function_name, class_name = qualified_name_split[0], None + else: + function_name, class_name = qualified_name_split[1], qualified_name_split[0] + + # Create a FunctionToOptimize for the function + # We don't have full line info here, so we'll use defaults + parents: list[FunctionParent] = [] + if class_name: + parents = [FunctionParent(name=class_name, type="ClassDef")] + + func_info = FunctionToOptimize( + function_name=function_name, + file_path=file_path, + parents=parents, + starting_line=1, + ending_line=1, + language=str(language), + ) + + # Find references + references = find_references(func_info, project_root, tests_root, max_files=500) + + if not references: + return "" + + # Format references as markdown code blocks + calling_fns_details = format_references_as_markdown(references, file_path, project_root, language) + + except Exception as e: + logger.debug("Error getting function references: %s", e) + calling_fns_details = "" + + end_time = time.perf_counter() + logger.debug("Got function references in %.2f seconds", end_time - start_time) + return calling_fns_details + + +def format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: str) -> str: + """Format references as markdown code blocks with calling function code. + + Args: + references: List of ReferenceInfo objects. + file_path: Path to the source file (to exclude). + project_root: Root of the project. + language: The programming language. + + Returns: + Markdown-formatted string. + + """ + # Group references by file + refs_by_file: dict[Path, list] = {} + for ref in references: + # Exclude the source file's definition/import references + if ref.file_path == file_path and ref.reference_type in ("import", "reexport"): + continue + + if ref.file_path not in refs_by_file: + refs_by_file[ref.file_path] = [] + refs_by_file[ref.file_path].append(ref) + + fn_call_context = "" + context_len = 0 + + for ref_file, file_refs in refs_by_file.items(): + if context_len > MAX_CONTEXT_LEN_REVIEW: + break + + try: + path_relative = ref_file.relative_to(project_root) + except ValueError: + continue + + lang_hint = "python" + + # Read the file to extract calling function context + try: + file_content = ref_file.read_text(encoding="utf-8") + lines = file_content.splitlines() + except Exception: + continue + + # Get unique caller functions from this file + callers_seen: set[str] = set() + caller_contexts: list[str] = [] + + for ref in file_refs: + caller = ref.caller_function or "" + if caller in callers_seen: + continue + callers_seen.add(caller) + + # Extract context around the reference + if ref.caller_function: + # Try to extract the full calling function + func_code = extract_calling_function_source(file_content, ref.caller_function, ref.line) + if func_code: + caller_contexts.append(func_code) + context_len += len(func_code) + else: + # Module-level call - show a few lines of context + start_line = max(0, ref.line - 3) + end_line = min(len(lines), ref.line + 2) + context_code = "\n".join(lines[start_line:end_line]) + caller_contexts.append(context_code) + context_len += len(context_code) + + if caller_contexts: + fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n" + fn_call_context += "\n".join(caller_contexts) + fn_call_context += "\n```\n" + + return fn_call_context diff --git a/src/codeflash_python/static_analysis/static_analysis.py b/src/codeflash_python/static_analysis/static_analysis.py new file mode 100644 index 000000000..f36c3b591 --- /dev/null +++ b/src/codeflash_python/static_analysis/static_analysis.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import ast +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, TypeVar + +from pydantic import BaseModel, ConfigDict, field_validator + +if TYPE_CHECKING: + from codeflash_core.models import FunctionParent + + +ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) + + +class ImportedInternalModuleAnalysis(BaseModel): + model_config = ConfigDict(frozen=True) + + name: str + full_name: str + file_path: Path + + @field_validator("name") + @classmethod + def name_is_identifier(cls, v: str) -> str: + if not v.isidentifier(): + msg = "must be an identifier" + raise ValueError(msg) + return v + + @field_validator("full_name") + @classmethod + def full_name_is_dotted_identifier(cls, v: str) -> str: + if any(not s or not s.isidentifier() for s in v.split(".")): + msg = "must be a dotted identifier" + raise ValueError(msg) + return v + + @field_validator("file_path") + @classmethod + def file_path_exists(cls, v: Path | None) -> Path | None: + if v and not v.exists(): + msg = "must be an existing path" + raise ValueError(msg) + return v + + +class FunctionKind(Enum): + FUNCTION = 0 + STATIC_METHOD = 1 + CLASS_METHOD = 2 + INSTANCE_METHOD = 3 + + +def parse_imports(code: str) -> list[ast.Import | ast.ImportFrom]: + return [node for node in ast.walk(ast.parse(code)) if isinstance(node, (ast.Import, ast.ImportFrom))] + + +def resolve_relative_name(module: str | None, level: int, current_module: str) -> str | None: + if level == 0: + return module + current_parts = current_module.split(".") + if level > len(current_parts): + return None + base_parts = current_parts[:-level] + if module: + base_parts.extend(module.split(".")) + return ".".join(base_parts) + + +def get_module_full_name(node: ast.Import | ast.ImportFrom, current_module: str) -> list[str]: + if isinstance(node, ast.Import): + return [alias.name for alias in node.names] + base_module = resolve_relative_name(node.module, node.level, current_module) + if base_module is None: + return [] + if node.module is None and node.level > 0: + return [f"{base_module}.{alias.name}" for alias in node.names] + return [base_module] + + +def is_internal_module(module_name: str, project_root: Path) -> bool: + module_path = module_name.replace(".", "/") + possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"] + return any(path.exists() for path in possible_paths) + + +def get_module_file_path(module_name: str, project_root: Path) -> Path | None: + module_path = module_name.replace(".", "/") + possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"] + for path in possible_paths: + if path.exists(): + return path.resolve() + return None + + +def analyze_imported_modules( + code_str: str, module_file_path: Path, project_root: Path +) -> list[ImportedInternalModuleAnalysis]: + """Statically finds and analyzes all imported internal modules.""" + module_rel_path = module_file_path.relative_to(project_root).with_suffix("") + current_module = ".".join(module_rel_path.parts) + imports = parse_imports(code_str) + module_names: set[str] = set() + for node in imports: + module_names.update(get_module_full_name(node, current_module)) + internal_modules = {module_name for module_name in module_names if is_internal_module(module_name, project_root)} + return [ + ImportedInternalModuleAnalysis(name=str(mod_name).split(".")[-1], full_name=mod_name, file_path=file_path) + for mod_name in internal_modules + if (file_path := get_module_file_path(mod_name, project_root)) is not None + ] + + +def get_first_top_level_object_def_ast( + object_name: str, object_type: type[ObjectDefT], node: ast.AST +) -> ObjectDefT | None: + for child in ast.iter_child_nodes(node): + if isinstance(child, object_type) and child.name == object_name: + return child + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + continue + if descendant := get_first_top_level_object_def_ast(object_name, object_type, child): + return descendant + return None + + +def get_first_top_level_function_or_method_ast( + function_name: str, parents: list[FunctionParent], node: ast.AST +) -> ast.FunctionDef | ast.AsyncFunctionDef | None: + if not parents: + result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node) + if result is not None: + return result + return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node) + if parents[0].type == "ClassDef" and ( + class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node) + ): + result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) + if result is not None: + return result + return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node) + return None + + +def function_kind(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: list[FunctionParent]) -> FunctionKind | None: + if not parents or parents[0].type in ["FunctionDef", "AsyncFunctionDef"]: + return FunctionKind.FUNCTION + if parents[0].type == "ClassDef": + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + if decorator.id == "classmethod": + return FunctionKind.CLASS_METHOD + if decorator.id == "staticmethod": + return FunctionKind.STATIC_METHOD + return FunctionKind.INSTANCE_METHOD + return None + + +def has_typed_parameters(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: list[FunctionParent]) -> bool: + kind = function_kind(node, parents) + if kind in [FunctionKind.FUNCTION, FunctionKind.STATIC_METHOD]: + return all(arg.annotation for arg in node.args.args) + if kind in [FunctionKind.CLASS_METHOD, FunctionKind.INSTANCE_METHOD]: + return all(arg.annotation for arg in node.args.args[1:]) + return False diff --git a/src/codeflash_python/telemetry/__init__.py b/src/codeflash_python/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/telemetry/posthog_cf.py b/src/codeflash_python/telemetry/posthog_cf.py new file mode 100644 index 000000000..f2ffe90a1 --- /dev/null +++ b/src/codeflash_python/telemetry/posthog_cf.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import logging +from typing import Any + +from posthog import Posthog + +from codeflash_python.api.cfapi import get_user_id +from codeflash_python.version import __version__ + +logger = logging.getLogger("codeflash_python") + +_posthog = None + + +def initialize_posthog(*, enabled: bool = True) -> None: + """Enable or disable PostHog. + + :param enabled: Whether to enable PostHog. + """ + if not enabled: + return + + global _posthog + _posthog = Posthog(project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com") + _posthog.log.setLevel(logging.CRITICAL) # Suppress PostHog logging + ph("cli-telemetry-enabled") + + +def ph(event: str, properties: dict[str, Any] | None = None) -> None: + """Log an event to PostHog. + + :param event: The name of the event. + :param properties: A dictionary of properties to attach to the event. + """ + if _posthog is None: + return + + properties = properties or {} + properties.update({"cli_version": __version__}) + + user_id = get_user_id() + + if user_id: + _posthog.capture(distinct_id=user_id, event=event, properties=properties) + else: + logger.debug("Failed to log event to PostHog: User ID could not be retrieved.") diff --git a/src/codeflash_python/verification/__init__.py b/src/codeflash_python/verification/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codeflash_python/verification/addopts.py b/src/codeflash_python/verification/addopts.py new file mode 100644 index 000000000..c42999402 --- /dev/null +++ b/src/codeflash_python/verification/addopts.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import configparser +import logging +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path + +import tomlkit + +from codeflash_python.code_utils.config_parser import get_all_closest_config_files + +logger = logging.getLogger("codeflash_python") + +BLACKLIST_ADDOPTS = ("--benchmark", "--sugar", "--codespeed", "--cov", "--profile", "--junitxml", "-n") + + +def filter_args(addopts_args: list[str]) -> list[str]: + # Convert BLACKLIST_ADDOPTS to a set for faster lookup of simple matches + # But keep tuple for startswith + blacklist = BLACKLIST_ADDOPTS + # Precompute the length for re-use + n = len(addopts_args) + filtered_args = [] + i = 0 + while i < n: + current_arg = addopts_args[i] + if current_arg.startswith(blacklist): + i += 1 + if i < n and not addopts_args[i].startswith("-"): + i += 1 + else: + filtered_args.append(current_arg) + i += 1 + return filtered_args + + +def modify_addopts(config_file: Path) -> tuple[str, bool]: + file_type = config_file.suffix.lower() + filename = config_file.name + config = None + if file_type not in {".toml", ".ini", ".cfg"} or not config_file.exists(): + return "", False + # Read original file + with Path.open(config_file, encoding="utf-8") as f: + content = f.read() + try: + if filename == "pyproject.toml": + # use tomlkit + data = tomlkit.parse(content) + original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "") + # nothing to do if no addopts present + if original_addopts == "": + return content, False + if isinstance(original_addopts, list): + original_addopts = " ".join(original_addopts) + original_addopts = original_addopts.replace("=", " ") + addopts_args = ( + original_addopts.split() + ) # any number of space characters as delimiter, doesn't look at = which is fine + else: + # use configparser + config = configparser.ConfigParser() + config.read_string(content) + data = {section: dict(config[section]) for section in config.sections()} + if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}: + original_addopts = data.get("pytest", {}).get("addopts", "") # should only be a string + else: + original_addopts = data.get("tool:pytest", {}).get("addopts", "") # should only be a string + original_addopts = original_addopts.replace("=", " ") + addopts_args = original_addopts.split() + new_addopts_args = filter_args(addopts_args) + if new_addopts_args == addopts_args: + return content, False + # change addopts now + if file_type == ".toml": + data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args) # type: ignore[index,call-overload] + # Write modified file + with Path.open(config_file, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(data)) + return content, True + elif config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}: + assert config is not None + config.set("pytest", "addopts", " ".join(new_addopts_args)) + # Write modified file + with Path.open(config_file, "w", encoding="utf-8") as f: + config.write(f) + return content, True + else: + assert config is not None + config.set("tool:pytest", "addopts", " ".join(new_addopts_args)) + # Write modified file + with Path.open(config_file, "w", encoding="utf-8") as f: + config.write(f) + return content, True + + except Exception: + logger.debug("Trouble parsing") + return content, False # not modified + + +@contextmanager +def custom_addopts() -> Generator[None, None, None]: + closest_config_files = get_all_closest_config_files() + + original_content = {} + + try: + for config_file in closest_config_files: + original_content[config_file] = modify_addopts(config_file) + yield + + finally: + # Restore original file + for file, (content, was_modified) in original_content.items(): + if was_modified: + with Path.open(file, "w", encoding="utf-8") as f: + f.write(content) diff --git a/src/codeflash_python/verification/async_instrumentation.py b/src/codeflash_python/verification/async_instrumentation.py new file mode 100644 index 000000000..6046e1c3e --- /dev/null +++ b/src/codeflash_python/verification/async_instrumentation.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_python.code_utils.formatter import sort_imports +from codeflash_python.models.models import TestingMode + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_core.models import FunctionToOptimize + +logger = logging.getLogger("codeflash_python") + + +class AsyncDecoratorAdder(cst.CSTTransformer): + """Transformer that adds async decorator to async function definitions.""" + + def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None: + """Initialize the transformer. + + Args: + ---- + function: The FunctionToOptimize object representing the target async function. + mode: The testing mode to determine which decorator to apply. + + """ + super().__init__() + self.function = function + self.mode = mode + self.qualified_name_parts = function.qualified_name.split(".") + self.context_stack = [] + self.added_decorator = False + + # Choose decorator based on mode + if mode == TestingMode.BEHAVIOR: + self.decorator_name = "codeflash_behavior_async" + elif mode == TestingMode.CONCURRENCY: + self.decorator_name = "codeflash_concurrency_async" + else: + self.decorator_name = "codeflash_performance_async" + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Track when we enter a class + self.context_stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + # Pop the context when we leave a class + self.context_stack.pop() + return updated_node + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + # Track when we enter a function + self.context_stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # Check if this is an async function and matches our target + if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts: + # Check if the decorator is already present + has_decorator = any(self.is_target_decorator(decorator.decorator) for decorator in original_node.decorators) # type: ignore[invalid-argument-type] + + # Only add the decorator if it's not already there + if not has_decorator: + new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name)) + + # Add our new decorator to the existing decorators + updated_decorators = [new_decorator, *list(updated_node.decorators)] + updated_node = updated_node.with_changes(decorators=tuple(updated_decorators)) + self.added_decorator = True + + # Pop the context when we leave a function + self.context_stack.pop() + return updated_node + + def is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool: + """Check if a decorator matches our target decorator name.""" + if isinstance(decorator_node, cst.Name): + return decorator_node.value in { + "codeflash_trace_async", + "codeflash_behavior_async", + "codeflash_performance_async", + "codeflash_concurrency_async", + } + if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name): + return decorator_node.func.value in { + "codeflash_trace_async", + "codeflash_behavior_async", + "codeflash_performance_async", + "codeflash_concurrency_async", + } + return False + + +ASYNC_HELPER_INLINE_CODE = """import asyncio +import gc +import os +import sqlite3 +import time +from functools import wraps +from pathlib import Path +from tempfile import TemporaryDirectory + +import dill as pickle + + +def get_run_tmp_file(file_path): + if not hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") + return Path(get_run_tmp_file.tmpdir.name) / file_path + + +def extract_test_context_from_env(): + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_behavior_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0") + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + codeflash_con = sqlite3.connect(db_path) + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, " + "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + "runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}######!") + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value)) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) + codeflash_con.commit() + codeflash_con.close() + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_performance_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}:{codeflash_duration}######!") + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_concurrency_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + function_name = func.__name__ + concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")) + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "") + test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "") + loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0") + gc.disable() + try: + seq_start = time.perf_counter_ns() + for _ in range(concurrency_factor): + result = await func(*args, **kwargs) + sequential_time = time.perf_counter_ns() - seq_start + finally: + gc.enable() + gc.disable() + try: + conc_start = time.perf_counter_ns() + tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter_ns() - conc_start + finally: + gc.enable() + tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}" + print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!") + return result + return async_wrapper +""" + +ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" + + +def get_decorator_name_for_mode(mode: TestingMode) -> str: + if mode == TestingMode.BEHAVIOR: + return "codeflash_behavior_async" + if mode == TestingMode.CONCURRENCY: + return "codeflash_concurrency_async" + return "codeflash_performance_async" + + +def write_async_helper_file(target_dir: Path) -> Path: + """Write the async decorator helper file to the target directory.""" + helper_path = target_dir / ASYNC_HELPER_FILENAME + if not helper_path.exists(): + helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8") + return helper_path + + +def add_async_decorator_to_function( + source_path: Path, + function: FunctionToOptimize, + mode: TestingMode = TestingMode.BEHAVIOR, + project_root: Path | None = None, +) -> bool: + """Add async decorator to an async function definition and write back to file. + + Writes a helper file containing the decorator implementation to project_root (or source directory + as fallback) and adds a standard import + decorator to the source file. + + """ + if not function.is_async: + return False + + try: + with source_path.open(encoding="utf8") as f: + source_code = f.read() + + module = cst.parse_module(source_code) + + # Add the decorator to the function + decorator_transformer = AsyncDecoratorAdder(function, mode) + module = module.visit(decorator_transformer) + + if decorator_transformer.added_decorator: + # Write the helper file to project_root (on sys.path) or source dir as fallback + helper_dir = project_root if project_root is not None else source_path.parent + write_async_helper_file(helper_dir) + # Add the import via CST so sort_imports can place it correctly + decorator_name = get_decorator_name_for_mode(mode) + import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}") + module = module.with_changes(body=[import_node, *list(module.body)]) + + modified_code = sort_imports(code=module.code, float_to_top=True) + except Exception as e: + logger.exception("Error adding async decorator to function %s: %s", function.qualified_name, e) + return False + else: + if decorator_transformer.added_decorator: + with source_path.open("w", encoding="utf8") as f: + f.write(modified_code) + logger.debug("Applied async %s instrumentation to %s", mode.value, source_path) + return True + return False diff --git a/src/codeflash_python/verification/codeflash_capture.py b/src/codeflash_python/verification/codeflash_capture.py new file mode 100644 index 000000000..fd4f40fba --- /dev/null +++ b/src/codeflash_python/verification/codeflash_capture.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +# This file should not have any dependencies on codeflash +import functools +import gc +import inspect +import os +import sqlite3 +import time +import warnings +from enum import Enum +from pathlib import Path +from typing import Callable + +import dill as pickle +from dill import PicklingWarning + +from codeflash_python.picklepatch.pickle_patcher import PicklePatcher + +warnings.filterwarnings("ignore", category=PicklingWarning) + + +class VerificationType(str, Enum): + FUNCTION_CALL = ( + "function_call" # Correctness verification for a test function, checks input values and output values) + ) + INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init + INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init + + +def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]: + """Extract test information by walking the call stack from the current frame.""" + test_module_name: str = "" + test_class_name: str | None = None + test_name: str = "" + line_id: str = "" + + # Get current frame and skip our own function's frame + frame = inspect.currentframe() + if frame is not None: + frame = frame.f_back + + # Walk the stack + while frame is not None: + function_name = frame.f_code.co_name + filename = frame.f_code.co_filename + lineno = frame.f_lineno + + # Check if function name indicates a test (e.g., starts with "test_") + if function_name.startswith("test_"): + test_name = function_name + test_module = inspect.getmodule(frame) + if hasattr(test_module, "__name__"): + test_module_name = test_module.__name__ + line_id = str(lineno) + + # Check if it's a method in a class + if ( + "self" in frame.f_locals + and hasattr(frame.f_locals["self"], "__class__") + and hasattr(frame.f_locals["self"].__class__, "__name__") + ): + test_class_name = frame.f_locals["self"].__class__.__name__ + break + + # Check for instantiation on the module level + if ( + "__name__" in frame.f_globals + and frame.f_globals["__name__"].split(".")[-1].startswith("test_") + and Path(filename).resolve().is_relative_to(Path(tests_root)) + and function_name == "" + ): + test_module_name = frame.f_globals["__name__"] + line_id = str(lineno) + + # # Check if it's a method in a class + if ( + "self" in frame.f_locals + and hasattr(frame.f_locals["self"], "__class__") + and hasattr(frame.f_locals["self"].__class__, "__name__") + ): + test_class_name = frame.f_locals["self"].__class__.__name__ + break + + # Go to the previous frame + frame = frame.f_back + + # If stack walking didn't find test info, fall back to environment variables + if not test_name: + env_test_function = os.environ.get("CODEFLASH_TEST_FUNCTION") + if env_test_function: + test_name = env_test_function + if not test_module_name: + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + if not test_class_name: + env_class = os.environ.get("CODEFLASH_TEST_CLASS") + test_class_name = env_class if env_class else None + + # Ensure test_name is a string, not None + final_test_name: str = test_name if test_name else "" + return test_module_name, test_class_name, final_test_name, line_id + + +def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False) -> Callable: + """Define a decorator to instrument the init function, collect test info, and capture the instance state.""" + + def decorator(wrapped: Callable) -> Callable: + @functools.wraps(wrapped) + def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003 + # Dynamic information retrieved from stack + test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack(tests_root) + + # Get env variables + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + codeflash_iteration = os.environ["CODEFLASH_TEST_ITERATION"] + + # Create test_id + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + + # Initialize index tracking if needed, handles multiple instances created in the same test line number + if not hasattr(wrapper, "index"): + wrapper.index = {} # type: ignore[attr-defined] + + # Update index for this test + if test_id in wrapper.index: # type: ignore[attr-defined] + wrapper.index[test_id] += 1 # type: ignore[attr-defined] + else: + wrapper.index[test_id] = 0 # type: ignore[attr-defined] + + codeflash_test_index = wrapper.index[test_id] # type: ignore[attr-defined] + + # Generate invocation id + invocation_id = f"{line_id}_{codeflash_test_index}" + test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + # Connect to sqlite + codeflash_con = sqlite3.connect(f"{tmp_dir_path}_{codeflash_iteration}.sqlite") + codeflash_cur = codeflash_con.cursor() + + # Record timing information + exception = None + gc.disable() + try: + counter = time.perf_counter_ns() + wrapped(*args, **kwargs) + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}######!") + + # Capture instance state after initialization + # self is always the first argument, this is ensured during instrumentation + instance = args[0] + if hasattr(instance, "__dict__"): + instance_state = instance.__dict__ + elif hasattr(instance, "__slots__"): + # For classes using __slots__, capture slot values + instance_state = { + slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot) + } + else: + # For C extension types or other special classes (e.g., Playwright's Page), + # capture all non-private, non-callable attributes + instance_state = { + attr: getattr(instance, attr) + for attr in dir(instance) + if not attr.startswith("_") and not callable(getattr(instance, attr, None)) + } + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + + # Write to sqlite + pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + VerificationType.INIT_STATE_FTO if is_fto else VerificationType.INIT_STATE_HELPER, + ), + ) + codeflash_con.commit() + if exception: + raise exception + + return wrapper + + return decorator diff --git a/src/codeflash_python/verification/comparator.py b/src/codeflash_python/verification/comparator.py new file mode 100644 index 000000000..fa111c5c5 --- /dev/null +++ b/src/codeflash_python/verification/comparator.py @@ -0,0 +1,666 @@ +from __future__ import annotations + +import _thread +import array +import ast +import datetime +import decimal +import enum +import io +import itertools +import logging +import math +import re +import sqlite3 +import threading +import types +import warnings +import weakref +import xml.etree.ElementTree as ET +from collections import ChainMap, OrderedDict, deque +from importlib.util import find_spec +from typing import Any + +import sentry_sdk + +from codeflash_python.picklepatch.pickle_placeholder import PicklePlaceholderAccessError + +logger = logging.getLogger("codeflash_python") + +HAS_NUMPY = find_spec("numpy") is not None +HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None +HAS_SCIPY = find_spec("scipy") is not None +HAS_PANDAS = find_spec("pandas") is not None +HAS_PYRSISTENT = find_spec("pyrsistent") is not None +HAS_TORCH = find_spec("torch") is not None +HAS_JAX = find_spec("jax") is not None +HAS_XARRAY = find_spec("xarray") is not None +HAS_TENSORFLOW = find_spec("tensorflow") is not None +HAS_NUMBA = find_spec("numba") is not None +HAS_PYARROW = find_spec("pyarrow") is not None + +if HAS_NUMPY: + import numpy as np +if HAS_SCIPY: + import scipy +if HAS_JAX: + import jax + import jax.numpy as jnp +if HAS_XARRAY: + import xarray +if HAS_TENSORFLOW: + import tensorflow as tf +if HAS_SQLALCHEMY: + import sqlalchemy # type: ignore[import-not-found] +if HAS_PYARROW: + import pyarrow as pa # type: ignore[import-not-found] +if HAS_PANDAS: + import pandas # noqa: ICN001 +if HAS_TORCH: + import torch +if HAS_NUMBA: + import numba + from numba.core.dispatcher import Dispatcher + from numba.typed import Dict as NumbaDict + from numba.typed import List as NumbaList +if HAS_PYRSISTENT: + import pyrsistent # type: ignore[import-not-found] + +# Pattern to match pytest temp directories: /tmp/pytest-of-/pytest-/ +# These paths vary between test runs but are logically equivalent +PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108 + +# Pattern to match Python tempfile directories: /tmp/tmp/ +# Created by tempfile.mkdtemp() or tempfile.TemporaryDirectory() +PYTHON_TEMPFILE_PATTERN = re.compile(r"/tmp/tmp[a-zA-Z0-9_]+/") # noqa: S108 + +_DICT_KEYS_TYPE = type({}.keys()) +_DICT_VALUES_TYPE = type({}.values()) +_DICT_ITEMS_TYPE = type({}.items()) + +_EQUALITY_TYPES = ( + int, + bool, + complex, + type(None), + type(Ellipsis), + decimal.Decimal, + set, + bytes, + bytearray, + memoryview, + frozenset, + enum.Enum, + type, + range, + slice, + OrderedDict, + types.GenericAlias, + *((_union_type,) if (_union_type := getattr(types, "UnionType", None)) else ()), +) + + +def normalize_temp_path(path: str) -> str: + """Normalize temporary file paths by replacing session-specific components. + + Handles two types of temp paths: + - Pytest: /tmp/pytest-of-/pytest-/ -> /tmp/pytest-temp/ + - Python tempfile: /tmp/tmp/ -> /tmp/python-temp/ + """ + path = PYTEST_TEMP_PATH_PATTERN.sub("/tmp/pytest-temp/", path) # noqa: S108 + return PYTHON_TEMPFILE_PATTERN.sub("/tmp/python-temp/", path) # noqa: S108 + + +def is_temp_path(s: str) -> bool: + """Check if a string looks like a temp path (pytest or Python tempfile).""" + return PYTEST_TEMP_PATH_PATTERN.search(s) is not None or PYTHON_TEMPFILE_PATTERN.search(s) is not None + + +def extract_exception_from_message(msg: str) -> BaseException | None: + """Try to extract a wrapped exception type from an error message. + + Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception. + Returns a synthetic exception of that type if found in builtins, None otherwise. + """ + # Pattern: "got ExceptionType('message')" or "got ExceptionType("message")" + # This pattern is used by torch._dynamo and potentially other libraries + match = re.search(r"got (\w+)\(['\"]", msg) + if match: + exc_name = match.group(1) + # Try to find this exception type in builtins + import builtins + + exc_class = getattr(builtins, exc_name, None) + if exc_class is not None and isinstance(exc_class, type) and issubclass(exc_class, BaseException): + return exc_class() + return None + + +def get_wrapped_exception(exc: BaseException) -> BaseException | None: + """Get the wrapped exception if this is a simple wrapper. + + Returns the inner exception if: + - exc is an ExceptionGroup with exactly one exception + - exc has a __cause__ (explicit chaining via 'raise X from Y') + - exc message contains a wrapped exception type pattern (e.g., "got IndexError('...")") + + Returns None if exc is not a wrapper or wraps multiple exceptions. + """ + # Check for ExceptionGroup with single exception (Python 3.11+) + if hasattr(exc, "exceptions"): + exceptions = exc.exceptions + if len(exceptions) == 1: # type: ignore[arg-type] + return exceptions[0] # type: ignore[index] + # Check for explicit exception chaining (__cause__) + if exc.__cause__ is not None: + return exc.__cause__ + # Try to extract wrapped exception type from the message (library-agnostic) + return extract_exception_from_message(str(exc)) + + +def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: + """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" + try: + # Handle exceptions specially - before type check to allow wrapper comparison + if isinstance(orig, BaseException) and isinstance(new, BaseException): + if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): + # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. + # The test results should be rejected as the behavior of the unpickleable object is unknown. + logger.debug("Unable to verify behavior of unpickleable object in replay test") + return False + + # If types match exactly, compare attributes + if type(orig) is type(new): + orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")} + new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} + return comparator(orig_dict, new_dict, superset_obj) + + # Types differ - check if one is a wrapper over the other + # Check if orig wraps something that matches new + wrapped_orig = get_wrapped_exception(orig) + if wrapped_orig is not None and comparator(wrapped_orig, new, superset_obj): + return True + + # Check if new wraps something that matches orig + wrapped_new = get_wrapped_exception(new) + if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): + return True + + return False + + if type(orig) is not type(new): + type_obj = type(orig) + new_type_obj = type(new) + # distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names + if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__: + return False + if isinstance(orig, (list, tuple, deque, ChainMap)): + if len(orig) != len(new): + return False + return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + + # Handle strings separately to normalize temp paths + if isinstance(orig, str): + if orig == new: + return True + # If strings differ, check if they're temp paths that differ only in session number + if is_temp_path(orig) and is_temp_path(new): + return normalize_temp_path(orig) == normalize_temp_path(new) + return False + + if isinstance(orig, _EQUALITY_TYPES): + return orig == new + if isinstance(orig, float): + if math.isnan(orig) and math.isnan(new): + return True + return math.isclose(orig, new) + + # Handle weak references (e.g., found in torch.nn.LSTM/GRU modules) + if isinstance(orig, weakref.ref): + orig_referent = orig() + new_referent = new() + # Both dead refs are equal, otherwise compare referents + if orig_referent is None and new_referent is None: + return True + if orig_referent is None or new_referent is None: + return False + return comparator(orig_referent, new_referent, superset_obj) + + if HAS_JAX: + # Handle JAX arrays first to avoid boolean context errors in other conditions + if isinstance(orig, jax.Array): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + return bool(jnp.allclose(orig, new, equal_nan=True)) + + # Handle xarray objects before numpy to avoid boolean context errors + if HAS_XARRAY: + if isinstance(orig, (xarray.Dataset, xarray.DataArray)): + return orig.identical(new) + + # Handle TensorFlow objects early to avoid boolean context errors + if HAS_TENSORFLOW: + if isinstance(orig, tf.Tensor): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + # Use numpy conversion for proper NaN handling + return comparator(orig.numpy(), new.numpy(), superset_obj) + + if isinstance(orig, tf.Variable): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + return comparator(orig.numpy(), new.numpy(), superset_obj) + + if isinstance(orig, tf.dtypes.DType): + return orig == new + + if isinstance(orig, tf.TensorShape): + return orig == new + + if isinstance(orig, tf.SparseTensor): + if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj): + return False + return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator( + orig.values.numpy(), new.values.numpy(), superset_obj + ) + + if isinstance(orig, tf.RaggedTensor): + if orig.dtype != new.dtype: + return False + if orig.shape.rank != new.shape.rank: + return False + return comparator(orig.to_list(), new.to_list(), superset_obj) + + if HAS_SQLALCHEMY: + try: + insp = sqlalchemy.inspection.inspect(orig) + insp = sqlalchemy.inspection.inspect(new) + orig_keys = orig.__dict__ + new_keys = new.__dict__ + for key in list(orig_keys.keys()): + if key.startswith("_"): + continue + if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj): + return False + return True + + except sqlalchemy.exc.NoInspectionAvailable: + pass + + # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it + if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): + if superset_obj: + return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) + if len(orig) != len(new): + return False + for key in orig: + if key not in new: + return False + if not comparator(orig[key], new[key], superset_obj): + return False + return True + + # Handle mappingproxy (read-only dict view, commonly seen as class.__dict__) + if isinstance(orig, types.MappingProxyType): + return comparator(dict(orig), dict(new), superset_obj) + + # Handle dict view types (dict_keys, dict_values, dict_items) + if isinstance(orig, _DICT_KEYS_TYPE): + return comparator(set(orig), set(new)) + if isinstance(orig, _DICT_VALUES_TYPE): + return comparator(list(orig), list(new)) + if isinstance(orig, _DICT_ITEMS_TYPE): + return comparator(dict(orig), dict(new), superset_obj) + + if HAS_NUMPY: + if isinstance(orig, (np.datetime64, np.timedelta64)): + # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime + if np.isnat(orig) and np.isnat(new): + return True + if np.isnat(orig) or np.isnat(new): + return False + return orig == new + + if isinstance(orig, np.ndarray): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + # Handle 0-d arrays specially to avoid "iteration over a 0-d array" error + if orig.ndim == 0: + try: + return np.allclose(orig, new, equal_nan=True) + except Exception: + return bool(orig == new) + try: + return np.allclose(orig, new, equal_nan=True) + except Exception: + # fails at "ufunc 'isfinite' not supported for the input types" + return bool(np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])) + + if isinstance(orig, (np.floating, np.complexfloating)): + return bool(np.isclose(orig, new, equal_nan=True)) + + if isinstance(orig, (np.integer, np.bool_, np.byte)): + return orig == new + + if isinstance(orig, np.void): + if orig.dtype != new.dtype: + return False + return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + + # Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.) + if isinstance(orig, np.dtype): + return orig == new + + # Handle numpy random generators + if isinstance(orig, np.random.Generator): + # Compare the underlying BitGenerator state + orig_state = orig.bit_generator.state + new_state = new.bit_generator.state + return comparator(orig_state, new_state, superset_obj) + + if isinstance(orig, np.random.RandomState): + # Compare the internal state + orig_state = orig.get_state(legacy=False) + new_state = new.get_state(legacy=False) + return comparator(orig_state, new_state, superset_obj) + + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): + if orig.dtype != new.dtype: + return False + if orig.get_shape() != new.get_shape(): + return False + return (orig != new).nnz == 0 + + if HAS_PYARROW: + if isinstance(orig, pa.Table): + if orig.schema != new.schema: + return False + if orig.num_rows != new.num_rows: + return False + return bool(orig.equals(new)) + + if isinstance(orig, pa.RecordBatch): + if orig.schema != new.schema: + return False + if orig.num_rows != new.num_rows: + return False + return bool(orig.equals(new)) + + if isinstance(orig, pa.ChunkedArray): + if orig.type != new.type: + return False + if len(orig) != len(new): + return False + return bool(orig.equals(new)) + + if isinstance(orig, pa.Array): + if orig.type != new.type: + return False + if len(orig) != len(new): + return False + return bool(orig.equals(new)) + + if isinstance(orig, pa.Scalar): + if orig.type != new.type: + return False + # Handle null scalars + if not orig.is_valid and not new.is_valid: + return True + if not orig.is_valid or not new.is_valid: + return False + return bool(orig.equals(new)) + + if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)): + return bool(orig.equals(new)) + + if HAS_PANDAS: + if isinstance( + orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray) + ): + return bool(orig.equals(new)) + + if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)): + return orig == new + if pandas.isna(orig) and pandas.isna(new): + return True + + if isinstance(orig, array.array): + if orig.typecode != new.typecode: + return False + if len(orig) != len(new): + return False + return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + + # This should be at the end of all numpy checking + try: + if HAS_NUMPY and np.isnan(orig): + return np.isnan(new) + except Exception: + pass + try: + if HAS_NUMPY and np.isinf(orig): + return np.isinf(new) + except Exception: + pass + + if HAS_TORCH: + if isinstance(orig, torch.Tensor): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + if orig.requires_grad != new.requires_grad: + return False + if orig.device != new.device: + return False + return torch.allclose(orig, new, equal_nan=True) + + if isinstance(orig, torch.dtype): + return orig == new + + if isinstance(orig, torch.device): + return orig == new + + if HAS_NUMBA: + # Handle numba typed List + if isinstance(orig, NumbaList): + if len(orig) != len(new): + return False + return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + + # Handle numba typed Dict + if isinstance(orig, NumbaDict): + if superset_obj: + # Allow new dict to have more keys, but all orig keys must exist with equal values + return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig) + if len(orig) != len(new): + return False + for key in orig: + if key not in new: + return False + if not comparator(orig[key], new[key], superset_obj): + return False + return True + + # Handle numba type objects (e.g., numba.int64, numba.float64, numba.Array, etc.) + if isinstance(orig, numba.core.types.Type): + return orig == new + + # Handle numba JIT-compiled functions (CPUDispatcher, etc.) + if isinstance(orig, Dispatcher): + # Compare by identity of the underlying Python function + # Two JIT functions are equal if they wrap the same Python function + return orig.py_func is new.py_func + + if HAS_PYRSISTENT: + if isinstance( + orig, + ( + pyrsistent.PMap, + pyrsistent.PVector, + pyrsistent.PSet, + pyrsistent.PRecord, + pyrsistent.PClass, + pyrsistent.PBag, + pyrsistent.PList, + pyrsistent.PDeque, + ), + ): + return orig == new + + if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"): + orig_dict = {} + new_dict = {} + + for attr in orig.__attrs_attrs__: + if attr.eq: + attr_name = attr.name + orig_dict[attr_name] = getattr(orig, attr_name, None) + new_dict[attr_name] = getattr(new, attr_name, None) + + if superset_obj: + new_attrs_dict = {} + for attr in new.__attrs_attrs__: + if attr.eq: + attr_name = attr.name + new_attrs_dict[attr_name] = getattr(new, attr_name, None) + return all( + k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items() + ) + return comparator(orig_dict, new_dict, superset_obj) + + # Handle itertools infinite iterators + if isinstance(orig, itertools.count): + # repr reliably reflects internal state, e.g. "count(5)" or "count(5, 2)" + return repr(orig) == repr(new) + + if isinstance(orig, itertools.repeat): + # repr reliably reflects internal state, e.g. "repeat(5)" or "repeat(5, 3)" + return repr(orig) == repr(new) + + if isinstance(orig, itertools.cycle): + # cycle has no useful repr and no public attributes; use __reduce__ to extract state. + # __reduce__ returns (cls, (remaining_iter,), (saved_items, first_pass_done)). + # NOTE: consuming the remaining_iter is destructive to the cycle object, but this is + # acceptable since the comparator is the final consumer of captured return values. + # NOTE: __reduce__ on itertools.cycle was removed in Python 3.14. + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + orig_reduce = orig.__reduce__() + new_reduce = new.__reduce__() + orig_remaining = list(orig_reduce[1][0]) + new_remaining = list(new_reduce[1][0]) + orig_saved, orig_started = orig_reduce[2] + new_saved, new_started = new_reduce[2] + if orig_started != new_started: + return False + return comparator(orig_remaining, new_remaining, superset_obj) and comparator( + orig_saved, new_saved, superset_obj + ) + except TypeError: + # Python 3.14+: __reduce__ removed. Fall back to consuming elements from both + # cycles and comparing. Since the comparator is the final consumer, this is safe. + sample_size = 200 + orig_sample = [next(orig) for _ in range(sample_size)] + new_sample = [next(new) for _ in range(sample_size)] + return comparator(orig_sample, new_sample, superset_obj) + + # Handle remaining itertools types (chain, islice, starmap, product, permutations, etc.) + # by materializing into lists. count/repeat/cycle are already handled above. + # NOTE: materializing is destructive (consumes the iterator) and will hang on infinite input, + # but the three infinite itertools types are already handled above. + if type(orig).__module__ == "itertools": + if isinstance(orig, itertools.groupby): + # groupby yields (key, group_iterator) — materialize groups too + orig_groups = [(k, list(g)) for k, g in orig] + new_groups = [(k, list(g)) for k, g in new] + return comparator(orig_groups, new_groups, superset_obj) + return comparator(list(orig), list(new), superset_obj) + + # re.Pattern can be made better by DFA Minimization and then comparing + if isinstance( + orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern) + ): + return orig == new + + # If the object passed has a user defined __eq__ method, use that + # This could fail if the user defined __eq__ is defined with C-extensions + try: + if hasattr(orig, "__eq__") and isinstance(orig.__eq__, types.MethodType): + return orig == new + except Exception: + pass + + # For class objects + if hasattr(orig, "__dict__") and hasattr(new, "__dict__"): + orig_keys = orig.__dict__ + new_keys = new.__dict__ + if type(orig_keys) == types.MappingProxyType and type(new_keys) == types.MappingProxyType: + # meta class objects + if orig != new: + return False + orig_keys = dict(orig_keys) + new_keys = dict(new_keys) + orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")} + new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")} + + if superset_obj: + # allow new object to be a superset of the original object + return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items()) + + if isinstance(orig, ast.AST): + orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"} + new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"} + return comparator(orig_keys, new_keys, superset_obj) + + # For objects with __slots__ but no __dict__, compare slot attributes + if hasattr(type(orig), "__slots__"): + all_slots = set() + for cls in type(orig).__mro__: + if hasattr(cls, "__slots__"): + all_slots.update(cls.__slots__) + orig_vals = {s: getattr(orig, s, None) for s in all_slots} + new_vals = {s: getattr(new, s, None) for s in all_slots} + if superset_obj: + return all(k in new_vals and comparator(v, new_vals[k], superset_obj) for k, v in orig_vals.items()) + return comparator(orig_vals, new_vals, superset_obj) + + if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}: + return new == orig + if isinstance(orig, ET.Element): + return isinstance(new, ET.Element) and ET.tostring(orig) == ET.tostring(new) + if isinstance( + orig, + ( + _thread.LockType, + _thread.RLock, + threading.Event, + threading.Condition, + sqlite3.Connection, + sqlite3.Cursor, + io.IOBase, + ), + ): + return type(orig) is type(new) + if str(type(orig)) == "": + return True + logger.warning("Unknown comparator input type: %s", type(orig)) + sentry_sdk.capture_exception(RuntimeError(f"Unknown comparator input type: {type(orig)}")) + return False + except RecursionError as e: + logger.exception("RecursionError while comparing objects: %s", e) + sentry_sdk.capture_exception(e) + return False + except Exception as e: + logger.exception("Error while comparing objects: %s", e) + sentry_sdk.capture_exception(e) + return False diff --git a/src/codeflash_python/verification/concolic.py b/src/codeflash_python/verification/concolic.py new file mode 100644 index 000000000..ae016ce8d --- /dev/null +++ b/src/codeflash_python/verification/concolic.py @@ -0,0 +1,105 @@ +"""Concolic test generation using CrossHair.""" + +from __future__ import annotations + +import ast +import importlib.util +import logging +import subprocess +import tempfile +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash_core.config import TestConfig +from codeflash_python.code_utils.compat import SAFE_SYS_EXECUTABLE +from codeflash_python.code_utils.shell_utils import make_env_with_project_root +from codeflash_python.discovery.discover_unit_tests import discover_unit_tests +from codeflash_python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test +from codeflash_python.static_analysis.static_analysis import has_typed_parameters +from codeflash_python.telemetry.posthog_cf import ph + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + +logger = logging.getLogger(__name__) + + +def generate_concolic_tests( + test_cfg: TestConfig, project_root: Path, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: Any +) -> tuple[dict, str]: + crosshair_available = importlib.util.find_spec("crosshair") is not None + + start_time = time.perf_counter() + function_to_concolic_tests: dict = {} + concolic_test_suite_code = "" + + if not crosshair_available: + logger.debug("Skipping concolic test generation (crosshair-tool is not installed)") + return function_to_concolic_tests, concolic_test_suite_code + + if ( + test_cfg.concolic_test_root_dir + and isinstance(function_to_optimize_ast, ast.FunctionDef) + and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents) + ): + logger.info("Generating concolic opcode coverage tests for the original code...") + try: + env = make_env_with_project_root(project_root) + cover_result = subprocess.run( + [ + SAFE_SYS_EXECUTABLE, + "-m", + "crosshair", + "cover", + "--example_output_format=pytest", + "--per_condition_timeout=20", + ".".join( + [ + function_to_optimize.file_path.relative_to(project_root) + .with_suffix("") + .as_posix() + .replace("/", "."), + function_to_optimize.qualified_name, + ] + ), + ], + capture_output=True, + text=True, + cwd=project_root, + check=False, + timeout=600, + env=env, + ) + except subprocess.TimeoutExpired: + logger.debug("CrossHair Cover test generation timed out") + return function_to_concolic_tests, concolic_test_suite_code + + if cover_result.returncode == 0: + generated_concolic_test: str = cover_result.stdout + if not is_valid_concolic_test(generated_concolic_test, project_root=str(project_root)): + logger.debug("CrossHair generated invalid test, skipping") + return function_to_concolic_tests, concolic_test_suite_code + concolic_test_suite_code = clean_concolic_tests(generated_concolic_test) + concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir)) + concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py" + concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8") + + concolic_test_cfg = TestConfig( + tests_root=concolic_test_suite_dir, + tests_project_rootdir=test_cfg.concolic_test_root_dir, + project_root=project_root, + ) + function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg) + logger.info( + "Created %d concolic unit test case%s ", + num_discovered_concolic_tests, + "s" if num_discovered_concolic_tests != 1 else "", + ) + ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests}) + + else: + logger.debug("Error running CrossHair Cover%s", ": " + cover_result.stderr if cover_result.stderr else ".") + end_time = time.perf_counter() + logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time) + return function_to_concolic_tests, concolic_test_suite_code diff --git a/src/codeflash_python/verification/coverage_utils.py b/src/codeflash_python/verification/coverage_utils.py new file mode 100644 index 000000000..c126611c8 --- /dev/null +++ b/src/codeflash_python/verification/coverage_utils.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import sentry_sdk +from coverage.exceptions import NoDataError + +from codeflash_python.models.models import CoverageData, CoverageStatus, FunctionCoverage +from codeflash_python.static_analysis.coverage_utils import ( + build_fully_qualified_name, + extract_dependent_function, + generate_candidates, +) + +if TYPE_CHECKING: + from collections.abc import Collection + from pathlib import Path + + from codeflash_python.models.models import CodeOptimizationContext + +logger = logging.getLogger("codeflash_python") + + +class CoverageUtils: + """Coverage utils class for interfacing with Coverage.""" + + @staticmethod + def load_from_sqlite_database( + database_path: Path, + config_path: Path, + function_name: str, + code_context: CodeOptimizationContext, + source_code_path: Path, + ) -> CoverageData: + """Load coverage data from an SQLite database, mimicking the behavior of load_from_coverage_file.""" + from coverage import Coverage + from coverage.jsonreport import JsonReporter + + cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True) + + if not database_path.exists() or not database_path.stat().st_size: + logger.debug("Coverage database %s is empty or does not exist", database_path) + sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist") + return CoverageData.create_empty(source_code_path, function_name, code_context) + cov.load() + + reporter = JsonReporter(cov) + temp_json_file = database_path.with_suffix(".report.json") + with temp_json_file.open("w", encoding="utf-8") as f: + try: + reporter.report(morfs=[source_code_path.as_posix()], outfile=f) + except NoDataError: + sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}") + return CoverageData.create_empty(source_code_path, function_name, code_context) + with temp_json_file.open() as f: + original_coverage_data = json.load(f) + + coverage_data, status = CoverageUtils.parse_coverage_file(temp_json_file, source_code_path) + + main_func_coverage, dependent_func_coverage = CoverageUtils.fetch_function_coverages( + function_name, code_context, coverage_data, original_cov_data=original_coverage_data + ) + + total_executed_lines, total_unexecuted_lines = CoverageUtils.aggregate_coverage( + main_func_coverage, dependent_func_coverage + ) + + total_lines = total_executed_lines | total_unexecuted_lines + coverage = len(total_executed_lines) / len(total_lines) * 100 if total_lines else 0.0 + # coverage = (lines covered of the original function + its 1 level deep helpers) / (lines spanned by original function + its 1 level deep helpers), if no helpers then just the original function coverage + + functions_being_tested = [main_func_coverage.name] + if dependent_func_coverage: + functions_being_tested.append(dependent_func_coverage.name) + + graph = CoverageUtils.build_graph(main_func_coverage, dependent_func_coverage) + temp_json_file.unlink() + + return CoverageData( + file_path=source_code_path, + coverage=coverage, + function_name=function_name, + functions_being_tested=functions_being_tested, + graph=graph, + code_context=code_context, + main_func_coverage=main_func_coverage, + dependent_func_coverage=dependent_func_coverage, + status=status, + ) + + @staticmethod + def parse_coverage_file( + coverage_file_path: Path, source_code_path: Path + ) -> tuple[dict[str, dict[str, Any]], CoverageStatus]: + with coverage_file_path.open(encoding="utf-8") as f: + coverage_data = json.load(f) + + candidates = generate_candidates(source_code_path) + + logger.debug("Looking for coverage data in %s", " -> ".join(candidates)) + for candidate in candidates: + try: + cov: dict[str, dict[str, Any]] = coverage_data["files"][candidate]["functions"] + logger.debug("Coverage data found for %s in %s", source_code_path, candidate) + status = CoverageStatus.PARSED_SUCCESSFULLY + break + except KeyError: + continue + else: + logger.debug("No coverage data found for %s in %s", source_code_path, candidates) + cov = {} + status = CoverageStatus.NOT_FOUND + return cov, status + + @staticmethod + def fetch_function_coverages( + function_name: str, + code_context: CodeOptimizationContext, + coverage_data: dict[str, dict[str, Any]], + original_cov_data: dict[str, dict[str, Any]], + ) -> tuple[FunctionCoverage, FunctionCoverage | None]: + resolved_name = build_fully_qualified_name(function_name, code_context) + try: + main_function_coverage = FunctionCoverage( + name=resolved_name, + coverage=coverage_data[resolved_name]["summary"]["percent_covered"], + executed_lines=coverage_data[resolved_name]["executed_lines"], + unexecuted_lines=coverage_data[resolved_name]["missing_lines"], + executed_branches=coverage_data[resolved_name]["executed_branches"], + unexecuted_branches=coverage_data[resolved_name]["missing_branches"], + ) + except KeyError: + main_function_coverage = FunctionCoverage( + name=resolved_name, + coverage=0, + executed_lines=[], + unexecuted_lines=[], + executed_branches=[], + unexecuted_branches=[], + ) + + dependent_function = extract_dependent_function(function_name, code_context) + dependent_func_coverage = ( + CoverageUtils.grab_dependent_function_from_coverage_data( + dependent_function, coverage_data, original_cov_data + ) + if dependent_function + else None + ) + + return main_function_coverage, dependent_func_coverage + + @staticmethod + def aggregate_coverage( + main_func_coverage: FunctionCoverage, dependent_func_coverage: FunctionCoverage | None + ) -> tuple[set[int], set[int]]: + total_executed_lines = set(main_func_coverage.executed_lines) + total_unexecuted_lines = set(main_func_coverage.unexecuted_lines) + + if dependent_func_coverage: + total_executed_lines.update(dependent_func_coverage.executed_lines) + total_unexecuted_lines.update(dependent_func_coverage.unexecuted_lines) + + return total_executed_lines, total_unexecuted_lines + + @staticmethod + def build_graph( + main_func_coverage: FunctionCoverage, dependent_func_coverage: FunctionCoverage | None + ) -> dict[str, dict[str, Collection[object]]]: + graph = { + main_func_coverage.name: { + "executed_lines": set(main_func_coverage.executed_lines), + "unexecuted_lines": set(main_func_coverage.unexecuted_lines), + "executed_branches": main_func_coverage.executed_branches, + "unexecuted_branches": main_func_coverage.unexecuted_branches, + } + } + + if dependent_func_coverage: + graph[dependent_func_coverage.name] = { + "executed_lines": set(dependent_func_coverage.executed_lines), + "unexecuted_lines": set(dependent_func_coverage.unexecuted_lines), + "executed_branches": dependent_func_coverage.executed_branches, + "unexecuted_branches": dependent_func_coverage.unexecuted_branches, + } + + return graph + + @staticmethod + def grab_dependent_function_from_coverage_data( + dependent_function_name: str, + coverage_data: dict[str, dict[str, Any]], + original_cov_data: dict[str, dict[str, Any]], + ) -> FunctionCoverage: + """Grab the dependent function from the coverage data.""" + try: + return FunctionCoverage( + name=dependent_function_name, + coverage=coverage_data[dependent_function_name]["summary"]["percent_covered"], + executed_lines=coverage_data[dependent_function_name]["executed_lines"], + unexecuted_lines=coverage_data[dependent_function_name]["missing_lines"], + executed_branches=coverage_data[dependent_function_name]["executed_branches"], + unexecuted_branches=coverage_data[dependent_function_name]["missing_branches"], + ) + except KeyError: + msg = f"Coverage data not found for dependent function {dependent_function_name} in the coverage data" + try: + files = original_cov_data["files"] + for file in files: + functions = files[file]["functions"] + for function in functions: + if function == dependent_function_name or ( + "." in dependent_function_name and function.endswith(f".{dependent_function_name}") + ): + return FunctionCoverage( + name=dependent_function_name, + coverage=functions[function]["summary"]["percent_covered"], + executed_lines=functions[function]["executed_lines"], + unexecuted_lines=functions[function]["missing_lines"], + executed_branches=functions[function]["executed_branches"], + unexecuted_branches=functions[function]["missing_branches"], + ) + msg = f"Coverage data not found for dependent function {dependent_function_name} in the original coverage data" + except KeyError: + raise ValueError(msg) from None + + return FunctionCoverage( + name=dependent_function_name, + coverage=0, + executed_lines=[], + unexecuted_lines=[], + executed_branches=[], + unexecuted_branches=[], + ) diff --git a/src/codeflash_python/verification/device_sync.py b/src/codeflash_python/verification/device_sync.py new file mode 100644 index 000000000..930c824b6 --- /dev/null +++ b/src/codeflash_python/verification/device_sync.py @@ -0,0 +1,314 @@ +"""GPU/device framework detection and synchronization AST generation.""" + +from __future__ import annotations + +import ast + + +def detect_frameworks_from_code(code: str) -> dict[str, str]: + """Detect GPU/device frameworks (torch, tensorflow, jax) used in the code by analyzing imports. + + Returns: + A dictionary mapping framework names to their import aliases. + For example: {"torch": "th", "tensorflow": "tf", "jax": "jax"} + + """ + frameworks: dict[str, str] = {} + try: + tree = ast.parse(code) + except SyntaxError: + return frameworks + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split(".")[0] + if module_name == "torch": + # Use asname if available, otherwise use the module name + frameworks["torch"] = alias.asname if alias.asname else module_name + elif module_name == "tensorflow": + frameworks["tensorflow"] = alias.asname if alias.asname else module_name + elif module_name == "jax": + frameworks["jax"] = alias.asname if alias.asname else module_name + elif isinstance(node, ast.ImportFrom) and node.module: + module_name = node.module.split(".")[0] + if module_name == "torch" and "torch" not in frameworks: + frameworks["torch"] = module_name + elif module_name == "tensorflow" and "tensorflow" not in frameworks: + frameworks["tensorflow"] = module_name + elif module_name == "jax" and "jax" not in frameworks: + frameworks["jax"] = module_name + + return frameworks + + +def create_device_sync_precompute_statements(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements to pre-compute device sync conditions before profiling. + + This moves the conditional checks (like is_available(), hasattr(), etc.) outside + the timing block to avoid their overhead affecting the measurements. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements that pre-compute sync conditions into boolean variables + + """ + if not used_frameworks: + return [] + + precompute_statements: list[ast.stmt] = [] + + # PyTorch: pre-compute whether to sync CUDA or MPS + if "torch" in used_frameworks: + torch_alias = used_frameworks["torch"] + # _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Store())], + value=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_initialized", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + lineno=1, + ) + ) + # _codeflash_should_sync_mps = (not _codeflash_should_sync_cuda and + # hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and + # hasattr(torch.mps, 'synchronize')) + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_mps", ctx=ast.Store())], + value=ast.BoolOp( + op=ast.And(), + values=[ + ast.UnaryOp(op=ast.Not(), operand=ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Load())), + ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + ast.Constant(value="mps"), + ], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + attr="mps", + ctx=ast.Load(), + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="mps", ctx=ast.Load() + ), + ast.Constant(value="synchronize"), + ], + keywords=[], + ), + ], + ), + lineno=1, + ) + ) + + # JAX: pre-compute whether jax.block_until_ready exists + if "jax" in used_frameworks: + jax_alias = used_frameworks["jax"] + # _codeflash_should_sync_jax = hasattr(jax, 'block_until_ready') + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_jax", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ast.Name(id=jax_alias, ctx=ast.Load()), ast.Constant(value="block_until_ready")], + keywords=[], + ), + lineno=1, + ) + ) + + # TensorFlow: pre-compute whether tf.test.experimental.sync_devices exists + if "tensorflow" in used_frameworks: + tf_alias = used_frameworks["tensorflow"] + # _codeflash_should_sync_tf = hasattr(tf.test.experimental, 'sync_devices') + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_tf", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=tf_alias, ctx=ast.Load()), attr="test", ctx=ast.Load() + ), + attr="experimental", + ctx=ast.Load(), + ), + ast.Constant(value="sync_devices"), + ], + keywords=[], + ), + lineno=1, + ) + ) + + return precompute_statements + + +def create_device_sync_statements( + used_frameworks: dict[str, str] | None, for_return_value: bool = False +) -> list[ast.stmt]: + """Create AST statements for device synchronization using pre-computed conditions. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + (e.g., {'torch': 'th', 'tensorflow': 'tf', 'jax': 'jax'}) + for_return_value: If True, creates sync for after function call (includes JAX block_until_ready) + + Returns: + List of AST statements for device synchronization using pre-computed boolean variables + + """ + if not used_frameworks: + return [] + + sync_statements: list[ast.stmt] = [] + + # PyTorch synchronization using pre-computed conditions + if "torch" in used_frameworks: + torch_alias = used_frameworks["torch"] + # if _codeflash_should_sync_cuda: + # torch.cuda.synchronize() + # elif _codeflash_should_sync_mps: + # torch.mps.synchronize() + cuda_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[ + ast.If( + test=ast.Name(id="_codeflash_should_sync_mps", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="mps", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + ) + ], + ) + sync_statements.append(cuda_sync) + + # JAX synchronization (only after function call, using block_until_ready on return value) + if "jax" in used_frameworks and for_return_value: + jax_alias = used_frameworks["jax"] + # if _codeflash_should_sync_jax: + # jax.block_until_ready(return_value) + jax_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_jax", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=jax_alias, ctx=ast.Load()), attr="block_until_ready", ctx=ast.Load() + ), + args=[ast.Name(id="return_value", ctx=ast.Load())], + keywords=[], + ) + ) + ], + orelse=[], + ) + sync_statements.append(jax_sync) + + # TensorFlow synchronization using pre-computed condition + if "tensorflow" in used_frameworks: + tf_alias = used_frameworks["tensorflow"] + # if _codeflash_should_sync_tf: + # tf.test.experimental.sync_devices() + tf_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_tf", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=tf_alias, ctx=ast.Load()), attr="test", ctx=ast.Load() + ), + attr="experimental", + ctx=ast.Load(), + ), + attr="sync_devices", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + ) + sync_statements.append(tf_sync) + + return sync_statements diff --git a/src/codeflash_python/verification/edit_generated_tests.py b/src/codeflash_python/verification/edit_generated_tests.py new file mode 100644 index 000000000..aad437d25 --- /dev/null +++ b/src/codeflash_python/verification/edit_generated_tests.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import ast +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import libcst as cst +from libcst import MetadataWrapper +from libcst.metadata import PositionProvider + +from codeflash_python.code_utils.time_utils import format_perf, format_time +from codeflash_python.models.models import GeneratedTests, GeneratedTestsList +from codeflash_python.result.critic import performance_gain + +if TYPE_CHECKING: + from codeflash_python.models.models import InvocationId + + +logger = logging.getLogger("codeflash_python") + + +class CommentMapper(ast.NodeVisitor): + def __init__( + self, test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int] + ) -> None: + self.results: dict[int, str] = {} + self.test: GeneratedTests = test + self.original_runtimes = original_runtimes + self.optimized_runtimes = optimized_runtimes + self.abs_path = test.behavior_file_path.with_suffix("") + self.context_stack: list[str] = [] + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + self.context_stack.append(node.name) + for inner_node in node.body: + if isinstance(inner_node, ast.FunctionDef): + self.visit_FunctionDef(inner_node) + elif isinstance(inner_node, ast.AsyncFunctionDef): + self.visit_AsyncFunctionDef(inner_node) + self.context_stack.pop() + return node + + def get_comment(self, match_key: str) -> str: + # calculate speedup and output comment + original_time = self.original_runtimes[match_key] + optimized_time = self.optimized_runtimes[match_key] + perf_gain = format_perf( + abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100) + ) + status = "slower" if optimized_time > original_time else "faster" + # Create the runtime comment + return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + self.process_function_def_common(node) + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + self.process_function_def_common(node) + return node + + def process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: + self.context_stack.append(node.name) + i = len(node.body) - 1 + test_qualified_name = ".".join(self.context_stack) + key = test_qualified_name + "#" + str(self.abs_path) + while i >= 0: + line_node = node.body[i] + if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): + j = len(line_node.body) - 1 + while j >= 0: + compound_line_node: ast.stmt = line_node.body[j] + nodes_to_check = [compound_line_node] + nodes_to_check.extend(getattr(compound_line_node, "body", [])) + for internal_node in nodes_to_check: + if isinstance(internal_node, (ast.stmt, ast.Assign)): + inv_id = str(i) + "_" + str(j) + match_key = key + "#" + inv_id + if match_key in self.original_runtimes and match_key in self.optimized_runtimes: + self.results[internal_node.lineno] = self.get_comment(match_key) + j -= 1 + else: + inv_id = str(i) + match_key = key + "#" + inv_id + if match_key in self.original_runtimes and match_key in self.optimized_runtimes: + self.results[line_node.lineno] = self.get_comment(match_key) + i -= 1 + self.context_stack.pop() + + +def get_fn_call_linenos( + test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int] +) -> dict[int, str]: + line_comment_ast_mapper = CommentMapper(test, original_runtimes, optimized_runtimes) + source_code = test.generated_original_test_source + tree = ast.parse(source_code) + line_comment_ast_mapper.visit(tree) + return line_comment_ast_mapper.results + + +class CommentAdder(cst.CSTTransformer): + """Transformer that adds comments to specified lines.""" + + # Declare metadata dependencies + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, line_to_comments: dict[int, str]) -> None: + """Initialize the transformer with target line numbers. + + Args: + line_to_comments: Mapping of line numbers (1-indexed) to comments + + """ + self.line_to_comments = line_to_comments + super().__init__() + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> cst.SimpleStatementLine: + """Add comment to simple statement lines.""" + pos = self.get_metadata(PositionProvider, original_node) + + if pos and pos.start.line in self.line_to_comments: + # Create a comment with trailing whitespace + comment = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line]) + ) + + # Update the trailing whitespace of the line itself + return updated_node.with_changes(trailing_whitespace=comment) + + return updated_node + + def leave_SimpleStatementSuite( + self, original_node: cst.SimpleStatementSuite, updated_node: cst.SimpleStatementSuite + ) -> cst.SimpleStatementSuite: + """Add comment to simple statement suites (e.g., after if/for/while).""" + pos = self.get_metadata(PositionProvider, original_node) + + if pos and pos.start.line in self.line_to_comments: + # Create a comment with trailing whitespace + comment = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line]) + ) + + # Update the trailing whitespace of the suite + return updated_node.with_changes(trailing_whitespace=comment) + + return updated_node + + +def is_python_file(file_path: Path) -> bool: + """Check if a file is a Python file.""" + return file_path.suffix == ".py" + + +def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]: + unique_inv_ids: dict[str, int] = {} + logger.debug("[unique_inv_id] Processing %s invocation IDs", len(inv_id_runtimes)) + for inv_id, runtimes in inv_id_runtimes.items(): + test_qualified_name = ( + inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator] + if inv_id.test_class_name + else inv_id.test_function_name + ) + + test_module_path = inv_id.test_module_path + if "/" in test_module_path or "\\" in test_module_path: + abs_path = tests_project_rootdir / Path(test_module_path) + else: + abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py") + + abs_path_str = str(abs_path.resolve().with_suffix("")) + # Include both unit test and perf test paths for runtime annotations + # (performance test runtimes are used for annotations) + if ("__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str) or not test_qualified_name: + logger.debug("[unique_inv_id] Skipping: path=%s, test_qualified_name=%s", abs_path_str, test_qualified_name) + continue + key = test_qualified_name + "#" + abs_path_str + parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr] + cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr] + match_key = key + "#" + cur_invid + logger.debug("[unique_inv_id] Adding key: %s with runtime %s", match_key, min(runtimes)) + if match_key not in unique_inv_ids: + unique_inv_ids[match_key] = 0 + unique_inv_ids[match_key] += min(runtimes) + logger.debug("[unique_inv_id] Result has %s entries", len(unique_inv_ids)) + return unique_inv_ids + + +def add_runtime_comments_to_generated_tests( + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, +) -> GeneratedTestsList: + """Add runtime performance comments to function calls in generated tests.""" + original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path()) + optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path()) + # Process each generated test + modified_tests = [] + for test in generated_tests.generated_tests: + is_python = is_python_file(test.behavior_file_path) + + if is_python: + # Use Python libcst-based comment insertion + try: + tree = cst.parse_module(test.generated_original_test_source) + wrapper = MetadataWrapper(tree) + line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict) + comment_adder = CommentAdder(line_to_comments) + modified_tree = wrapper.visit(comment_adder) + modified_source = modified_tree.code + modified_test = GeneratedTests( + generated_original_test_source=modified_source, + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + modified_tests.append(modified_test) + except Exception as e: + # If parsing fails, keep the original test + logger.debug("Failed to add runtime comments to test: %s", e) + modified_tests.append(test) + else: + modified_tests.append(test) + + return GeneratedTestsList(generated_tests=modified_tests) + + +def remove_functions_from_generated_tests( + generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] +) -> GeneratedTestsList: + # Pre-compile patterns for all function names to remove + function_patterns = compile_function_patterns(test_functions_to_remove) + new_generated_tests = [] + + for generated_test in generated_tests.generated_tests: + source = generated_test.generated_original_test_source + + # Apply all patterns without redundant searches + for pattern in function_patterns: + # Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls + for match in pattern.finditer(source): + # Skip if "@pytest.mark.parametrize" present + # Only the matched function's code is targeted + if "@pytest.mark.parametrize" in match.group(0): + continue + # Remove function from source + # If match, remove the function by substitution in the source + # Replace using start/end indices for efficiency + start, end = match.span() + source = source[:start] + source[end:] + # After removal, break since .finditer() is from left to right, and only one match expected per function in source + break + + generated_test.generated_original_test_source = source + new_generated_tests.append(generated_test) + + return GeneratedTestsList(generated_tests=new_generated_tests) + + +# Pre-compile all function removal regexes upfront for efficiency. +def compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]: + return [ + re.compile( + rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)", + re.DOTALL, + ) + for func in test_functions_to_remove + ] + + +def remove_test_functions(test_source: str, functions_to_remove: list[str]) -> str: + """Remove specific test functions from Python test source using libcst. + + Handles both bare function names (top-level) and qualified names (ClassName.method_name). + If all test methods are removed from a class, the class is removed too. + """ + bare_names: set[str] = set() + qualified_names: set[str] = set() + for name in functions_to_remove: + if "." in name: + qualified_names.add(name) + else: + bare_names.add(name) + + class TestFunctionRemover(cst.CSTTransformer): + def __init__(self) -> None: + self.class_stack: list[str] = [] + self.emptied_classes: set[str] = set() + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + self.class_stack.append(node.name.value) + return True + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef | cst.RemovalSentinel: + class_name = self.class_stack.pop() + if class_name in self.emptied_classes: + self.emptied_classes.discard(class_name) + body = updated_node.body + if isinstance(body, cst.IndentedBlock): + has_meaningful_body = any( + not ( + isinstance(s, cst.SimpleStatementLine) + and len(s.body) == 1 + and isinstance(s.body[0], (cst.Pass, cst.Expr)) + and ( + isinstance(s.body[0], cst.Pass) + or (isinstance(s.body[0].value, (cst.SimpleString, cst.ConcatenatedString))) + ) + ) + for s in body.body + ) + if not has_meaningful_body: + return cst.RemovalSentinel.REMOVE + return updated_node + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef | cst.RemovalSentinel: + fn_name = original_node.name.value + if fn_name in bare_names and not self.class_stack: + return cst.RemovalSentinel.REMOVE + if self.class_stack: + qualified = f"{self.class_stack[-1]}.{fn_name}" + if qualified in qualified_names: + self.emptied_classes.add(self.class_stack[-1]) + return cst.RemovalSentinel.REMOVE + return updated_node + + try: + tree = cst.parse_module(test_source) + modified = tree.visit(TestFunctionRemover()) + return modified.code + except Exception: + return test_source diff --git a/src/codeflash_python/verification/equivalence.py b/src/codeflash_python/verification/equivalence.py new file mode 100644 index 000000000..825d20a06 --- /dev/null +++ b/src/codeflash_python/verification/equivalence.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import logging +import re +import reprlib +import sys +from typing import TYPE_CHECKING + +import libcst as cst + +from codeflash_python.api.types import TestDiff, TestDiffScope +from codeflash_python.models.models import TestResults, TestType, VerificationType +from codeflash_python.verification.comparator import comparator + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash_python.models.models import InvocationId, TestResults + + +logger = logging.getLogger("codeflash_python") + + +def shorten_pytest_error(pytest_error_string: str) -> str: + return "\n".join(re.findall(r"^[E>] +(.*)$", pytest_error_string, re.MULTILINE)) + + +INCREASED_RECURSION_LIMIT = 5000 + + +def get_test_src_code(invocation_id: InvocationId, test_path: Path) -> str | None: + """Extract the source code of a test function from a test file using CST parsing.""" + if not test_path.exists(): + return None + try: + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + except Exception: + # Handle case where test_function_name might be None + test_fn_name = invocation_id.test_function_name if invocation_id.test_function_name else "unknown" + return ( + f"// Test: {test_fn_name}\n" + f"// File: {test_path.name}\n" + f"// Testing function: {invocation_id.function_getting_tested}" + ) + + if invocation_id.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == invocation_id.test_class_name: + for member in stmt.body.body: + if ( + isinstance(member, cst.FunctionDef) + and invocation_id.test_function_name + and member.name.value == invocation_id.test_function_name + ): + return module_node.code_for_node(member).strip() + return None + + if invocation_id.test_function_name: + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == invocation_id.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + + +reprlib_repr = reprlib.Repr() +reprlib_repr.maxstring = 1500 +test_diff_repr = reprlib_repr.repr + + +def safe_repr(obj: object) -> str: + """Safely get repr of an object, handling Mock objects with corrupted state.""" + try: + return repr(obj) + except (AttributeError, TypeError, RecursionError) as e: + return f"" + + +def compare_test_results( + original_results: TestResults, candidate_results: TestResults, pass_fail_only: bool = False +) -> tuple[bool, list[TestDiff]]: + # This is meant to be only called with test results for the first loop index + if len(original_results) == 0 or len(candidate_results) == 0: + return False, [] # empty test results are not equal + original_recursion_limit = sys.getrecursionlimit() + if original_recursion_limit < INCREASED_RECURSION_LIMIT: + sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError + test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( + set(candidate_results.get_all_unique_invocation_loop_ids()) + ) + test_diffs: list[TestDiff] = [] + did_all_timeout: bool = True + for test_id in test_ids_superset: + original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) + cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + + if cdd_test_result is not None and original_test_result is None: + continue + # If helper function instance_state verification is not present, that's ok. continue + if ( + original_test_result is not None + and original_test_result.verification_type + and original_test_result.verification_type == VerificationType.INIT_STATE_HELPER + and cdd_test_result is None + ): + continue + if original_test_result is None or cdd_test_result is None: + continue + did_all_timeout = did_all_timeout and (original_test_result.timed_out or False) + if original_test_result.timed_out: + continue + superset_obj = False + if original_test_result.verification_type and ( + original_test_result.verification_type + in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} + ): + superset_obj = True + + candidate_test_failures = candidate_results.test_failures + original_test_failures = original_results.test_failures + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if candidate_test_failures + else "" + ) + if cdd_pytest_error: + cdd_pytest_error = shorten_pytest_error(cdd_pytest_error) + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" + ) + if original_pytest_error: + original_pytest_error = shorten_pytest_error(original_pytest_error) + + if original_test_result.test_type in { + TestType.EXISTING_UNIT_TEST, + TestType.CONCOLIC_COVERAGE_TEST, + TestType.GENERATED_REGRESSION, + TestType.REPLAY_TEST, + } and (cdd_test_result.did_pass != original_test_result.did_pass): + test_diffs.append( + TestDiff( + scope=TestDiffScope.DID_PASS, + original_value=str(original_test_result.did_pass), + candidate_value=str(cdd_test_result.did_pass), + test_src_code=get_test_src_code(original_test_result.id, original_test_result.file_name), + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) + ) + + elif not pass_fail_only and not comparator( + original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj + ): + test_diffs.append( + TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=test_diff_repr(safe_repr(original_test_result.return_value)), + candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)), + test_src_code=get_test_src_code(original_test_result.id, original_test_result.file_name), + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) + ) + + try: + logger.debug( + "File Name: %s\nTest Type: %s\nVerification Type: %s\nInvocation ID: %s\nOriginal return value: %r\nCandidate return value: %r\n", + original_test_result.file_name, + original_test_result.test_type, + original_test_result.verification_type, + original_test_result.id, + original_test_result.return_value, + cdd_test_result.return_value, + ) + except Exception as e: + logger.exception(e) + elif ( + not pass_fail_only + and (original_test_result.stdout and cdd_test_result.stdout) + and not comparator(original_test_result.stdout, cdd_test_result.stdout) + ): + test_diffs.append( + TestDiff( + scope=TestDiffScope.STDOUT, + original_value=str(original_test_result.stdout), + candidate_value=str(cdd_test_result.stdout), + test_src_code=get_test_src_code(original_test_result.id, original_test_result.file_name), + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) + ) + + sys.setrecursionlimit(original_recursion_limit) + if did_all_timeout: + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/src/codeflash_python/verification/instrument_codeflash_capture.py b/src/codeflash_python/verification/instrument_codeflash_capture.py new file mode 100644 index 000000000..3a790dd82 --- /dev/null +++ b/src/codeflash_python/verification/instrument_codeflash_capture.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import ast +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.code_utils import get_run_tmp_file +from codeflash_python.code_utils.formatter import sort_imports + +_ATTRS_NAMESPACES = frozenset({"attrs", "attr"}) +_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"}) + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + + +def instrument_codeflash_capture( + function_to_optimize: FunctionToOptimize, file_path_to_helper_class: dict[Path, set[str]], tests_root: Path +) -> None: + """Instrument __init__ function with codeflash_capture decorator if it's in a class.""" + # Find the class parent + if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef": + class_parent = function_to_optimize.parents[0] + else: + return + # Remove duplicate fto class from helper classes + if ( + function_to_optimize.file_path in file_path_to_helper_class + and class_parent.name in file_path_to_helper_class[function_to_optimize.file_path] + ): + file_path_to_helper_class[function_to_optimize.file_path].remove(class_parent.name) + # Instrument fto class + original_code = function_to_optimize.file_path.read_text(encoding="utf-8") + # Add decorator to init + modified_code = add_codeflash_capture_to_init( + target_classes={class_parent.name}, + fto_name=function_to_optimize.function_name, + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), + code=original_code, + tests_root=tests_root, + is_fto=True, + ) + function_to_optimize.file_path.write_text(modified_code, encoding="utf-8") + + # Instrument helper classes + for file_path, helper_classes in file_path_to_helper_class.items(): + original_code = file_path.read_text(encoding="utf-8") + modified_code = add_codeflash_capture_to_init( + target_classes=helper_classes, + fto_name=function_to_optimize.function_name, + tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(), + code=original_code, + tests_root=tests_root, + is_fto=False, + ) + file_path.write_text(modified_code, encoding="utf-8") + + +def add_codeflash_capture_to_init( + target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, *, is_fto: bool = False +) -> str: + """Add codeflash_capture decorator to __init__ function in the specified class.""" + tree = ast.parse(code) + transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto=is_fto) + modified_tree = transformer.visit(tree) + if transformer.inserted_decorator: + ast.fix_missing_locations(modified_tree) + + # Convert back to source code + return sort_imports(code=ast.unparse(modified_tree), float_to_top=True) + + +class InitDecorator(ast.NodeTransformer): + """AST transformer that adds codeflash_capture decorator to specific class's __init__.""" + + def __init__( + self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, *, is_fto: bool = False + ) -> None: + self.target_classes = target_classes + self.fto_name = fto_name + self.tmp_dir_path = tmp_dir_path + self.is_fto = is_fto + self.has_import = False + self.tests_root = tests_root + self.inserted_decorator = False + self.attrs_classes_to_patch: dict[str, ast.Call] = {} + + # Precompute decorator components to avoid reconstructing on every node visit + # Only the `function_name` field changes per class + self.base_decorator_keywords = [ + ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), + ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), + ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), + ] + self.base_decorator_func = ast.Name(id="codeflash_capture", ctx=ast.Load()) + + # Preconstruct starred/kwargs for super init injection for perf + self.super_starred = ast.Starred(value=ast.Name(id="args", ctx=ast.Load())) + self.super_kwarg = ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load())) + self.super_func = ast.Attribute( + value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]), + attr="__init__", + ctx=ast.Load(), + ) + self.init_vararg = ast.arg(arg="args") + self.init_kwarg = ast.arg(arg="kwargs") + self.init_self_arg = ast.arg(arg="self", annotation=None) + + # Precreate commonly reused AST fragments for classes that lack __init__ + # Create the super().__init__(*args, **kwargs) Expr (reuse prebuilt pieces) + self.super_call_expr = ast.Expr( + value=ast.Call(func=self.super_func, args=[self.super_starred], keywords=[self.super_kwarg]) + ) + # Create function arguments: self, *args, **kwargs (reuse arg nodes) + self.init_arguments = ast.arguments( + posonlyargs=[], + args=[self.init_self_arg], + vararg=self.init_vararg, + kwonlyargs=[], + kw_defaults=[], + kwarg=self.init_kwarg, + defaults=[], + ) + + # Pre-build reusable AST nodes for build_attrs_patch_block + self.load_ctx = ast.Load() + self.store_ctx = ast.Store() + self.args_name_load = ast.Name(id="args", ctx=self.load_ctx) + self.kwargs_name_load = ast.Name(id="kwargs", ctx=self.load_ctx) + self.self_arg_node = ast.arg(arg="self") + self.args_arg_node = ast.arg(arg="args") + self.kwargs_arg_node = ast.arg(arg="kwargs") + self.self_name_load = ast.Name(id="self", ctx=self.load_ctx) + self.starred_args = ast.Starred(value=self.args_name_load, ctx=self.load_ctx) + self.kwargs_keyword = ast.keyword(arg=None, value=self.kwargs_name_load) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: + # Check if our import already exists + if node.module == "codeflash_python.verification.codeflash_capture" and any( + alias.name == "codeflash_capture" for alias in node.names + ): + self.has_import = True + return node + + def visit_Module(self, node: ast.Module) -> ast.Module: + self.generic_visit(node) + + # Insert module-level monkey-patch wrappers for attrs classes immediately after their + # class definitions. We do this before inserting the import so indices stay stable. + if self.attrs_classes_to_patch: + new_body: list[ast.stmt] = [] + for stmt in node.body: + new_body.append(stmt) + if isinstance(stmt, ast.ClassDef) and stmt.name in self.attrs_classes_to_patch: + new_body.extend(self.build_attrs_patch_block(stmt.name, self.attrs_classes_to_patch[stmt.name])) + node.body = new_body + + # Add import statement + if not self.has_import and self.inserted_decorator: + import_stmt = ast.parse( + "from codeflash_python.verification.codeflash_capture import codeflash_capture" + ).body[0] + node.body.insert(0, import_stmt) + + return node + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + # Only modify the target class + if node.name not in self.target_classes: + return node + + has_init = False + # Build decorator node ONCE for each class, not per loop iteration + decorator = ast.Call( + func=self.base_decorator_func, + args=[], + keywords=[ + ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), + *self.base_decorator_keywords, + ], + ) + + # Only scan node.body once for both __init__ and decorator check + for item in node.body: + if ( + isinstance(item, ast.FunctionDef) + and item.name == "__init__" + and item.args.args + and isinstance(item.args.args[0], ast.arg) + and item.args.args[0].arg == "self" + ): + has_init = True + + # Check for existing decorator in-place, stop after finding one + for d in item.decorator_list: + if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture": + break + else: + # No decorator found + item.decorator_list.insert(0, decorator) + self.inserted_decorator = True + + if not has_init: + # Skip dataclasses — their __init__ is auto-generated at class creation time and isn't in the AST. + # The synthetic __init__ with super().__init__(*args, **kwargs) overrides it and fails because + # object.__init__() doesn't accept the dataclass field kwargs. + # TODO: support by saving a reference to the generated __init__ before overriding, e.g. + # _orig_init = ClassName.__init__; then calling _orig_init(self, *args, **kwargs) in the wrapper + for dec in node.decorator_list: + dec_name = self.expr_name(dec) + if dec_name is not None and dec_name.endswith("dataclass"): + return node + + # Skip NamedTuples — their __init__ is synthesized and cannot be overwritten. + for base in node.bases: + base_name = self.expr_name(base) + if base_name is not None and base_name.endswith("NamedTuple"): + return node + + # Attrs classes — their __init__ is auto-generated by the decorator at class creation + # time. With slots=True (the default for @attrs.define), attrs creates a brand-new class + # object, so the __class__ cell baked into a synthesised + # `super().__init__(*args, **kwargs)` still refers to the *original* class while `self` + # is already an instance of the *new* slots class, causing a TypeError. + # We therefore skip modifying the class body and instead emit a module-level + # monkey-patch block after the class (handled in visit_Module). + for dec in node.decorator_list: + dec_name = self.expr_name(dec) + if dec_name is not None: + parts = dec_name.split(".") + if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES: + if isinstance(dec, ast.Call): + for kw in dec.keywords: + if kw.arg == "init" and isinstance(kw.value, ast.Constant) and kw.value.value is False: + return node + self.attrs_classes_to_patch[node.name] = decorator + self.inserted_decorator = True + return node + + # Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments) + super_call = self.super_call_expr + # Create the complete function using prebuilt arguments/body but attach the class-specific decorator + + # Create the complete function + init_func = ast.FunctionDef( + name="__init__", args=self.init_arguments, body=[super_call], decorator_list=[decorator], returns=None + ) + + node.body.insert(0, init_func) + self.inserted_decorator = True + + return node + + def build_attrs_patch_block(self, class_name: str, decorator: ast.Call) -> list[ast.stmt]: + orig_name = f"_codeflash_orig_{class_name}_init" + patched_name = f"_codeflash_patched_{class_name}_init" + + class_name_load = ast.Name(id=class_name, ctx=self.load_ctx) + + # _codeflash_orig_ClassName_init = ClassName.__init__ + save_orig = ast.Assign( + targets=[ast.Name(id=orig_name, ctx=self.store_ctx)], + value=ast.Attribute(value=class_name_load, attr="__init__", ctx=self.load_ctx), + ) + + # def _codeflash_patched_ClassName_init(self, *args, **kwargs): + # return _codeflash_orig_ClassName_init(self, *args, **kwargs) + patched_func = ast.FunctionDef( + name=patched_name, + args=ast.arguments( + posonlyargs=[], + args=[self.self_arg_node], + vararg=self.args_arg_node, + kwonlyargs=[], + kw_defaults=[], + kwarg=self.kwargs_arg_node, + defaults=[], + ), + body=[ + ast.Return( + value=ast.Call( + func=ast.Name(id=orig_name, ctx=self.load_ctx), + args=[self.self_name_load, self.starred_args], + keywords=[self.kwargs_keyword], + ) + ) + ], + decorator_list=[], + returns=None, + ) + + # ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init) + assign_patched = ast.Assign( + targets=[ + ast.Attribute(value=ast.Name(id=class_name, ctx=self.load_ctx), attr="__init__", ctx=self.store_ctx) + ], + value=ast.Call(func=decorator, args=[ast.Name(id=patched_name, ctx=self.load_ctx)], keywords=[]), + ) + + return [save_orig, patched_func, assign_patched] + + def expr_name(self, node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Call): + return self.expr_name(node.func) + if isinstance(node, ast.Attribute): + parent = self.expr_name(node.value) + return f"{parent}.{node.attr}" if parent else node.attr + return None diff --git a/src/codeflash_python/verification/instrument_existing_tests.py b/src/codeflash_python/verification/instrument_existing_tests.py new file mode 100644 index 000000000..1ea968ae7 --- /dev/null +++ b/src/codeflash_python/verification/instrument_existing_tests.py @@ -0,0 +1,731 @@ +from __future__ import annotations + +import ast +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_core.models import FunctionParent, FunctionToOptimize +from codeflash_python.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path +from codeflash_python.code_utils.formatter import sort_imports +from codeflash_python.models.models import TestingMode +from codeflash_python.verification.device_sync import detect_frameworks_from_code +from codeflash_python.verification.wrapper_generation import create_wrapper_function + +if TYPE_CHECKING: + from collections.abc import Iterable + + from codeflash_python.models.models import CodePosition + + +logger = logging.getLogger("codeflash_python") + + +@dataclass(frozen=True) +class FunctionCallNodeArguments: + args: list[ast.expr] + keywords: list[ast.keyword] + + +def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments: + return FunctionCallNodeArguments(call_node.args, call_node.keywords) + + +def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: + # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. + # Small optimizations for tight loop: + if isinstance(node, ast.Call): + node_lineno = getattr(node, "lineno", None) + node_col_offset = getattr(node, "col_offset", None) + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) + if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None: + # Faster loop: reduce attribute lookups, use local variables for conditionals. + for pos in call_positions: + pos_line = pos.line_no + if pos_line is not None and node_lineno <= pos_line <= node_end_lineno: + if pos_line == node_lineno and node_col_offset <= pos.col_no: + return True + if ( + pos_line == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos.col_no + ): + return True + if node_lineno < pos_line < node_end_lineno: + return True + return False + + +def is_argument_name(name: str, arguments_node: ast.arguments) -> bool: + return any( + element.arg == name + for attribute_name in dir(arguments_node) + if isinstance(attribute := getattr(arguments_node, attribute_name), list) + for element in attribute + if isinstance(element, ast.arg) + ) + + +class InjectPerfOnly(ast.NodeTransformer): + def __init__( + self, + function: FunctionToOptimize, + module_path: str, + call_positions: list[CodePosition], + mode: TestingMode = TestingMode.BEHAVIOR, + ) -> None: + self.mode: TestingMode = mode + self.function_object = function + self.class_name = None + self.only_function_name = function.function_name + self.module_path = module_path + self.call_positions = call_positions + if len(function.parents) == 1 and function.parents[0].type == "ClassDef": + self.class_name = function.top_level_parent_name + + def find_and_update_line_node( + self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None + ) -> Iterable[ast.stmt] | None: + # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call, + # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. + + # Helper for manual walk + def iter_ast_calls(node): + # Generator to yield each ast.Call in test_node, preserves node identity + stack = [node] + while stack: + n = stack.pop() + if isinstance(n, ast.Call): + yield n + # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node), + # do a specialized BFS with only the necessary attributes + for _field, value in ast.iter_fields(n): + if isinstance(value, list): + for item in reversed(value): + if isinstance(item, ast.AST): + stack.append(item) + elif isinstance(value, ast.AST): + stack.append(value) + + # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead + return_statement = [test_node] + call_node = None + + # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals + fn_obj = self.function_object + module_path = self.module_path + mode = self.mode + qualified_name = fn_obj.qualified_name + + # Use locals for all 'current' values, only look up class/function/constant AST object once. + codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) + codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) + codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) + + for node in iter_ast_calls(test_node): + if not node_in_call_position(node, self.call_positions): + continue + + call_node = node + all_args = get_call_arguments(call_node) + # Two possible call types: Name and Attribute + node_func = node.func + + if isinstance(node_func, ast.Name): + function_name = node_func.id + + # Check if this is the function we want to instrument + if function_name != fn_obj.function_name: + continue + + if fn_obj.is_async: + return [test_node] + + # Build once, reuse objects. + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) + bind_call = ast.Assign( + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[ast.Name(id=function_name, ctx=ast.Load())], + keywords=[], + ), + attr="bind", + ctx=ast.Load(), + ), + args=all_args.args, + keywords=all_args.keywords, + ), + lineno=test_node.lineno, + col_offset=test_node.col_offset, + ) + + apply_defaults = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="apply_defaults", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=test_node.lineno + 1, + col_offset=test_node.col_offset, + ) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + base_args = [ + ast.Name(id=function_name, ctx=ast.Load()), + ast.Constant(value=module_path), + ast.Constant(value=test_class_name or None), + ast.Constant(value=node_name), + ast.Constant(value=qualified_name), + ast.Constant(value=index), + codeflash_loop_index, + ] + # Extend with BEHAVIOR extras if needed + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + # Extend with call args (performance) or starred bound args (behavior) + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + # Prepare keywords + if mode == TestingMode.BEHAVIOR: + node.keywords = [ + ast.keyword( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="kwargs", + ctx=ast.Load(), + ) + ) + ] + else: + node.keywords = call_node.keywords + + return_statement = ( + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] + ) + break + if isinstance(node_func, ast.Attribute): + function_to_test = node_func.attr + if function_to_test == fn_obj.function_name: + if fn_obj.is_async: + return [test_node] + + # Create the signature binding statements + + # Unparse only once + function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body + + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) + bind_call = ast.Assign( + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[function_name_expr], + keywords=[], + ), + attr="bind", + ctx=ast.Load(), + ), + args=all_args.args, + keywords=all_args.keywords, + ), + lineno=test_node.lineno, + col_offset=test_node.col_offset, + ) + + apply_defaults = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="apply_defaults", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=test_node.lineno + 1, + col_offset=test_node.col_offset, + ) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + base_args = [ + function_name_expr, + ast.Constant(value=module_path), + ast.Constant(value=test_class_name or None), + ast.Constant(value=node_name), + ast.Constant(value=qualified_name), + ast.Constant(value=index), + codeflash_loop_index, + ] + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + if mode == TestingMode.BEHAVIOR: + node.keywords = [ + ast.keyword( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="kwargs", + ctx=ast.Load(), + ) + ) + ] + else: + node.keywords = call_node.keywords + + # Return the signature binding statements along with the test_node + return_statement = ( + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] + ) + break + + if call_node is None: + return None + return return_statement + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes. + for inner_node in ast.walk(node): + if isinstance(inner_node, ast.FunctionDef): + self.visit_FunctionDef(inner_node, node.name) + + return node + + def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: + if node.name.startswith("test_"): + did_update = False + i = len(node.body) - 1 + while i >= 0: + line_node = node.body[i] + if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): + j = len(line_node.body) - 1 + while j >= 0: + compound_line_node: ast.stmt = line_node.body[j] + internal_node: ast.AST + for internal_node in ast.walk(compound_line_node): + if isinstance(internal_node, (ast.stmt, ast.Assign)): + updated_node = self.find_and_update_line_node( + internal_node, node.name, str(i) + "_" + str(j), test_class_name + ) + if updated_node is not None: + line_node.body[j : j + 1] = updated_node + did_update = True + break + j -= 1 + else: + updated_node = self.find_and_update_line_node(line_node, node.name, str(i), test_class_name) + if updated_node is not None: + node.body[i : i + 1] = updated_node + did_update = True + i -= 1 + if did_update: + node.body = [ + ast.Assign( + targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="int", ctx=ast.Load()), + args=[ + ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load() + ), + slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"), + ctx=ast.Load(), + ) + ], + keywords=[], + ), + lineno=node.lineno + 2, + col_offset=node.col_offset, + ), + *( + [ + ast.Assign( + targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())], + value=ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load() + ), + slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"), + ctx=ast.Load(), + ), + lineno=node.lineno + 1, + col_offset=node.col_offset, + ), + ast.Assign( + targets=[ast.Name(id="codeflash_con", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load() + ), + args=[ + ast.JoinedStr( + values=[ + ast.Constant( + value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}" + ), + ast.FormattedValue( + value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), + conversion=-1, + ), + ast.Constant(value=".sqlite"), + ] + ) + ], + keywords=[], + ), + lineno=node.lineno + 3, + col_offset=node.col_offset, + ), + ast.Assign( + targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="codeflash_con", ctx=ast.Load()), + attr="cursor", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=node.lineno + 4, + col_offset=node.col_offset, + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="codeflash_cur", ctx=ast.Load()), + attr="execute", + ctx=ast.Load(), + ), + args=[ + ast.Constant( + value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT," + " test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT," + " loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + ], + keywords=[], + ), + lineno=node.lineno + 5, + col_offset=node.col_offset, + ), + ] + if self.mode == TestingMode.BEHAVIOR + else [] + ), + *node.body, + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ) + ] + if self.mode == TestingMode.BEHAVIOR + else [] + ), + ] + return node + + +class AsyncCallInstrumenter(ast.NodeTransformer): + def __init__( + self, + function: FunctionToOptimize, + module_path: str, + call_positions: list[CodePosition], + mode: TestingMode = TestingMode.BEHAVIOR, + ) -> None: + self.mode = mode + self.function_object = function + self.class_name = None + self.only_function_name = function.function_name + self.module_path = module_path + self.call_positions = call_positions + self.did_instrument = False + # Track function call count per test function + self.async_call_counter: dict[str, int] = {} + if len(function.parents) == 1 and function.parents[0].type == "ClassDef": + self.class_name = function.top_level_parent_name + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + result = self.generic_visit(node) + assert isinstance(result, ast.ClassDef) + return result + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + if not node.name.startswith("test_"): + return node + + result = self.process_test_function(node) + assert isinstance(result, ast.AsyncFunctionDef) + return result + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + # Only process test functions + if not node.name.startswith("test_"): + return node + + result = self.process_test_function(node) + assert isinstance(result, ast.FunctionDef) + return result + + def process_test_function( + self, node: ast.AsyncFunctionDef | ast.FunctionDef + ) -> ast.AsyncFunctionDef | ast.FunctionDef: + # Initialize counter for this test function + if node.name not in self.async_call_counter: + self.async_call_counter[node.name] = 0 + + new_body = [] + + # Optimize ast.walk calls inside instrument_statement, by scanning only relevant nodes + for _i, stmt in enumerate(node.body): + transformed_stmt, added_env_assignment = self.optimized_instrument_statement(stmt) + + if added_env_assignment: + current_call_index = self.async_call_counter[node.name] + self.async_call_counter[node.name] += 1 + + env_assignment = ast.Assign( + targets=[ + ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load() + ), + slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"), + ctx=ast.Store(), + ) + ], + value=ast.Constant(value=f"{current_call_index}"), + lineno=stmt.lineno if hasattr(stmt, "lineno") else 1, + ) + new_body.append(env_assignment) + self.did_instrument = True + + new_body.append(transformed_stmt) + + node.body = new_body + return node + + def instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]: + for node in ast.walk(stmt): + if ( + isinstance(node, ast.Await) + and isinstance(node.value, ast.Call) + and self.is_target_call(node.value) + and self.call_in_positions(node.value) + ): + # Check if this call is in one of our target positions + return stmt, True # Return original statement but signal we added env var + + return stmt, False + + def is_target_call(self, call_node: ast.Call) -> bool: + """Check if this call node is calling our target async function.""" + if isinstance(call_node.func, ast.Name): + return call_node.func.id == self.function_object.function_name + if isinstance(call_node.func, ast.Attribute): + return call_node.func.attr == self.function_object.function_name + return False + + def call_in_positions(self, call_node: ast.Call) -> bool: + if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"): + return False + + return node_in_call_position(call_node, self.call_positions) + + # Optimized version: only walk child nodes for Await + def optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]: + # Stack-based DFS, manual for relevant Await nodes + stack = [stmt] + while stack: + node = stack.pop() + # Favor direct ast.Await detection + if isinstance(node, ast.Await): + val = node.value + if isinstance(val, ast.Call) and self.is_target_call(val) and self.call_in_positions(val): + return stmt, True + # Use _fields instead of ast.walk for less allocations + for fname in getattr(node, "_fields", ()): + child = getattr(node, fname, None) + if isinstance(child, list): + stack.extend(child) + elif isinstance(child, ast.AST): + stack.append(child) + return stmt, False + + +class FunctionImportedAsVisitor(ast.NodeVisitor): + """Checks if a function has been imported as an alias. We only care about the alias then. + + from numpy import array as np_array + np_array is what we want + """ + + def __init__(self, function: FunctionToOptimize) -> None: + assert len(function.parents) <= 1, "Only support functions with one or less parent" + self.imported_as = function + self.function = function + if function.parents: + self.to_match = function.parents[0].name + else: + self.to_match = function.function_name + + # TODO: Validate if the function imported is actually from the right module + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + for alias in node.names: + if alias.name == self.to_match and hasattr(alias, "asname") and alias.asname is not None: + if self.function.parents: + self.imported_as = FunctionToOptimize( + function_name=self.function.function_name, + parents=[FunctionParent(alias.asname, "ClassDef")], + file_path=self.function.file_path, + starting_line=self.function.starting_line, + ending_line=self.function.ending_line, + is_async=self.function.is_async, + language=self.function.language, + ) + else: + self.imported_as = FunctionToOptimize( + function_name=alias.asname, + parents=[], + file_path=self.function.file_path, + starting_line=self.function.starting_line, + ending_line=self.function.ending_line, + is_async=self.function.is_async, + language=self.function.language, + ) + + +def inject_async_profiling_into_existing_test( + test_path: Path, + call_positions: list[CodePosition], + function_to_optimize: FunctionToOptimize, + tests_project_root: Path, + mode: TestingMode = TestingMode.BEHAVIOR, +) -> tuple[bool, str | None]: + """Inject profiling for async function calls by setting environment variables before each call.""" + with test_path.open(encoding="utf8") as f: + test_code = f.read() + + try: + tree = ast.parse(test_code) + except SyntaxError: + logger.exception("Syntax error in code in file - %s", test_path) + return False, None + # TODO: Pass the full name of function here, otherwise we can run into namespace clashes + test_module_path = module_name_from_file_path(test_path, tests_project_root) + import_visitor = FunctionImportedAsVisitor(function_to_optimize) + import_visitor.visit(tree) + func = import_visitor.imported_as + + async_instrumenter = AsyncCallInstrumenter(func, test_module_path, call_positions, mode=mode) + tree = async_instrumenter.visit(tree) + + if not async_instrumenter.did_instrument: + return False, None + + # Add necessary imports + new_imports = [ast.Import(names=[ast.alias(name="os")])] + + tree.body = [*new_imports, *tree.body] + return True, sort_imports(ast.unparse(tree), float_to_top=True) + + +def inject_profiling_into_existing_test( + test_path: Path, + call_positions: list[CodePosition], + function_to_optimize: FunctionToOptimize, + tests_project_root: Path, + mode: TestingMode = TestingMode.BEHAVIOR, +) -> tuple[bool, str | None]: + tests_project_root = tests_project_root.resolve() + if function_to_optimize.is_async: + return inject_async_profiling_into_existing_test( + test_path, call_positions, function_to_optimize, tests_project_root, mode + ) + + with test_path.open(encoding="utf8") as f: + test_code = f.read() + + used_frameworks = detect_frameworks_from_code(test_code) + try: + tree = ast.parse(test_code) + except SyntaxError: + logger.exception("Syntax error in code in file - %s", test_path) + return False, None + + test_module_path = module_name_from_file_path(test_path, tests_project_root) + import_visitor = FunctionImportedAsVisitor(function_to_optimize) + import_visitor.visit(tree) + func = import_visitor.imported_as + + tree = InjectPerfOnly(func, test_module_path, call_positions, mode=mode).visit(tree) + new_imports = [ + ast.Import(names=[ast.alias(name="time")]), + ast.Import(names=[ast.alias(name="gc")]), + ast.Import(names=[ast.alias(name="os")]), + ] + if mode == TestingMode.BEHAVIOR: + new_imports.extend( + [ + ast.Import(names=[ast.alias(name="inspect")]), + ast.Import(names=[ast.alias(name="sqlite3")]), + ast.Import(names=[ast.alias(name="dill", asname="pickle")]), + ] + ) + # Add framework imports for GPU sync code (needed when framework is only imported via submodule) + for framework_name, framework_alias in used_frameworks.items(): + if framework_alias == framework_name: + # Only add import if we're using the framework name directly (not an alias) + # This handles cases like "from torch.nn import Module" where torch needs to be imported + new_imports.append(ast.Import(names=[ast.alias(name=framework_name)])) + else: + # If there's an alias, use it (e.g., "import torch as th") + new_imports.append(ast.Import(names=[ast.alias(name=framework_name, asname=framework_alias)])) + additional_functions = [create_wrapper_function(mode, used_frameworks)] + + tree.body = [*new_imports, *additional_functions, *tree.body] + return True, sort_imports(ast.unparse(tree), float_to_top=True) diff --git a/src/codeflash_python/verification/parse_test_output.py b/src/codeflash_python/verification/parse_test_output.py new file mode 100644 index 000000000..6650bf797 --- /dev/null +++ b/src/codeflash_python/verification/parse_test_output.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import logging +import os +import sqlite3 +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING + +import dill as pickle +from lxml.etree import XMLParser, parse # type: ignore[import-not-found] + +from codeflash_python.code_utils.code_utils import get_run_tmp_file +from codeflash_python.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType +from codeflash_python.verification.path_utils import file_path_from_module_name +from codeflash_python.verification.test_output_utils import merge_test_results, parse_test_failures_from_stdout + +if TYPE_CHECKING: + import subprocess + + from codeflash_core.config import TestConfig + from codeflash_python.models.models import CodeOptimizationContext, CoverageData, TestFiles + +logger = logging.getLogger("codeflash_python") +DEBUG_MODE = os.environ.get("CODEFLASH_DEBUG", "").lower() in ("1", "true") + + +def parse_func(file_path: Path) -> XMLParser: + """Parse the XML file with lxml.etree.XMLParser as the backend.""" + xml_parser = XMLParser(huge_tree=True) + return parse(file_path, xml_parser) + + +def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults: + test_results = TestResults() + if not file_location.exists(): + logger.debug("No test results for %s found.", file_location) + return test_results + + with file_location.open("rb") as file: + try: + while file: + len_next_bytes = file.read(4) + if not len_next_bytes: + return test_results + len_next = int.from_bytes(len_next_bytes, byteorder="big") + encoded_test_bytes = file.read(len_next) + encoded_test_name = encoded_test_bytes.decode("ascii") + duration_bytes = file.read(8) + duration = int.from_bytes(duration_bytes, byteorder="big") + len_next_bytes = file.read(4) + len_next = int.from_bytes(len_next_bytes, byteorder="big") + test_pickle_bin = file.read(len_next) + loop_index_bytes = file.read(8) + loop_index = int.from_bytes(loop_index_bytes, byteorder="big") + len_next_bytes = file.read(4) + len_next = int.from_bytes(len_next_bytes, byteorder="big") + invocation_id_bytes = file.read(len_next) + invocation_id = invocation_id_bytes.decode("ascii") + + invocation_id_object = InvocationId.from_str_id(encoded_test_name, invocation_id) + test_file_path = file_path_from_module_name( + invocation_id_object.test_module_path, test_config.tests_project_rootdir + ) + + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + try: + test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None + except Exception as e: + if DEBUG_MODE: + logger.exception("Failed to load pickle file for %s Exception: %s", encoded_test_name, e) + continue + assert test_type is not None, f"Test type not found for {test_file_path}" + test_results.add( + function_test_invocation=FunctionTestInvocation( + loop_index=loop_index, + id=invocation_id_object, + file_name=test_file_path, + did_pass=True, + runtime=duration, + test_framework=test_config.test_framework, + test_type=test_type, + return_value=test_pickle, + timed_out=False, + verification_type=VerificationType.FUNCTION_CALL, + ) + ) + except Exception as e: + logger.warning("Failed to parse test results from %s. Exception: %s", file_location, e) + return test_results + return test_results + + +def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults: + test_results = TestResults() + if not sqlite_file_path.exists(): + logger.warning("No test results for %s found.", sqlite_file_path) + return test_results + db = None + try: + db = sqlite3.connect(sqlite_file_path) + cur = db.cursor() + data = cur.execute( + "SELECT test_module_path, test_class_name, test_function_name, " + "function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results" + ).fetchall() + except Exception as e: + logger.warning("Failed to parse test results from %s. Exception: %s", sqlite_file_path, e) + if db is not None: + db.close() + return test_results + finally: + db.close() + + for val in data: + try: + test_module_path = val[0] + test_class_name = val[1] if val[1] else None + test_function_name = val[2] if val[2] else None + function_getting_tested = val[3] + + test_file_path = file_path_from_module_name(test_module_path, test_config.tests_project_rootdir) + + loop_index = val[4] + iteration_id = val[5] + runtime = val[6] + verification_type = val[8] + if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}: + test_type = TestType.INIT_STATE_TEST + else: + # Try original_file_path first (for existing tests that were instrumented) + test_type = test_files.get_test_type_by_original_file_path(test_file_path) + logger.debug("[PARSE-DEBUG] test_module=%s, test_file_path=%s", test_module_path, test_file_path) + logger.debug("[PARSE-DEBUG] by_original_file_path: %s", test_type) + # If not found, try instrumented_behavior_file_path (for generated tests) + if test_type is None: + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + logger.debug("[PARSE-DEBUG] by_instrumented_file_path: %s", test_type) + if test_type is None: + # Skip results where test type cannot be determined + logger.debug("Skipping result for %s: could not determine test type", test_function_name) + continue + logger.debug("[PARSE-DEBUG] FINAL test_type=%s", test_type) + + ret_val = None + if loop_index == 1 and val[7]: + try: + ret_val = (pickle.loads(val[7]),) + except Exception as e: + logger.debug("Failed to deserialize return value for %s: %s", test_function_name, e) + continue + + test_results.add( + function_test_invocation=FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path=test_module_path, + test_class_name=test_class_name, + test_function_name=test_function_name, + function_getting_tested=function_getting_tested, + iteration_id=iteration_id, + ), + file_name=test_file_path, + did_pass=True, + runtime=runtime, + test_framework=test_config.test_framework, + test_type=test_type, + return_value=ret_val, + timed_out=False, + verification_type=VerificationType(verification_type) if verification_type else None, + ) + ) + except Exception: + logger.exception("Failed to parse sqlite test results for %s", sqlite_file_path) + # Hardcoding the test result to True because the test did execute and we are only interested in the return values, + # the did_pass comes from the xml results file + return test_results + + +def parse_test_xml( + test_xml_file_path: Path, + test_files: TestFiles, + test_config: TestConfig, + run_result: subprocess.CompletedProcess | None = None, +) -> TestResults: + from codeflash_python.verification.parse_xml import parse_python_test_xml + + return parse_python_test_xml(test_xml_file_path, test_files, test_config, run_result) + + +def parse_test_results( + test_xml_path: Path, + test_files: TestFiles, + test_config: TestConfig, + optimization_iteration: int, + function_name: str | None, + source_file: Path | None, + coverage_database_file: Path | None, + coverage_config_file: Path | None, + code_context: CodeOptimizationContext | None = None, + run_result: subprocess.CompletedProcess | None = None, +) -> tuple[TestResults, CoverageData | None]: + test_results_xml = parse_test_xml( + test_xml_path, test_files=test_files, test_config=test_config, run_result=run_result + ) + + test_results_data = TestResults() + + try: + sql_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite")) + if sql_results_file.exists(): + test_results_data = parse_sqlite_test_results( + sqlite_file_path=sql_results_file, test_files=test_files, test_config=test_config + ) + logger.debug("Parsed %s results from SQLite", len(test_results_data.test_results)) + except Exception as e: + logger.exception("Failed to parse SQLite test results: %s", e) + + try: + bin_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin")) + if bin_results_file.exists(): + bin_test_results = parse_test_return_values_bin( + bin_results_file, test_files=test_files, test_config=test_config + ) + for result in bin_test_results: + test_results_data.add(result) + logger.debug("Merged %s results from binary file", len(bin_test_results)) + except AttributeError as e: + logger.exception(e) + + # Cleanup temp files + get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin")).unlink(missing_ok=True) + + get_run_tmp_file(Path("pytest_results.xml")).unlink(missing_ok=True) + get_run_tmp_file(Path("unittest_results.xml")).unlink(missing_ok=True) + get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite")).unlink(missing_ok=True) + + results = merge_test_results(test_results_xml, test_results_data, test_config.test_framework) + + all_args = False + coverage = None + if coverage_database_file and source_file and code_context and function_name: + all_args = True + from codeflash_python.verification.coverage_utils import CoverageUtils + + coverage = CoverageUtils.load_from_sqlite_database( + database_path=coverage_database_file, + config_path=coverage_config_file, # type: ignore[invalid-argument-type] + source_code_path=source_file, + code_context=code_context, + function_name=function_name, + ) + if coverage: + coverage.log_coverage() + if run_result: + try: + failures = parse_test_failures_from_stdout(run_result.stdout) + results.test_failures = failures + except Exception as e: + logger.exception(e) + + return results, coverage if all_args else None diff --git a/src/codeflash_python/verification/parse_xml.py b/src/codeflash_python/verification/parse_xml.py new file mode 100644 index 000000000..232b8c153 --- /dev/null +++ b/src/codeflash_python/verification/parse_xml.py @@ -0,0 +1,245 @@ +r"""Python-specific JUnit XML parsing with 6-field timing markers. + +Python uses extended 6-field markers: + Start: !$######module:class_prefix.test_func:func_tested:loop_index:iteration_id######$!\n + End: !######module:class_prefix.test_func:func_tested:loop_index:iteration_id:runtime######! +""" + +from __future__ import annotations + +import logging +import os +import re +from typing import TYPE_CHECKING + +from junitparser.xunit2 import JUnitXml + +from codeflash_python.code_utils.code_utils import module_name_from_file_path +from codeflash_python.models.models import FunctionTestInvocation, InvocationId, TestResults +from codeflash_python.verification.path_utils import file_path_from_module_name + +logger = logging.getLogger("codeflash_python") + +if TYPE_CHECKING: + import subprocess + from pathlib import Path + + from lxml import etree # type: ignore[import-not-found] + + from codeflash_core.config import TestConfig + from codeflash_python.models.models import TestFiles + +matches_re_start = re.compile( + r"!\$######([^:]*)" # group 1: module path + r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty + r"([^.:]*)" # group 3: test function name + r":([^:]*)" # group 4: function being tested + r":([^:]*)" # group 5: loop index + r":([^#]*)" # group 6: iteration id + r"######\$!\n" +) +matches_re_end = re.compile( + r"!######([^:]*)" # group 1: module path + r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty + r"([^.:]*)" # group 3: test function name + r":([^:]*)" # group 4: function being tested + r":([^:]*)" # group 5: loop index + r":([^#]*)" # group 6: iteration_id or iteration_id:runtime + r"######!" +) + + +def parse_func(file_path: Path) -> etree._ElementTree: + from lxml.etree import XMLParser, parse # type: ignore[import-not-found] + + xml_parser = XMLParser(huge_tree=True) + return parse(file_path, xml_parser) + + +def parse_python_test_xml( + test_xml_file_path: Path, + test_files: TestFiles, + test_config: TestConfig, + run_result: subprocess.CompletedProcess | None = None, +) -> TestResults: + from codeflash_python.verification.test_output_utils import resolve_test_file_from_class_path + + test_results = TestResults() + if not test_xml_file_path.exists(): + logger.warning("No test results for %s found.", test_xml_file_path) + return test_results + try: + xml = JUnitXml.fromfile(str(test_xml_file_path), parse_func=parse_func) + except Exception as e: + logger.warning("Failed to parse %s as JUnitXml. Exception: %s", test_xml_file_path, e) + return test_results + base_dir = test_config.tests_project_rootdir + + for suite in xml: + for testcase in suite: + class_name = testcase.classname + test_file_name = suite._elem.attrib.get("file") # noqa: SLF001 + if ( + test_file_name == f"unittest{os.sep}loader.py" + and class_name == "unittest.loader._FailedTest" + and suite.errors == 1 + and suite.tests == 1 + ): + logger.info("Test failed to load, skipping it.") + if run_result is not None: + if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str): + logger.info("Test log - STDOUT : %s \nSTDERR : %s", run_result.stdout, run_result.stderr) + else: + logger.info( + "Test log - STDOUT : %s \nSTDERR : %s", + run_result.stdout.decode(), + run_result.stderr.decode(), + ) + return test_results + + test_class_path = testcase.classname + if test_class_path and test_class_path.split(".")[0] in ("pytest", "_pytest"): + logger.debug("Skipping pytest-internal test entry: %s", test_class_path) + continue + try: + if testcase.name is None: + logger.debug( + "testcase.name is None for testcase %r in file %s, skipping", testcase, test_xml_file_path + ) + continue + test_function = testcase.name.split("[", 1)[0] if "[" in testcase.name else testcase.name + except (AttributeError, TypeError) as e: + msg = ( + f"Accessing testcase.name in parse_test_xml for testcase {testcase!r} in file" + f" {test_xml_file_path} has exception: {e}" + ) + logger.exception(msg) + continue + if test_file_name is None: + if test_class_path: + test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir) + if test_file_path is None: + logger.warning("Could not find the test for file name - %s ", test_class_path) + continue + else: + test_file_path = file_path_from_module_name(test_function, base_dir) + else: + test_file_path = base_dir / test_file_name + assert test_file_path, f"Test file path not found for {test_file_name}" + + if not test_file_path.exists(): + logger.warning("Could not find the test for file name - %s ", test_file_path) + continue + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + if test_type is None: + test_type = test_files.get_test_type_by_original_file_path(test_file_path) + if test_type is None: + registered_paths = [str(tf.instrumented_behavior_file_path) for tf in test_files.test_files] + logger.warning( + "Test type not found for '%s'. Registered test files: %s. Skipping test case.", + test_file_path, + registered_paths, + ) + continue + test_module_path = module_name_from_file_path(test_file_path, test_config.tests_project_rootdir) + result = testcase.is_passed + test_class = None + if class_name is not None and class_name.startswith(test_module_path): + test_class = class_name[len(test_module_path) + 1 :] + + loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1 + + timed_out = False + if len(testcase.result) > 1: + logger.debug("!!!!!Multiple results for %s in %s!!!", testcase.name or "", test_xml_file_path) + if len(testcase.result) == 1: + message = testcase.result[0].message.lower() + if "failed: timeout >" in message or "timed out" in message: + timed_out = True + + sys_stdout = testcase.system_out or "" + + begin_matches = list(matches_re_start.finditer(sys_stdout)) + end_matches: dict[tuple, re.Match] = {} + for match in matches_re_end.finditer(sys_stdout): + groups = match.groups() + if len(groups[5].split(":")) > 1: + iteration_id = groups[5].split(":")[0] + groups = (*groups[:5], iteration_id) + end_matches[groups] = match + + if not begin_matches: + test_results.add( + FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path=test_module_path, + test_class_name=test_class, + test_function_name=test_function, + function_getting_tested="", + iteration_id="", + ), + file_name=test_file_path, + runtime=None, + test_framework=test_config.test_framework, + did_pass=result, + test_type=test_type, + return_value=None, + timed_out=timed_out, + stdout="", + ) + ) + else: + for match_index, match in enumerate(begin_matches): + groups = match.groups() + runtime = None + + end_match = end_matches.get(groups) + iteration_id = groups[5] + if end_match: + stdout = sys_stdout[match.end() : end_match.start()] + split_val = end_match.groups()[5].split(":") + if len(split_val) > 1: + iteration_id = split_val[0] + runtime = int(split_val[1]) + else: + iteration_id, runtime = split_val[0], None + elif match_index == len(begin_matches) - 1: + stdout = sys_stdout[match.end() :] + else: + stdout = sys_stdout[match.end() : begin_matches[match_index + 1].start()] + + test_results.add( + FunctionTestInvocation( + loop_index=int(groups[4]), + id=InvocationId( + test_module_path=groups[0], + test_class_name=None if groups[1] == "" else groups[1][:-1], + test_function_name=groups[2], + function_getting_tested=groups[3], + iteration_id=iteration_id, + ), + file_name=test_file_path, + runtime=runtime, + test_framework=test_config.test_framework, + did_pass=result, + test_type=test_type, + return_value=None, + timed_out=timed_out, + stdout=stdout, + ) + ) + + if not test_results: + logger.info( + "Tests '%s' failed to run, skipping", [test_file.original_file_path for test_file in test_files.test_files] + ) + if run_result is not None: + stdout, stderr = "", "" + try: + stdout = run_result.stdout.decode() + stderr = run_result.stderr.decode() + except AttributeError: + stdout = run_result.stderr + logger.debug("Test log - STDOUT : %s \nSTDERR : %s", stdout, stderr) + return test_results diff --git a/src/codeflash_python/verification/path_utils.py b/src/codeflash_python/verification/path_utils.py new file mode 100644 index 000000000..da796c193 --- /dev/null +++ b/src/codeflash_python/verification/path_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import os +from functools import lru_cache +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + +def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path: + """Get file path from module path.""" + return project_root_path / (module_name.replace(".", os.sep) + ".py") + + +@lru_cache(maxsize=100) +def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Path | None: + partial_test_class = test_module_name + while partial_test_class: + test_path = file_path_from_module_name(partial_test_class, base_dir) + if (base_dir / test_path).exists(): + return base_dir / test_path + partial_test_class = ".".join(partial_test_class.split(".")[:-1]) + return None diff --git a/src/codeflash_python/verification/pytest_plugin.py b/src/codeflash_python/verification/pytest_plugin.py new file mode 100644 index 000000000..227bc1826 --- /dev/null +++ b/src/codeflash_python/verification/pytest_plugin.py @@ -0,0 +1,592 @@ +from __future__ import annotations + +import contextlib +import inspect + +# System Imports +import logging +import os +import platform +import re +import sys +import time as _time_module +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Callable +from unittest import TestCase + +# PyTest Imports +import pytest +from pluggy import HookspecMarker + +from codeflash_python.code_utils.config_consts import ( + STABILITY_CENTER_TOLERANCE, + STABILITY_SPREAD_TOLERANCE, + STABILITY_WINDOW_SIZE, +) + +if TYPE_CHECKING: + from _pytest.config import Config, Parser + from _pytest.main import Session + from _pytest.python import Metafunc + +_HAS_NUMPY = find_spec("numpy") is not None + +_PROTECTED_MODULES = frozenset( + {"gc", "inspect", "os", "sys", "time", "functools", "pathlib", "typing", "dill", "pytest", "importlib"} +) + +SECONDS_IN_HOUR: float = 3600 +SECONDS_IN_MINUTE: float = 60 +SHORTEST_AMOUNT_OF_TIME: float = 0 +hookspec = HookspecMarker("pytest") + + +class InvalidTimeParameterError(Exception): + pass + + +class UnexpectedError(Exception): + pass + + +if platform.system() == "Linux": + import resource + + # We set the memory limit to 85% of total system memory + swap when swap exists + swap_file_path = Path("/proc/swaps") + swap_exists = swap_file_path.is_file() + swap_size = 0 + + if swap_exists: + with swap_file_path.open("r") as f: + swap_lines = f.readlines() + swap_exists = len(swap_lines) > 1 # First line is header + + if swap_exists: + # Parse swap size from lines after header + for line in swap_lines[1:]: + parts = line.split() + if len(parts) >= 3: + # Swap size is in KB in the 3rd column + with contextlib.suppress(ValueError, IndexError): + swap_size += int(parts[2]) * 1024 # Convert KB to bytes + + # Get total system memory + total_memory = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") + + # Add swap to total available memory if swap exists + if swap_exists: + total_memory += swap_size + + # Set the memory limit to 85% of total memory (RAM plus swap) + memory_limit = int(total_memory * 0.85) + + # Set both soft and hard limits + resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit)) + + +# Store references to original functions before any patching +_ORIGINAL_TIME_TIME = _time_module.time +_ORIGINAL_PERF_COUNTER = _time_module.perf_counter +_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns +_ORIGINAL_TIME_SLEEP = _time_module.sleep + + +# Apply deterministic patches for reproducible test execution +def apply_deterministic_patches() -> None: + """Apply patches to make all sources of randomness deterministic.""" + import datetime + import random + import time + import uuid + + # Store original functions (these are already saved globally above) + _original_time = time.time + _original_perf_counter = time.perf_counter + _original_datetime_now = datetime.datetime.now + _original_datetime_utcnow = datetime.datetime.utcnow # type: ignore[attr-defined] + _original_uuid4 = uuid.uuid4 + _original_uuid1 = uuid.uuid1 + _original_random = random.random + + # Fixed deterministic values + fixed_timestamp = 1761717605.108106 + fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc) + fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012") + + # Counter for perf_counter to maintain relative timing + _perf_counter_start = fixed_timestamp + _perf_counter_calls = 0 + + def mock_time_time() -> float: + """Return fixed timestamp while preserving performance characteristics.""" + _original_time() # Maintain performance characteristics + return fixed_timestamp + + def mock_perf_counter() -> float: + """Return incrementing counter for relative timing.""" + nonlocal _perf_counter_calls + _original_perf_counter() # Maintain performance characteristics + _perf_counter_calls += 1 + return _perf_counter_start + (_perf_counter_calls * 0.001) # Increment by 1ms each call + + def mock_datetime_now(tz: datetime.timezone | None = None) -> datetime.datetime: + """Return fixed datetime while preserving performance characteristics.""" + _original_datetime_now(tz) # Maintain performance characteristics + if tz is None: + return fixed_datetime + return fixed_datetime.replace(tzinfo=tz) + + def mock_datetime_utcnow() -> datetime.datetime: + """Return fixed UTC datetime while preserving performance characteristics.""" + _original_datetime_utcnow() # type: ignore[attr-defined] # Maintain performance characteristics + return fixed_datetime + + def mock_uuid4() -> uuid.UUID: + """Return fixed UUID4 while preserving performance characteristics.""" + _original_uuid4() # Maintain performance characteristics + return fixed_uuid + + def mock_uuid1(node: int | None = None, clock_seq: int | None = None) -> uuid.UUID: + """Return fixed UUID1 while preserving performance characteristics.""" + _original_uuid1(node, clock_seq) # Maintain performance characteristics + return fixed_uuid + + def mock_random() -> float: + """Return deterministic random value while preserving performance characteristics.""" + _original_random() # Maintain performance characteristics + return 0.123456789 # Fixed random value + + # Apply patches + time.time = mock_time_time # type: ignore[misc] + time.perf_counter = mock_perf_counter # type: ignore[misc] + uuid.uuid4 = mock_uuid4 # type: ignore[misc] + uuid.uuid1 = mock_uuid1 # type: ignore[misc] + + # Seed random module for other random functions + random.seed(42) + random.random = mock_random # type: ignore[method-assign] + + # For datetime, we need to use a different approach since we can't patch class methods + # Store original methods for potential later use + import builtins + + builtins._original_datetime_now = _original_datetime_now # type: ignore[attr-defined] # noqa: SLF001 + builtins._original_datetime_utcnow = _original_datetime_utcnow # type: ignore[attr-defined] # noqa: SLF001 + builtins._mock_datetime_now = mock_datetime_now # type: ignore[attr-defined] # noqa: SLF001 + builtins._mock_datetime_utcnow = mock_datetime_utcnow # type: ignore[attr-defined] # noqa: SLF001 + + # Patch numpy.random if available + if _HAS_NUMPY: + import numpy as np + + # Use modern numpy random generator approach + np.random.default_rng(42) + np.random.seed(42) # Keep legacy seed for compatibility # noqa: NPY002 + + # Patch os.urandom if needed + try: + import os + + _original_urandom = os.urandom + + def mock_urandom(n: int) -> bytes: + _original_urandom(n) # Maintain performance characteristics + return b"\x42" * n # Fixed bytes + + os.urandom = mock_urandom # type: ignore[method-assign] + except (ImportError, AttributeError): + pass + + +# Note: Deterministic patches are applied conditionally, not globally +# They should only be applied when running CodeFlash optimization tests + + +def pytest_addoption(parser: Parser) -> None: + """Add command line options.""" + pytest_loops = parser.getgroup("loops") + pytest_loops.addoption( + "--codeflash_delay", + action="store", + default=0, + type=float, + help="The amount of time to wait between each test loop.", + ) + pytest_loops.addoption( + "--codeflash_hours", action="store", default=0, type=float, help="The number of hours to loop the tests for." + ) + pytest_loops.addoption( + "--codeflash_minutes", + action="store", + default=0, + type=float, + help="The number of minutes to loop the tests for.", + ) + pytest_loops.addoption( + "--codeflash_seconds", + action="store", + default=0, + type=float, + help="The number of seconds to loop the tests for.", + ) + + pytest_loops.addoption( + "--codeflash_loops", action="store", default=1, type=int, help="The number of times to loop each test" + ) + + pytest_loops.addoption( + "--codeflash_min_loops", + action="store", + default=1, + type=int, + help="The minimum number of times to loop each test", + ) + + pytest_loops.addoption( + "--codeflash_max_loops", + action="store", + default=100_000, + type=int, + help="The maximum number of times to loop each test", + ) + + pytest_loops.addoption( + "--codeflash_loops_scope", + action="store", + default="function", + type=str, + choices=("function", "class", "module", "session"), + help="Scope for looping tests", + ) + pytest_loops.addoption( + "--codeflash_stability_check", + action="store", + default="false", + type=str, + choices=("true", "false"), + help="Enable stability checks for the loops", + ) + + +@pytest.hookimpl(trylast=True) +def pytest_configure(config: Config) -> None: + config.addinivalue_line("markers", "loops(n): run the given test function `n` times.") + config.pluginmanager.register(PytestLoops(config), PytestLoops.name) + + # Apply deterministic patches when the plugin is configured + apply_deterministic_patches() + + +def get_runtime_from_stdout(stdout: str) -> int | None: + marker_start = "!######" + marker_end = "######!" + + if not stdout: + return None + + end = stdout.rfind(marker_end) + if end == -1: + return None + + start = stdout.rfind(marker_start, 0, end) + if start == -1: + return None + + payload = stdout[start + len(marker_start) : end] + last_colon = payload.rfind(":") + if last_colon == -1: + return None + try: + return int(payload[last_colon + 1 :]) + except ValueError: + return None + + +_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$") +_NODEID_LOOP_PATTERN = re.compile(r"\[ \d+ \]") + + +def should_stop( + runtimes: list[int], + window: int, + min_window_size: int, + center_rel_tol: float = STABILITY_CENTER_TOLERANCE, + spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE, +) -> bool: + if len(runtimes) < window: + return False + + if len(runtimes) < min_window_size: + return False + + recent = runtimes[-window:] + + # Use sorted array for faster median and min/max operations + recent_sorted = sorted(recent) + mid = window // 2 + m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2 + + # 1) All recent points close to the median + centered = True + for r in recent: + if abs(r - m) / m > center_rel_tol: + centered = False + break + + # 2) Window spread is small + r_min, r_max = recent_sorted[0], recent_sorted[-1] + if r_min == 0: + return False + spread_ok = (r_max - r_min) / r_min <= spread_rel_tol + + return centered and spread_ok + + +class PytestLoops: + name: str = "pytest-loops" + + def __init__(self, config: Config) -> None: + # Turn debug prints on only if "-vv" or more passed + level = logging.DEBUG if config.option.verbose > 1 else logging.INFO + logging.basicConfig(level=level) + self.logger = logging.getLogger(self.name) + self.runtime_data_by_test_case: dict[str, list[int]] = {} + self.enable_stability_check: bool = ( + str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true" + ) + self.module_clearables: dict[str, list[Callable]] = {} + + @pytest.hookimpl + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: + if not self.enable_stability_check: + return + if report.when == "call" and report.passed: + duration_ns = get_runtime_from_stdout(report.capstdout) + if duration_ns: + clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid) + self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns) + + @hookspec(firstresult=True) + def pytest_runtestloop(self, session: Session) -> bool: + """Reimplement the test loop but loop for the user defined amount of time.""" + if session.testsfailed and not session.config.option.continue_on_collection_errors: + msg = "{} error{} during collection".format(session.testsfailed, "s" if session.testsfailed != 1 else "") + raise session.Interrupted(msg) + + if session.config.option.collectonly: + return True + + start_time: float = _ORIGINAL_TIME_TIME() + total_time: float = self.get_total_time(session) + + count: int = 0 + runtimes = [] + elapsed_ns = 0 + + while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests + count += 1 + loop_start = _ORIGINAL_PERF_COUNTER_NS() + for index, item in enumerate(session.items): + item: pytest.Item = item # noqa: PLW0127 + item._report_sections.clear() # clear reports for new test # noqa: SLF001 + + if total_time > SHORTEST_AMOUNT_OF_TIME: + item._nodeid = self.set_nodeid(item._nodeid, count) # noqa: SLF001 + + next_item: pytest.Item | None = session.items[index + 1] if index + 1 < len(session.items) else None + + self.clear_lru_caches(item) + + item.config.hook.pytest_runtest_protocol(item=item, nextitem=next_item) + if session.shouldfail: + raise session.Failed(session.shouldfail) + if session.shouldstop: + raise session.Interrupted(session.shouldstop) + + if self.enable_stability_check: + elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start + best_runtime_until_now = sum(min(data) for data in self.runtime_data_by_test_case.values()) + if best_runtime_until_now > 0: + runtimes.append(best_runtime_until_now) + + estimated_total_loops = 0 + if elapsed_ns > 0: + rate = count / elapsed_ns + total_time_ns = total_time * 1e9 + estimated_total_loops = int(rate * total_time_ns) + + window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5) + if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops): + break + + if self.timed_out(session, start_time, count): + break + + _ORIGINAL_TIME_SLEEP(self.get_delay_time(session)) + return True + + def clear_lru_caches(self, item: pytest.Item) -> None: + func = item.function # type: ignore[attr-defined] + + # Always clear the test function itself + if hasattr(func, "cache_clear") and callable(func.cache_clear): + with contextlib.suppress(Exception): + func.cache_clear() + + module_name = getattr(func, "__module__", None) + if not module_name: + return + + try: + clearables = self.module_clearables.get(module_name) + if clearables is None: + clearables = self.scan_module_clearables(module_name) + self.module_clearables[module_name] = clearables + + for obj in clearables: + with contextlib.suppress(Exception): + obj.cache_clear() # type: ignore[attr-defined] + except Exception: + pass + + def scan_module_clearables(self, module_name: str) -> list[Callable]: + module = sys.modules.get(module_name) + if not module: + return [] + + clearables: list[Callable] = [] + for _, obj in inspect.getmembers(module): + if not callable(obj): + continue + + if hasattr(obj, "__wrapped__"): + top_module = obj.__wrapped__.__module__ + else: + try: + obj_module = inspect.getmodule(obj) + top_module = obj_module.__name__.split(".")[0] if obj_module is not None else None + except Exception: + top_module = None + + if top_module in _PROTECTED_MODULES: + continue + + if hasattr(obj, "cache_clear") and callable(obj.cache_clear): + clearables.append(obj) + + return clearables + + def set_nodeid(self, nodeid: str, count: int) -> str: + """Set loop count when using duration. + + :param nodeid: Name of test function. + :param count: Current loop count. + :return: Formatted string for test name. + """ + run_str = f"[ {count} ]" + os.environ["CODEFLASH_LOOP_INDEX"] = str(count) + result, n = _NODEID_LOOP_PATTERN.subn(run_str, nodeid) + return result if n else nodeid + run_str + + def get_delay_time(self, session: Session) -> float: + """Extract delay time from session. + + :param session: Pytest session object. + :return: Returns the delay time for each test loop. + """ + return session.config.option.codeflash_delay + + def get_total_time(self, session: Session) -> float: + """Take all the user available time options, add them and return it in seconds. + + :param session: Pytest session object. + :return: Returns total amount of time in seconds. + """ + hours_in_seconds = session.config.option.codeflash_hours * SECONDS_IN_HOUR + minutes_in_seconds = session.config.option.codeflash_minutes * SECONDS_IN_MINUTE + seconds = session.config.option.codeflash_seconds + total_time = hours_in_seconds + minutes_in_seconds + seconds + if total_time < SHORTEST_AMOUNT_OF_TIME: + msg = f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!" + raise InvalidTimeParameterError(msg) + return total_time + + def timed_out(self, session: Session, start_time: float, count: int) -> bool: + """Check if the user specified amount of time has lapsed. + + :param session: Pytest session object. + :return: Returns True if the timeout has expired, False otherwise. + """ + return count >= session.config.option.codeflash_max_loops or ( + count >= session.config.option.codeflash_min_loops + and _ORIGINAL_TIME_TIME() - start_time > self.get_total_time(session) + ) + + @pytest.fixture + def __pytest_loop_step_number(self, request: pytest.FixtureRequest) -> int: + """Set step number for loop. + + :param request: The number to print. + :return: request.param. + """ + marker = request.node.get_closest_marker("loops") + count = (marker and marker.args[0]) or request.config.option.codeflash_loops + if count > 1: + try: + return request.param + except AttributeError: + if issubclass(request.cls, TestCase): + warnings.warn("Repeating unittest class tests not supported", stacklevel=2) + else: + msg = "This call couldn't work with pytest-loops. Please consider raising an issue with your usage." + raise UnexpectedError(msg) from None + return count + + @pytest.hookimpl(trylast=True) + def pytest_generate_tests(self, metafunc: Metafunc) -> None: + """Create tests based on loop value. + + :param metafunc: pytest metafunction + :return: None. + """ + count = metafunc.config.option.codeflash_loops + m = metafunc.definition.get_closest_marker("loops") + + if m is not None: + count = int(m.args[0]) + if count > 1: + metafunc.fixturenames.append("__pytest_loop_step_number") + + def make_progress_id(i: int, n: int = count) -> str: + return f"{n}/{i + 1}" + + scope = metafunc.config.option.codeflash_loops_scope + metafunc.parametrize( + "__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope + ) + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_setup(self, item: pytest.Item) -> None: + """Set test context environment variables before each test.""" + test_module_name = item.module.__name__ if item.module else "unknown_module" # type: ignore[attr-defined] + + test_class_name = None + if hasattr(item, "cls") and item.cls: + test_class_name = item.cls.__name__ # type: ignore[attr-defined] + + test_function_name = item.name + if "[" in test_function_name: + test_function_name = test_function_name.split("[", 1)[0] + + os.environ["CODEFLASH_TEST_MODULE"] = test_module_name + os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or "" + os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name + + @pytest.hookimpl(trylast=True) + def pytest_runtest_teardown(self, item: pytest.Item) -> None: + """Clean up test context environment variables after each test.""" + for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]: + os.environ.pop(var, None) diff --git a/src/codeflash_python/verification/test_output_utils.py b/src/codeflash_python/verification/test_output_utils.py new file mode 100644 index 000000000..52e0c8965 --- /dev/null +++ b/src/codeflash_python/verification/test_output_utils.py @@ -0,0 +1,357 @@ +"""Throughput/concurrency metrics, test file resolution, result merging, and failure parsing.""" + +from __future__ import annotations + +import logging +import re +from collections import defaultdict +from typing import TYPE_CHECKING + +from codeflash_python.discovery.discover_unit_tests import discover_parameters_unittest +from codeflash_python.models.models import ConcurrencyMetrics, FunctionTestInvocation, TestResults, VerificationType +from codeflash_python.verification.path_utils import file_name_from_test_module_name + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger("codeflash_python") + + +matches_re_start = re.compile( + r"!\$######([^:]*)" # group 1: module path + r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty + r"([^.:]*)" # group 3: test function name + r":([^:]*)" # group 4: function being tested + r":([^:]*)" # group 5: loop index + r":([^#]*)" # group 6: iteration id + r"######\$!\n" +) +matches_re_end = re.compile( + r"!######([^:]*)" # group 1: module path + r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty + r"([^.:]*)" # group 3: test function name + r":([^:]*)" # group 4: function being tested + r":([^:]*)" # group 5: loop index + r":([^#]*)" # group 6: iteration_id or iteration_id:runtime + r"######!" +) + + +start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!") +end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!") + + +def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int: + """Calculate function throughput from TestResults by extracting performance stdout. + + A completed execution is defined as having both a start tag and matching end tag from performance wrappers. + Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$! + End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######! + """ + start_matches = start_pattern.findall(test_results.perf_stdout or "") + end_matches = end_pattern.findall(test_results.perf_stdout or "") + + end_matches_truncated = [end_match[:5] for end_match in end_matches] + end_matches_set = set(end_matches_truncated) + + function_throughput = 0 + for start_match in start_matches: + if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name: + function_throughput += 1 + return function_throughput + + +# Pattern for concurrency benchmark output: +# !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@! +_concurrency_pattern = re.compile(r"!@######CONC:([^:]*):([^:]*):([^:]*):([^:]*):([^:]*):(\d+):(\d+):(\d+)######@!") + + +def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ConcurrencyMetrics | None: + """Parse concurrency benchmark results from test output. + + Format: !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@! + + Returns ConcurrencyMetrics with: + - sequential_time_ns: Total time for N sequential executions + - concurrent_time_ns: Total time for N concurrent executions + - concurrency_factor: N (number of concurrent executions) + - concurrency_ratio: sequential_time / concurrent_time (higher = better concurrency) + """ + if not test_results.perf_stdout: + return None + + matches = _concurrency_pattern.findall(test_results.perf_stdout) + if not matches: + return None + + # Aggregate metrics for the target function + total_seq, total_conc, factor, count = 0, 0, 0, 0 + for match in matches: + # match[3] is function_name + if len(match) >= 8 and match[3] == function_name: + total_seq += int(match[5]) + total_conc += int(match[6]) + factor = int(match[7]) + count += 1 + + if count == 0: + return None + + avg_seq = total_seq / count + avg_conc = total_conc / count + ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0 + + return ConcurrencyMetrics( + sequential_time_ns=int(avg_seq), + concurrent_time_ns=int(avg_conc), + concurrency_factor=factor, + concurrency_ratio=ratio, + ) + + +def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None: + """Resolve test file path from pytest's test class path. + + This function handles various cases where pytest's classname in JUnit XML + includes parent directories that may already be part of base_dir. + + Args: + test_class_path: The full class path from pytest (e.g., "project.tests.test_file.TestClass") + base_dir: The base directory for tests (tests project root) + + Returns: + Path to the test file if found, None otherwise + + Examples: + >>> # base_dir = "/path/to/tests" + >>> # test_class_path = "code_to_optimize.tests.unittest.test_file.TestClass" + >>> # Should find: /path/to/tests/unittest/test_file.py + + """ + # First try the full path (Python module path) + test_file_path = file_name_from_test_module_name(test_class_path, base_dir) + + # If we couldn't find the file, try stripping the last component (likely a class name) + # This handles cases like "module.TestClass" where TestClass is a class, not a module + if test_file_path is None and "." in test_class_path: + module_without_class = ".".join(test_class_path.split(".")[:-1]) + test_file_path = file_name_from_test_module_name(module_without_class, base_dir) + + # If still not found, progressively strip prefix components + # This handles cases where pytest's classname includes parent directories that are + # already part of base_dir (e.g., "project.tests.unittest.test_file.TestClass" + # when base_dir is "/.../tests") + if test_file_path is None: + parts = test_class_path.split(".") + # Try stripping 1, 2, 3, ... prefix components + for num_to_strip in range(1, len(parts)): + remaining = ".".join(parts[num_to_strip:]) + test_file_path = file_name_from_test_module_name(remaining, base_dir) + if test_file_path: + break + # Also try without the last component (class name) + if "." in remaining: + remaining_no_class = ".".join(remaining.split(".")[:-1]) + test_file_path = file_name_from_test_module_name(remaining_no_class, base_dir) + if test_file_path: + break + + return test_file_path + + +def merge_test_results( + xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str +) -> TestResults: + merged_test_results = TestResults() + + grouped_xml_results: defaultdict[str, TestResults] = defaultdict(TestResults) + grouped_bin_results: defaultdict[str, TestResults] = defaultdict(TestResults) + + # This is done to match the right iteration_id which might not be available in the xml + for result in xml_test_results: + if test_framework == "pytest": + if ( + result.id.test_function_name + and result.id.test_function_name.endswith("]") + and "[" in result.id.test_function_name + ): # parameterized test + test_function_name = result.id.test_function_name[: result.id.test_function_name.index("[")] + else: + test_function_name = result.id.test_function_name + elif test_framework == "unittest": + test_function_name = result.id.test_function_name + if test_function_name: + is_parameterized, new_test_function_name, _ = discover_parameters_unittest(test_function_name) + if is_parameterized: # handle parameterized test + test_function_name = new_test_function_name + else: + test_function_name = result.id.test_function_name + + grouped_xml_results[ + (result.id.test_module_path or "") + + ":" + + (result.id.test_class_name or "") + + ":" + + (test_function_name or "") + + ":" + + str(result.loop_index) + ].add(result) + + for result in bin_test_results: + grouped_bin_results[ + (result.id.test_module_path or "") + + ":" + + (result.id.test_class_name or "") + + ":" + + (result.id.test_function_name or "") + + ":" + + str(result.loop_index) + ].add(result) + + for result_id in grouped_xml_results: + xml_results = grouped_xml_results[result_id] + bin_results = grouped_bin_results.get(result_id) + if not bin_results: + merged_test_results.merge(xml_results) + continue + + if len(xml_results) == 1: + xml_result = xml_results[0] + # This means that we only have one FunctionTestInvocation for this test xml. Match them to the bin results + # Either a whole test function fails or passes. + for result_bin in bin_results: + # Prefer XML runtime (from stdout markers) if bin runtime is None/0 + merged_runtime = result_bin.runtime if result_bin.runtime else xml_result.runtime + merged_test_results.add( + FunctionTestInvocation( + loop_index=xml_result.loop_index, + id=result_bin.id, + file_name=xml_result.file_name, + runtime=merged_runtime, + test_framework=xml_result.test_framework, + did_pass=xml_result.did_pass, + test_type=xml_result.test_type, + return_value=result_bin.return_value, + timed_out=xml_result.timed_out, + verification_type=VerificationType(result_bin.verification_type) + if result_bin.verification_type + else None, + stdout=xml_result.stdout, + ) + ) + elif xml_results.test_results[0].id.iteration_id is not None: + # This means that we have multiple iterations of the same test function + # We need to match the iteration_id to the bin results + for xml_result in xml_results.test_results: + try: + bin_result = bin_results.get_by_unique_invocation_loop_id(xml_result.unique_invocation_loop_id) + except AttributeError: + bin_result = None + if bin_result is None: + merged_test_results.add(xml_result) + continue + # Prefer XML runtime (from stdout markers) if bin runtime is None/0 + merged_runtime = bin_result.runtime if bin_result.runtime else xml_result.runtime + merged_test_results.add( + FunctionTestInvocation( + loop_index=xml_result.loop_index, + id=xml_result.id, + file_name=xml_result.file_name, + runtime=merged_runtime, + test_framework=xml_result.test_framework, + did_pass=bin_result.did_pass, + test_type=xml_result.test_type, + return_value=bin_result.return_value, + timed_out=xml_result.timed_out + if merged_runtime is None + else False, # If runtime was measured, then the testcase did not time out + verification_type=VerificationType(bin_result.verification_type) + if bin_result.verification_type + else None, + stdout=xml_result.stdout, + ) + ) + else: + # Should happen only if the xml did not have any test invocation id info + for i, bin_result in enumerate(bin_results.test_results): + try: + xml_result = xml_results.test_results[i] + except IndexError: + xml_result = None + if xml_result is None: + merged_test_results.add(bin_result) + continue + # Prefer XML runtime (from stdout markers) if bin runtime is None/0 + merged_runtime = bin_result.runtime if bin_result.runtime else xml_result.runtime + merged_test_results.add( + FunctionTestInvocation( + loop_index=bin_result.loop_index, + id=bin_result.id, + file_name=bin_result.file_name, + runtime=merged_runtime, + test_framework=bin_result.test_framework, + did_pass=bin_result.did_pass, + test_type=bin_result.test_type, + return_value=bin_result.return_value, + timed_out=xml_result.timed_out, # only the xml gets the timed_out flag + verification_type=VerificationType(result_bin.verification_type) + if result_bin.verification_type + else None, + stdout=xml_result.stdout, + ) + ) + + return merged_test_results + + +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + +def parse_test_failures_from_stdout(stdout: str) -> dict[str, str]: + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if "= FAILURES =" in line: + start = i + break + + if start is None: + return {} + + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) + + failure_block = lines[start:end] + + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] + + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) + + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") + + if current_name: + failures[current_name] = "".join(current_lines) + + return failures diff --git a/src/codeflash_python/verification/test_runner.py b/src/codeflash_python/verification/test_runner.py new file mode 100644 index 000000000..32a619d72 --- /dev/null +++ b/src/codeflash_python/verification/test_runner.py @@ -0,0 +1,511 @@ +from __future__ import annotations + +import logging +import re +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash_python.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash_python.models.test_result import TestResult +from codeflash_python.verification.addopts import custom_addopts + +if TYPE_CHECKING: + import threading + from collections.abc import Sequence + +# Pattern to extract timing from stdout markers: !######module:class.test:func:loop:id:duration######! + +logger = logging.getLogger("codeflash_python") + +_TIMING_MARKER_PATTERN = re.compile(r"!######.+:(\d+)######!") + + +def calculate_utilization_fraction(stdout: str, wall_clock_ns: int, test_type: str = "unknown") -> None: + """Calculate and log the function utilization fraction. + + Utilization = sum(function_runtimes_from_markers) / total_wall_clock_time + + This metric shows how much of the test execution time was spent in actual + function calls vs overhead (test framework, I/O, etc.). + + Args: + stdout: The stdout from the test subprocess containing timing markers. + wall_clock_ns: Total wall clock time for the subprocess in nanoseconds. + test_type: Type of test for logging context (e.g., "behavioral", "performance"). + + """ + if not stdout or wall_clock_ns <= 0: + return + + # Extract all timing values from stdout markers + matches = _TIMING_MARKER_PATTERN.findall(stdout) + if not matches: + logger.debug("[%s] No timing markers found in stdout, cannot calculate utilization", test_type) + return + + # Sum all function runtimes + total_function_runtime_ns = sum(int(m) for m in matches) + + # Calculate utilization fraction + utilization = total_function_runtime_ns / wall_clock_ns if wall_clock_ns > 0 else 0 + utilization_pct = utilization * 100 + + # Log metrics + logger.debug( + "[%s] Function Utilization Fraction: %.2f%% " + "(function_time=%.1fms, wall_time=%.1fms, overhead=%.1f%%, num_markers=%s)", + test_type, + utilization_pct, + total_function_runtime_ns / 1e6, + wall_clock_ns / 1e6, + 100 - utilization_pct, + len(matches), + ) + + +PYTEST_CMD: str = "pytest" + + +def setup_pytest_cmd(pytest_cmd: str | None) -> None: + global PYTEST_CMD + PYTEST_CMD = pytest_cmd or "pytest" + + +def pytest_cmd_tokens(is_posix: bool) -> list[str]: + import shlex + + return shlex.split(PYTEST_CMD, posix=is_posix) + + +def build_pytest_cmd(safe_sys_executable: str, is_posix: bool) -> list[str]: + return [safe_sys_executable, "-m", *pytest_cmd_tokens(is_posix)] + + +def run_tests( + test_files: Sequence[Path], + cwd: Path, + env: dict[str, str], + timeout: int, + *, + min_loops: int = 1, + max_loops: int = 1, + target_seconds: float | None = None, + stability_check: bool = False, + enable_coverage: bool = False, +) -> tuple[list[TestResult], Path, Path | None, Path | None]: + import contextlib + import shlex + import sys + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE + from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE + + if target_seconds is None: + target_seconds = TOTAL_LOOPING_TIME_EFFECTIVE + + junit_xml = get_run_tmp_file(Path("pytest_results.xml")) + + pytest_args = [ + "--capture=tee-sys", + "-q", + "--codeflash_loops_scope=session", + f"--codeflash_min_loops={min_loops}", + f"--codeflash_max_loops={max_loops}", + f"--codeflash_seconds={target_seconds}", + ] + if stability_check: + pytest_args.append("--codeflash_stability_check=true") + if timeout: + pytest_args.append(f"--timeout={timeout}") + + result_args = [f"--junitxml={junit_xml.as_posix()}", "-o", "junit_logging=all"] + + pytest_env = env.copy() + pytest_env["PYTEST_PLUGINS"] = "codeflash_python.verification.pytest_plugin" + + blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"] + if min_loops > 1: + blocklisted_plugins.extend(["cov", "profiling"]) + + test_file_args = [str(f) for f in test_files] + + coverage_database_file: Path | None = None + coverage_config_file: Path | None = None + + try: + if enable_coverage: + from codeflash_python.static_analysis.coverage_utils import prepare_coverage_files + + coverage_database_file, coverage_config_file = prepare_coverage_files() + pytest_env["NUMBA_DISABLE_JIT"] = str(1) + pytest_env["TORCHDYNAMO_DISABLE"] = str(1) + pytest_env["PYTORCH_JIT"] = str(0) + pytest_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" + pytest_env["TF_ENABLE_ONEDNN_OPTS"] = str(0) + pytest_env["JAX_DISABLE_JIT"] = str(0) + + is_windows = sys.platform == "win32" + if is_windows: + if coverage_database_file.exists(): + with contextlib.suppress(PermissionError, OSError): + coverage_database_file.unlink() + else: + cov_erase = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_env, timeout=30 + ) + logger.debug(cov_erase) + + coverage_cmd = [ + SAFE_SYS_EXECUTABLE, + "-m", + "coverage", + "run", + f"--rcfile={coverage_config_file.as_posix()}", + "-m", + ] + coverage_cmd.extend(pytest_cmd_tokens(IS_POSIX)) + + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"] + result = execute_test_subprocess( + coverage_cmd + pytest_args + blocklist_args + result_args + test_file_args, + cwd=cwd, + env=pytest_env, + timeout=600, + ) + else: + pytest_cmd_list = build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX) + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] + + result = execute_test_subprocess( + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_file_args, + cwd=cwd, + env=pytest_env, + timeout=600, + ) + + logger.debug("Result return code: %s, %s", result.returncode, result.stderr or "") + results = parse_test_results(junit_xml, result.stdout or "") + return results, junit_xml, coverage_database_file, coverage_config_file + + except Exception as e: + logger.exception("Test execution failed: %s", e) + return [], junit_xml, coverage_database_file, coverage_config_file + + +def parse_test_results(junit_xml_path: Path, stdout: str) -> list[TestResult]: + import xml.etree.ElementTree as ET + + results: list[TestResult] = [] + + if not junit_xml_path.exists(): + return results + + try: + tree = ET.parse(junit_xml_path) + root = tree.getroot() + + for testcase in root.iter("testcase"): + name = testcase.get("name", "unknown") + classname = testcase.get("classname", "") + time_str = testcase.get("time", "0") + + try: + runtime_ns = int(float(time_str) * 1_000_000_000) + except ValueError: + runtime_ns = None + + failure = testcase.find("failure") + error = testcase.find("error") + passed = failure is None and error is None + + error_message = None + if failure is not None: + error_message = failure.get("message", failure.text) + elif error is not None: + error_message = error.get("message", error.text) + + test_file = Path(classname.replace(".", "/") + ".py") if classname else Path("unknown") + + results.append( + TestResult( + test_name=name, + test_file=test_file, + passed=passed, + runtime_ns=runtime_ns, + error_message=error_message, + stdout=stdout, + ) + ) + except Exception as e: + logger.warning("Failed to parse JUnit XML: %s", e) + + return results + + +def run_behavioral_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + enable_coverage: bool = False, + candidate_index: int = 0, +) -> tuple[Path, Any, Path | None, Path | None]: + import contextlib + import shlex + import sys + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE + from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE + from codeflash_python.models.models import TestType + from codeflash_python.static_analysis.coverage_utils import prepare_coverage_files + + blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"] + + test_files: list[str] = [] + for file in test_paths.test_files: + if file.test_type == TestType.REPLAY_TEST: + if file.tests_in_file: + test_files.extend( + [ + str(file.instrumented_behavior_file_path) + "::" + test.test_function + for test in file.tests_in_file + ] + ) + else: + test_files.append(str(file.instrumented_behavior_file_path)) + + pytest_cmd_list = build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX) + test_files = list(set(test_files)) + + common_pytest_args = [ + "--capture=tee-sys", + "-q", + "--codeflash_loops_scope=session", + "--codeflash_min_loops=1", + "--codeflash_max_loops=1", + f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}", + ] + if timeout is not None: + common_pytest_args.append(f"--timeout={timeout}") + + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash_python.verification.pytest_plugin" + + coverage_database_file: Path | None = None + coverage_config_file: Path | None = None + + if enable_coverage: + coverage_database_file, coverage_config_file = prepare_coverage_files() + pytest_test_env["NUMBA_DISABLE_JIT"] = str(1) + pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1) + pytest_test_env["PYTORCH_JIT"] = str(0) + pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" + pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0) + pytest_test_env["JAX_DISABLE_JIT"] = str(0) + + is_windows = sys.platform == "win32" + if is_windows: + if coverage_database_file.exists(): + with contextlib.suppress(PermissionError, OSError): + coverage_database_file.unlink() + else: + cov_erase = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30 + ) + logger.debug(cov_erase) + coverage_cmd = [ + SAFE_SYS_EXECUTABLE, + "-m", + "coverage", + "run", + f"--rcfile={coverage_config_file.as_posix()}", + "-m", + ] + coverage_cmd.extend(pytest_cmd_tokens(IS_POSIX)) + + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"] + results = execute_test_subprocess( + coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "") + else: + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] + + results = execute_test_subprocess( + pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "") + + return result_file_path, results, coverage_database_file, coverage_config_file + + +def run_benchmarking_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + min_loops: int = 5, + max_loops: int = 100_000, + target_duration_seconds: float = 10.0, +) -> tuple[Path, Any]: + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE + + blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"] + + pytest_cmd_list = build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX) + test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files}) + pytest_args = [ + "--capture=tee-sys", + "-q", + "--codeflash_loops_scope=session", + f"--codeflash_min_loops={min_loops}", + f"--codeflash_max_loops={max_loops}", + f"--codeflash_seconds={target_duration_seconds}", + "--codeflash_stability_check=true", + ] + if timeout is not None: + pytest_args.append(f"--timeout={timeout}") + + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash_python.verification.pytest_plugin" + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] + results = execute_test_subprocess( + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + return result_file_path, results + + +def run_line_profile_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + line_profile_output_file: Path | None = None, +) -> tuple[Path, Any]: + + from codeflash_python.code_utils.code_utils import get_run_tmp_file + from codeflash_python.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE + from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE + + blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"] + + pytest_cmd_list = build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX) + test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files}) + pytest_args = [ + "--capture=tee-sys", + "-q", + "--codeflash_loops_scope=session", + "--codeflash_min_loops=1", + "--codeflash_max_loops=1", + f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}", + ] + if timeout is not None: + pytest_args.append(f"--timeout={timeout}") + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash_python.verification.pytest_plugin" + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] + pytest_test_env["LINE_PROFILE"] = "1" + results = execute_test_subprocess( + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + return result_file_path, results + + +def process_generated_test_strings( + generated_test_source: str, + instrumented_behavior_test_source: str, + instrumented_perf_test_source: str, + function_to_optimize: Any, + test_path: Path, + test_cfg: Any, + project_module_system: str | None, +) -> tuple[str, str, str]: + from codeflash_python.code_utils.code_utils import get_run_tmp_file + + temp_run_dir = get_run_tmp_file(Path()).as_posix() + instrumented_behavior_test_source = instrumented_behavior_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + instrumented_perf_test_source = instrumented_perf_test_source.replace( + "{codeflash_run_tmp_dir_client_side}", temp_run_dir + ) + return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source + + +def execute_test_subprocess( + cmd_list: list[str], + cwd: Path, + env: dict[str, str] | None, + timeout: int = 600, + cancel_event: threading.Event | None = None, +) -> subprocess.CompletedProcess: + """Execute a subprocess with the given command list, working directory, environment variables, and timeout. + + If *cancel_event* is provided and becomes set while the process is running, + the subprocess is terminated immediately and a CompletedProcess with + returncode -15 is returned. + """ + import time + + logger.debug("executing test run with command: %s", " ".join(cmd_list)) + with custom_addopts(): + if cancel_event is None: + run_args = get_cross_platform_subprocess_run_args( + cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True + ) + return subprocess.run(cmd_list, **run_args) # type: ignore[no-matching-overload] # noqa: PLW1510 + + # Use Popen so we can poll for cancellation + run_args = get_cross_platform_subprocess_run_args( + cwd=cwd, env=env, timeout=None, check=False, text=True, capture_output=False + ) + # Remove keys that don't apply to Popen + run_args.pop("check", None) + run_args.pop("timeout", None) + run_args.pop("capture_output", None) + proc = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **run_args) # type: ignore[no-matching-overload] + deadline = time.monotonic() + timeout + try: + while proc.poll() is None: + if cancel_event.is_set(): + proc.terminate() + proc.wait(timeout=5) + return subprocess.CompletedProcess(cmd_list, -15, stdout="", stderr="cancelled") + remaining = deadline - time.monotonic() + if remaining <= 0: + proc.terminate() + proc.wait(timeout=5) + msg = f"Timed out after {timeout}s" + raise subprocess.TimeoutExpired(cmd_list, timeout, output="", stderr=msg) # noqa: TRY301 + # Poll every 200ms + cancel_event.wait(min(0.2, remaining)) + stdout, stderr = proc.communicate(timeout=5) + return subprocess.CompletedProcess(cmd_list, proc.returncode, stdout=stdout or "", stderr=stderr or "") + except BaseException: + proc.kill() + proc.wait() + raise diff --git a/src/codeflash_python/verification/verification_utils.py b/src/codeflash_python/verification/verification_utils.py new file mode 100644 index 000000000..2a28bff3f --- /dev/null +++ b/src/codeflash_python/verification/verification_utils.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING + +from codeflash_core.config import TestConfig + +if TYPE_CHECKING: + from pathlib import Path + +__all__ = ["TestConfig"] + + +def get_test_file_path( + test_dir: Path, + function_name: str, + iteration: int = 0, + test_type: str = "unit", + source_file_path: Path | None = None, +) -> Path: + assert test_type in {"unit", "inspired", "replay", "perf"} + function_name_safe = function_name.replace(".", "_") + extension = ".py" + + path = test_dir / f"test_{function_name_safe}__{test_type}_test_{iteration}{extension}" + + if path.exists(): + return get_test_file_path(test_dir, function_name, iteration + 1, test_type, source_file_path=source_file_path) + return path + + +def delete_multiple_if_name_main(test_ast: ast.Module) -> ast.Module: + if_indexes = [] + for index, node in enumerate(test_ast.body): + if ( + isinstance(node, ast.If) + and isinstance(node.test, ast.Compare) + and isinstance(node.test.left, ast.Name) + and node.test.left.id == "__name__" + and len(node.test.ops) > 0 + and isinstance(node.test.ops[0], ast.Eq) + and len(node.test.comparators) > 0 + and isinstance(node.test.comparators[0], ast.Constant) + and node.test.comparators[0].value == "__main__" + ): + if_indexes.append(index) + for index in list(reversed(if_indexes))[1:]: + del test_ast.body[index] + return test_ast + + +class ModifyInspiredTests(ast.NodeTransformer): + """Transformer for modifying inspired test classes. + + Class is currently not in active use. + """ + + def __init__(self, import_list: list[ast.stmt], test_framework: str) -> None: + self.import_list = import_list + self.test_framework = test_framework + + def visit_Import(self, node: ast.Import) -> None: + self.import_list.append(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.import_list.append(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + if self.test_framework != "unittest": + return node + found = False + if node.bases: + for base in node.bases: + if ( + isinstance(base, ast.Attribute) + and base.attr == "TestCase" + and isinstance(base.value, ast.Name) + and base.value.id == "unittest" + ): + found = True + break + # TODO: Check if this is actually a unittest.TestCase + if isinstance(base, ast.Name) and base.id == "TestCase": + found = True + break + if not found: + return node + node.name = node.name + "Inspired" + return node diff --git a/src/codeflash_python/verification/verifier.py b/src/codeflash_python/verification/verifier.py new file mode 100644 index 000000000..c03b6e5d0 --- /dev/null +++ b/src/codeflash_python/verification/verifier.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import ast +import logging +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash_python.code_utils.code_utils import module_name_from_file_path +from codeflash_python.verification.test_runner import process_generated_test_strings +from codeflash_python.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main + +if TYPE_CHECKING: + from codeflash_core.models import FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + + +logger = logging.getLogger("codeflash_python") + + +def generate_tests( + aiservice_client: AiServiceClient, + source_code_being_tested: str, + function_to_optimize: FunctionToOptimize, + helper_function_names: list[str], + module_path: Path, + test_cfg_project_root: Path, + test_timeout: int, + function_trace_id: str, + test_index: int, + test_path: Path, + test_perf_path: Path, + is_numerical_code: bool | None = None, +) -> tuple[str, str, str, str | None, Path, Path] | None: + """Generate regression tests for a single function. + + Wraps AiServiceClient.generate_regression_tests() and processes + the returned test strings. + """ + start_time = time.perf_counter() + test_module_path = Path(module_name_from_file_path(test_path, test_cfg_project_root)) + + response = aiservice_client.generate_regression_tests( + source_code_being_tested=source_code_being_tested, + function_to_optimize=function_to_optimize, + helper_function_names=helper_function_names, + module_path=module_path, + test_module_path=test_module_path, + test_framework="pytest", + test_timeout=test_timeout, + trace_id=function_trace_id, + test_index=test_index, + is_numerical_code=is_numerical_code, + ) + + if response and isinstance(response, tuple) and len(response) == 4: + generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, raw_generated_tests = ( + response + ) + + generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = ( + process_generated_test_strings( + generated_test_source=generated_test_source, + instrumented_behavior_test_source=instrumented_behavior_test_source, + instrumented_perf_test_source=instrumented_perf_test_source, + function_to_optimize=function_to_optimize, + test_path=test_path, + test_cfg=None, + project_module_system=None, + ) + ) + else: + logger.warning("Failed to generate tests for %s", function_to_optimize.function_name) + return None + + end_time = time.perf_counter() + logger.debug("Generated tests in %.2f seconds", end_time - start_time) + return ( + generated_test_source, + instrumented_behavior_test_source, + instrumented_perf_test_source, + raw_generated_tests, + test_path, + test_perf_path, + ) + + +def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_framework: str) -> str: + try: + inspired_unit_tests_ast = ast.parse(inspired_unit_tests) + unit_test_source_ast = ast.parse(unit_test_source) + except SyntaxError as e: + logger.exception("Syntax error in code: %s", e) + return unit_test_source + import_list: list[ast.stmt] = [] + modified_ast = ModifyInspiredTests(import_list, test_framework).visit(inspired_unit_tests_ast) + if test_framework == "pytest": + for node in ast.iter_child_nodes(modified_ast): + if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): + node.name = node.name + "__inspired" + unit_test_source_ast.body.extend(modified_ast.body) + unit_test_source_ast.body = import_list + unit_test_source_ast.body + if test_framework == "unittest": + unit_test_source_ast = delete_multiple_if_name_main(unit_test_source_ast) + return ast.unparse(unit_test_source_ast) diff --git a/src/codeflash_python/verification/wrapper_generation.py b/src/codeflash_python/verification/wrapper_generation.py new file mode 100644 index 000000000..af24638c4 --- /dev/null +++ b/src/codeflash_python/verification/wrapper_generation.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import ast +import logging + +from codeflash_python.models.models import TestingMode, VerificationType +from codeflash_python.verification.device_sync import ( + create_device_sync_precompute_statements, + create_device_sync_statements, +) + +logger = logging.getLogger("codeflash_python") + + +def create_wrapper_function( + mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None +) -> ast.FunctionDef: + lineno = 1 + wrapper_body: list[ast.stmt] = [ + ast.Assign( + targets=[ast.Name(id="test_id", ctx=ast.Store())], + value=ast.JoinedStr( + values=[ + ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1), + ] + ), + lineno=lineno + 1, + ), + ast.If( + test=ast.UnaryOp( + op=ast.Not(), + operand=ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ast.Name(id="codeflash_wrap", ctx=ast.Load()), ast.Constant(value="index")], + keywords=[], + ), + ), + body=[ + ast.Assign( + targets=[ + ast.Attribute( + value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Store() + ) + ], + value=ast.Dict(keys=[], values=[]), + lineno=lineno + 3, + ) + ], + orelse=[], + lineno=lineno + 2, + ), + ast.If( + test=ast.Compare( + left=ast.Name(id="test_id", ctx=ast.Load()), + ops=[ast.In()], + comparators=[ + ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()) + ], + ), + body=[ + ast.AugAssign( + target=ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load() + ), + slice=ast.Name(id="test_id", ctx=ast.Load()), + ctx=ast.Store(), + ), + op=ast.Add(), + value=ast.Constant(value=1), + lineno=lineno + 5, + ) + ], + orelse=[ + ast.Assign( + targets=[ + ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load() + ), + slice=ast.Name(id="test_id", ctx=ast.Load()), + ctx=ast.Store(), + ) + ], + value=ast.Constant(value=0), + lineno=lineno + 6, + ) + ], + lineno=lineno + 4, + ), + ast.Assign( + targets=[ast.Name(id="codeflash_test_index", ctx=ast.Store())], + value=ast.Subscript( + value=ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()), + slice=ast.Name(id="test_id", ctx=ast.Load()), + ctx=ast.Load(), + ), + lineno=lineno + 7, + ), + ast.Assign( + targets=[ast.Name(id="invocation_id", ctx=ast.Store())], + value=ast.JoinedStr( + values=[ + ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1), + ast.Constant(value="_"), + ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1), + ] + ), + lineno=lineno + 8, + ), + *( + [ + ast.Assign( + targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())], + value=ast.JoinedStr( + values=[ + ast.FormattedValue( + value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1 + ), + ast.Constant(value=":"), + ast.FormattedValue( + value=ast.IfExp( + test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + body=ast.BinOp( + left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + op=ast.Add(), + right=ast.Constant(value="."), + ), + orelse=ast.Constant(value=""), + ), + conversion=-1, + ), + ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue( + value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1 + ), + ast.Constant(value=":"), + ast.FormattedValue( + value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1 + ), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1), + ] + ), + lineno=lineno + 9, + ), + ast.Expr( + value=ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[ + ast.JoinedStr( + values=[ + ast.Constant(value="!$######"), + ast.FormattedValue( + value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1 + ), + ast.Constant(value="######$!"), + ] + ) + ], + keywords=[], + ) + ), + ] + ), + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 + ), + # Pre-compute device sync conditions before profiling to avoid overhead during timing + *create_device_sync_precompute_statements(used_frameworks), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), + args=[], + keywords=[], + ), + lineno=lineno + 9, + ), + ast.Try( + body=[ + # Pre-sync: synchronize device before starting timer + *create_device_sync_statements(used_frameworks, for_return_value=False), + ast.Assign( + targets=[ast.Name(id="counter", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + lineno=lineno + 11, + ), + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=lineno + 12, + ), + # Post-sync: synchronize device after function call to ensure all device work is complete + *create_device_sync_statements(used_frameworks, for_return_value=True), + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=lineno + 13, + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=[ + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), + attr="perf_counter_ns", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=lineno + 15, + ), + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], + value=ast.Name(id="e", ctx=ast.Load()), + lineno=lineno + 13, + ), + ], + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()), + args=[], + keywords=[], + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[ + ast.JoinedStr( + values=[ + ast.Constant(value="!######"), + ast.FormattedValue(value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1), + *( + [ + ast.Constant(value=":"), + ast.FormattedValue( + value=ast.Name(id="codeflash_duration", ctx=ast.Load()), conversion=-1 + ), + ] + if mode == TestingMode.PERFORMANCE + else [] + ), + ast.Constant(value="######!"), + ] + ) + ], + keywords=[], + ) + ), + *( + [ + ast.Assign( + targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())], + value=ast.IfExp( + test=ast.Name(id="exception", ctx=ast.Load()), + body=ast.Call( + func=ast.Attribute( + value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load() + ), + args=[ast.Name(id="exception", ctx=ast.Load())], + keywords=[], + ), + orelse=ast.Call( + func=ast.Attribute( + value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load() + ), + args=[ast.Name(id="return_value", ctx=ast.Load())], + keywords=[], + ), + ), + lineno=lineno + 18, + ) + ] + if mode == TestingMode.BEHAVIOR + else [] + ), + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load() + ), + args=[ + ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"), + ast.Tuple( + elts=[ + ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_name", ctx=ast.Load()), + ast.Name(id="codeflash_function_name", ctx=ast.Load()), + ast.Name(id="codeflash_loop_index", ctx=ast.Load()), + ast.Name(id="invocation_id", ctx=ast.Load()), + ast.Name(id="codeflash_duration", ctx=ast.Load()), + ast.Name(id="pickled_return_value", ctx=ast.Load()), + ast.Constant(value=VerificationType.FUNCTION_CALL.value), + ], + ctx=ast.Load(), + ), + ], + keywords=[], + ), + lineno=lineno + 20, + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="commit", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + lineno=lineno + 21, + ), + ] + if mode == TestingMode.BEHAVIOR + else [] + ), + ast.If( + test=ast.Name(id="exception", ctx=ast.Load()), + body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)], + orelse=[], + lineno=lineno + 22, + ), + ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19), + ] + return ast.FunctionDef( + name="codeflash_wrap", + args=ast.arguments( + args=[ + ast.arg(arg="codeflash_wrapped", annotation=None), + ast.arg(arg="codeflash_test_module_name", annotation=None), + ast.arg(arg="codeflash_test_class_name", annotation=None), + ast.arg(arg="codeflash_test_name", annotation=None), + ast.arg(arg="codeflash_function_name", annotation=None), + ast.arg(arg="codeflash_line_id", annotation=None), + ast.arg(arg="codeflash_loop_index", annotation=None), + *([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []), + *([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []), + ], + vararg=ast.arg(arg="args"), + kwarg=ast.arg(arg="kwargs"), + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=wrapper_body, + lineno=lineno, + decorator_list=[], + returns=None, + ) diff --git a/tests/benchmarks/test_benchmark_code_extract_code_context.py b/tests/benchmarks/test_benchmark_code_extract_code_context.py index 4fe06b14d..05cb3e5f4 100644 --- a/tests/benchmarks/test_benchmark_code_extract_code_context.py +++ b/tests/benchmarks/test_benchmark_code_extract_code_context.py @@ -1,7 +1,7 @@ from argparse import Namespace from pathlib import Path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer @@ -15,7 +15,7 @@ def test_benchmark_extract(benchmark) -> None: disable_telemetry=True, tests_root=(file_path / "tests").resolve(), test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path.cwd(), ) diff --git a/tests/benchmarks/test_benchmark_discover_unit_tests.py b/tests/benchmarks/test_benchmark_discover_unit_tests.py index 6a2f4432e..082c178a4 100644 --- a/tests/benchmarks/test_benchmark_discover_unit_tests.py +++ b/tests/benchmarks/test_benchmark_discover_unit_tests.py @@ -1,7 +1,7 @@ from pathlib import Path from codeflash.discovery.discover_unit_tests import discover_unit_tests -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None: @@ -9,7 +9,7 @@ def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None: tests_path = project_path / "tests" / "pytest" test_config = TestConfig( tests_root=tests_path, - project_root_path=project_path, + project_root=project_path, test_framework="pytest", tests_project_rootdir=tests_path.parent, ) @@ -21,7 +21,7 @@ def test_benchmark_codeflash_test_discovery(benchmark) -> None: tests_path = project_path / "tests" test_config = TestConfig( tests_root=tests_path, - project_root_path=project_path, + project_root=project_path, test_framework="pytest", tests_project_rootdir=tests_path.parent, ) diff --git a/tests/code_utils/test_coverage_utils.py b/tests/code_utils/test_coverage_utils.py index 1697f2ba4..8fb6bf0c1 100644 --- a/tests/code_utils/test_coverage_utils.py +++ b/tests/code_utils/test_coverage_utils.py @@ -6,7 +6,7 @@ build_fully_qualified_name, extract_dependent_function, ) -from codeflash.models.function_types import FunctionParent +from codeflash_core.models import FunctionParent from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown from codeflash.verification.coverage_utils import CoverageUtils diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 198058b28..44a1f3509 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -24,7 +24,7 @@ def test_add_needed_imports_from_module0() -> None: from pydantic.dataclasses import dataclass from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton from codeflash.code_utils.code_utils import path_belongs_to_site_packages -from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize +from codeflash_core.models import FunctionParent, FunctionToOptimize def belongs_to_class(name: Name, class_name: str) -> bool: """Check if the given name belongs to the specified class.""" @@ -78,7 +78,7 @@ def test_add_needed_imports_from_module() -> None: from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton from codeflash.code_utils.code_utils import path_belongs_to_site_packages -from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize +from codeflash_core.models import FunctionParent, FunctionToOptimize def belongs_to_class(name: Name, class_name: str) -> bool: diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index c70187aa5..077cb8cd9 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -14,7 +14,7 @@ TestType, VerificationType, ) -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig TestType.__test__ = False TestConfig.__test__ = False @@ -25,7 +25,7 @@ def test_config(): """Create a mock TestConfig for testing.""" config = Mock(spec=TestConfig) - config.project_root_path = Path(__file__).resolve().parent.parent + config.project_root = Path(__file__).resolve().parent.parent config.test_framework = "pytest" config.tests_project_rootdir = Path(__file__).resolve().parent config.tests_root = Path(__file__).resolve().parent @@ -61,7 +61,7 @@ def create_test_invocation( def test_basic_runtime_comment_addition(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -98,7 +98,7 @@ def test_basic_runtime_comment_addition(self, test_config): def test_multiple_test_functions(self, test_config): """Test handling multiple test functions in the same file.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = quick_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -151,7 +151,7 @@ def helper_function(): def test_different_time_formats(self, test_config): """Test that different time ranges are formatted correctly with new precision rules.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_cases = [ (999, 500, "999ns -> 500ns"), # nanoseconds (25_000, 18_000, "25.0μs -> 18.0μs"), # microseconds with precision @@ -195,7 +195,7 @@ def test_different_time_formats(self, test_config): def test_missing_test_results(self, test_config): """Test behavior when test results are missing for a test function.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -228,7 +228,7 @@ def test_missing_test_results(self, test_config): def test_partial_test_results(self, test_config): """Test behavior when only one set of test results is available.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -262,7 +262,7 @@ def test_partial_test_results(self, test_config): def test_multiple_runtimes_uses_minimum(self, test_config): """Test that when multiple runtimes exist, the minimum is used.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -314,7 +314,7 @@ def test_multiple_runtimes_uses_minimum(self, test_config): def test_no_codeflash_output_assignment(self, test_config): """Test behavior when test doesn't have codeflash_output assignment.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -349,7 +349,7 @@ def test_no_codeflash_output_assignment(self, test_config): def test_invalid_python_code_handling(self, test_config): """Test behavior when test source code is invalid Python.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(: codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -384,7 +384,7 @@ def test_invalid_python_code_handling(self, test_config): def test_multiple_generated_tests(self, test_config): """Test handling multiple generated test objects.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source_1 = """def test_bubble_sort(): codeflash_output = quick_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -441,7 +441,7 @@ def test_multiple_generated_tests(self, test_config): def test_preserved_test_attributes(self, test_config): """Test that other test attributes are preserved during modification.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -486,7 +486,7 @@ def test_preserved_test_attributes(self, test_config): def test_multistatement_line_handling(self, test_config): """Test that runtime comments work correctly with multiple statements on one line.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_mutation_of_input(): # Test that the input list is mutated in-place and returned arr = [3, 1, 2] @@ -539,7 +539,7 @@ def test_multistatement_line_handling(self, test_config): def test_add_runtime_comments_simple_function(self, test_config): """Test adding runtime comments to a simple test function.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): codeflash_output = some_function() assert codeflash_output == expected @@ -577,7 +577,7 @@ def test_add_runtime_comments_simple_function(self, test_config): def test_add_runtime_comments_class_method(self, test_config): """Test adding runtime comments to a test method within a class.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """class TestClass: def test_function(self): codeflash_output = some_function() @@ -618,7 +618,7 @@ def test_function(self): def test_add_runtime_comments_multiple_assignments(self, test_config): """Test adding runtime comments when there are multiple codeflash_output assignments.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): setup_data = prepare_test() codeflash_output = some_function() @@ -674,7 +674,7 @@ def test_add_runtime_comments_multiple_assignments(self, test_config): def test_add_runtime_comments_no_matching_runtimes(self, test_config): """Test that source remains unchanged when no matching runtimes are found.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): codeflash_output = some_function() assert codeflash_output == expected @@ -710,7 +710,7 @@ def test_add_runtime_comments_no_matching_runtimes(self, test_config): def test_add_runtime_comments_no_codeflash_output(self, test_config): """Comments will still be added if codeflash output doesnt exist""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): result = some_function() assert result == expected @@ -749,7 +749,7 @@ def test_add_runtime_comments_no_codeflash_output(self, test_config): def test_add_runtime_comments_multiple_tests(self, test_config): """Test adding runtime comments to multiple generated tests.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source1 = """def test_function1(): codeflash_output = some_function() assert codeflash_output == expected @@ -821,7 +821,7 @@ def test_add_runtime_comments_multiple_tests(self, test_config): def test_add_runtime_comments_performance_regression(self, test_config): """Test adding runtime comments when optimized version is slower (negative performance gain).""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): codeflash_output = some_function() assert codeflash_output == expected @@ -873,7 +873,7 @@ def test_add_runtime_comments_performance_regression(self, test_config): def test_basic_runtime_comment_addition_no_cfo(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -911,7 +911,7 @@ def test_basic_runtime_comment_addition_no_cfo(self, test_config): def test_multiple_test_functions_no_cfo(self, test_config): """Test handling multiple test functions in the same file.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = quick_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -963,7 +963,7 @@ def helper_function(): def test_different_time_formats_no_cfo(self, test_config): """Test that different time ranges are formatted correctly with new precision rules.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_cases = [ (999, 500, "999ns -> 500ns"), # nanoseconds (25_000, 18_000, "25.0μs -> 18.0μs"), # microseconds with precision @@ -1006,7 +1006,7 @@ def test_different_time_formats_no_cfo(self, test_config): def test_missing_test_results_no_cfo(self, test_config): """Test behavior when test results are missing for a test function.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1039,7 +1039,7 @@ def test_missing_test_results_no_cfo(self, test_config): def test_partial_test_results_no_cfo(self, test_config): """Test behavior when only one set of test results is available.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1073,7 +1073,7 @@ def test_partial_test_results_no_cfo(self, test_config): def test_multiple_runtimes_uses_minimum_no_cfo(self, test_config): """Test that when multiple runtimes exist, the minimum is used.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1125,7 +1125,7 @@ def test_multiple_runtimes_uses_minimum_no_cfo(self, test_config): def test_no_codeflash_output_assignment_invalid_iteration_id(self, test_config): """Test behavior when test doesn't have codeflash_output assignment.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1160,7 +1160,7 @@ def test_no_codeflash_output_assignment_invalid_iteration_id(self, test_config): def test_invalid_python_code_handling_no_cfo(self, test_config): """Test behavior when test source code is invalid Python.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(: result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1195,7 +1195,7 @@ def test_invalid_python_code_handling_no_cfo(self, test_config): def test_multiple_generated_tests_no_cfo(self, test_config): """Test handling multiple generated test objects.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source_1 = """def test_bubble_sort(): codeflash_output = quick_sort([3, 1, 2]); assert codeflash_output == [1, 2, 3] """ @@ -1251,7 +1251,7 @@ def test_multiple_generated_tests_no_cfo(self, test_config): def test_preserved_test_attributes_no_cfo(self, test_config): """Test that other test attributes are preserved during modification.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] @@ -1296,7 +1296,7 @@ def test_preserved_test_attributes_no_cfo(self, test_config): def test_multistatement_line_handling_no_cfo(self, test_config): """Test that runtime comments work correctly with multiple statements on one line.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_mutation_of_input(): # Test that the input list is mutated in-place and returned arr = [3, 1, 2] @@ -1349,7 +1349,7 @@ def test_multistatement_line_handling_no_cfo(self, test_config): def test_add_runtime_comments_simple_function_no_cfo(self, test_config): """Test adding runtime comments to a simple test function.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): result = some_function(); assert result == expected """ @@ -1386,7 +1386,7 @@ def test_add_runtime_comments_simple_function_no_cfo(self, test_config): def test_add_runtime_comments_class_method_no_cfo(self, test_config): """Test adding runtime comments to a test method within a class.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """class TestClass: def test_function(self): result = some_function() @@ -1427,7 +1427,7 @@ def test_function(self): def test_add_runtime_comments_multiple_assignments_no_cfo(self, test_config): """Test adding runtime comments when there are multiple codeflash_output assignments.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): setup_data = prepare_test() codeflash_output = some_function(); assert codeflash_output == expected @@ -1479,7 +1479,7 @@ def test_add_runtime_comments_multiple_assignments_no_cfo(self, test_config): def test_add_runtime_comments_no_matching_runtimes_no_cfo(self, test_config): """Test that source remains unchanged when no matching runtimes are found.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): result = some_function() assert result == expected @@ -1515,7 +1515,7 @@ def test_add_runtime_comments_no_matching_runtimes_no_cfo(self, test_config): def test_add_runtime_comments_multiple_tests_no_cfo(self, test_config): """Test adding runtime comments to multiple generated tests.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source1 = """def test_function1(): result = some_function() assert result == expected @@ -1587,7 +1587,7 @@ def test_add_runtime_comments_multiple_tests_no_cfo(self, test_config): def test_add_runtime_comments_performance_regression_no_cfo(self, test_config): """Test adding runtime comments when optimized version is slower (negative performance gain).""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_function(): result = some_function(); assert codeflash_output == expected codeflash_output = some_function() @@ -1637,7 +1637,7 @@ def test_add_runtime_comments_performance_regression_no_cfo(self, test_config): def test_runtime_comment_addition_for(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): a = 2 for i in range(3): @@ -1697,7 +1697,7 @@ def test_runtime_comment_addition_for(self, test_config): def test_runtime_comment_addition_while(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): i = 0 while i<3: @@ -1757,7 +1757,7 @@ def test_runtime_comment_addition_while(self, test_config): def test_runtime_comment_addition_with(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): i = 0 with open('a.txt','rb') as f: @@ -1817,7 +1817,7 @@ def test_runtime_comment_addition_with(self, test_config): def test_runtime_comment_addition_lc(self, test_config): """Test basic functionality of adding runtime comments for list comprehension.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_bubble_sort(): i = 0 codeflash_output = [bubble_sort([3, 1, 2]) for _ in range(3)] @@ -1871,7 +1871,7 @@ def test_runtime_comment_addition_lc(self, test_config): def test_runtime_comment_addition_parameterized(self, test_config): """Test basic functionality of adding runtime comments for list comprehension.""" # Create test source code - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """@pytest.mark.parametrize( "input, expected_output", [ @@ -1940,7 +1940,7 @@ def test_bubble_sort(input, expected_output): def test_async_basic_runtime_comment_addition(self, test_config): """Test basic functionality of adding runtime comments to async test functions.""" - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """async def test_async_bubble_sort(): codeflash_output = await async_bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -1972,7 +1972,7 @@ def test_async_basic_runtime_comment_addition(self, test_config): assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source def test_async_multiple_test_functions(self, test_config): - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """async def test_async_bubble_sort(): codeflash_output = await async_quick_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] @@ -2018,7 +2018,7 @@ def helper_function(): ) def test_async_class_method(self, test_config): - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """class TestAsyncClass: async def test_async_function(self): codeflash_output = await some_async_function() @@ -2057,7 +2057,7 @@ async def test_async_function(self): assert result.generated_tests[0].generated_original_test_source == expected_source def test_async_mixed_sync_and_async_functions(self, test_config): - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """def test_sync_function(): codeflash_output = sync_function([1, 2, 3]) assert codeflash_output == [1, 2, 3] @@ -2107,7 +2107,7 @@ def test_another_sync(): assert "await async_function([4, 5, 6])" in modified_source def test_async_complex_await_patterns(self, test_config): - os.chdir(test_config.project_root_path) + os.chdir(test_config.project_root) test_source = """async def test_complex_async(): # Multiple await calls result1 = await async_func1() diff --git a/tests/test_async_function_discovery.py b/tests/test_async_function_discovery.py index c13151c22..475dfe54c 100644 --- a/tests/test_async_function_discovery.py +++ b/tests/test_async_function_discovery.py @@ -9,7 +9,7 @@ get_functions_to_optimize, inspect_top_level_functions_or_methods, ) -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig @pytest.fixture @@ -230,7 +230,7 @@ def sync_method(self): file_path.write_text(mixed_code) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root=Path("tests"), project_root=Path("."), test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( @@ -280,7 +280,7 @@ def sync_method(self): file_path.write_text(mixed_code) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root=Path("tests"), project_root=Path("."), test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 1034b5e51..2e3397bb0 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -13,7 +13,7 @@ get_decorator_name_for_mode, inject_profiling_into_existing_test, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture @@ -85,7 +85,7 @@ async def test_async_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -212,7 +212,7 @@ async def test_async_class_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -337,7 +337,7 @@ async def test_async_perf(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -457,7 +457,7 @@ async def async_error_function(lst): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -562,7 +562,7 @@ async def test_async_multi(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -682,7 +682,7 @@ async def test_async_edge_cases(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -827,7 +827,7 @@ def test_sync_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -1002,7 +1002,7 @@ async def test_mixed_sorting(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index ccfa5410d..da3b08397 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -9,7 +9,7 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.context.code_context_extractor import ( collect_type_names_from_annotation, enrich_testgen_context, @@ -458,7 +458,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -708,7 +708,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -806,7 +806,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -902,7 +902,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -998,7 +998,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -1050,7 +1050,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -1102,7 +1102,7 @@ def helper_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -1691,7 +1691,7 @@ def outside_method(): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -1962,7 +1962,7 @@ def get_system_details(): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2223,7 +2223,7 @@ def get_system_details(): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2383,7 +2383,7 @@ def standalone_function(): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2461,7 +2461,7 @@ def nested_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2513,7 +2513,7 @@ def target_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2566,7 +2566,7 @@ def target_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2577,7 +2577,7 @@ def target_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -2622,7 +2622,7 @@ def simple_method(self): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -3236,7 +3236,7 @@ def target_function(): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) @@ -3308,7 +3308,7 @@ def dump_layout(layout_type, layout): disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=Path().resolve(), ) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index aae043833..c29103aa8 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -7,7 +7,7 @@ import libcst as cst -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.languages.python.static_analysis.code_extractor import ( delete___future___aliased_imports, @@ -22,7 +22,7 @@ replace_functions_in_file, ) from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig os.environ["CODEFLASH_API_KEY"] = "cf-test-key" @@ -53,9 +53,9 @@ def sorter(arr): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -833,9 +833,9 @@ def main_method(self): test_config = TestConfig( tests_root=file_path.parent, tests_project_rootdir=file_path.parent, - project_root_path=file_path.parent, + project_root=file_path.parent, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() @@ -1744,9 +1744,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -1823,9 +1823,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -1903,9 +1903,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -1982,9 +1982,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -2062,9 +2062,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -2152,9 +2152,9 @@ def new_function2(value): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() @@ -3452,9 +3452,9 @@ def hydrate_input_text_actions_with_field_names( test_config = TestConfig( tests_root=root_dir / "tests/pytest", tests_project_rootdir=root_dir, - project_root_path=root_dir, + project_root=root_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 92704d7f1..dec7a3583 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -6,13 +6,13 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType from codeflash.verification.equivalence import compare_test_results from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.languages.python.test_runner import execute_test_subprocess -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig # Tests for get_stack_info. Ensures that when a test is run via pytest, the correct test information is extracted @@ -450,9 +450,9 @@ def __init__(self, x=2): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) fto = FunctionToOptimize( function_name="some_function", @@ -573,9 +573,9 @@ def __init__(self, *args, **kwargs): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) fto = FunctionToOptimize( function_name="some_function", @@ -700,9 +700,9 @@ def __init__(self, x=2): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) fto = FunctionToOptimize( function_name="some_function", @@ -863,9 +863,9 @@ def another_helper(self): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) fto = FunctionToOptimize( function_name="target_function", @@ -1017,9 +1017,9 @@ def another_helper(self): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) func_optimizer.test_files = TestFiles( @@ -1442,9 +1442,9 @@ def calculate_portfolio_metrics( test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) func_optimizer.test_files = TestFiles( @@ -1662,9 +1662,9 @@ def __init__(self, x, y): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) fto = FunctionToOptimize( function_name="__init__", diff --git a/tests/test_existing_tests_source_for.py b/tests/test_existing_tests_source_for.py index 2e11bc6ef..da72e648f 100644 --- a/tests/test_existing_tests_source_for.py +++ b/tests/test_existing_tests_source_for.py @@ -18,7 +18,7 @@ def setup_method(self): # Mock test config self.test_cfg = Mock() self.test_cfg.tests_root = Path(__file__).resolve().parent - self.test_cfg.project_root_path = Path(__file__).resolve().parent.parent + self.test_cfg.project_root = Path(__file__).resolve().parent.parent # Mock invocation ID self.mock_invocation_id = Mock() @@ -31,7 +31,7 @@ def setup_method(self): self.mock_function_called_in_test.tests_in_file = Mock() self.mock_function_called_in_test.tests_in_file.test_file = Path(__file__).resolve().parent / "test_module.py" # Path to pyproject.toml - os.chdir(self.test_cfg.project_root_path) + os.chdir(self.test_cfg.project_root) def test_no_test_files_returns_empty_string(self): """Test that function returns empty string when no test files exist.""" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 480efcef5..6cce4c63a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -7,10 +7,10 @@ from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash.languages.function_optimizer import FunctionOptimizer -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig @pytest.fixture @@ -226,7 +226,7 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected function_to_optimize = FunctionToOptimize(function_name="process_data", parents=[], file_path=target_path) test_cfg = TestConfig( - tests_root=test_dir, project_root_path=test_dir, test_framework="pytest", tests_project_rootdir=test_dir + tests_root=test_dir, project_root=test_dir, test_framework="pytest", tests_project_rootdir=test_dir ) args = argparse.Namespace( diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 0814f8af2..385d18ad6 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -2,11 +2,11 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.either import is_successful from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.models.models import FunctionParent -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def calculate_something(data): @@ -137,9 +137,9 @@ def test_class_method_dependencies() -> None: test_cfg=TestConfig( tests_root=file_path, tests_project_rootdir=file_path.parent, - project_root_path=file_path.parent, + project_root=file_path.parent, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ), ) with open(file_path) as f: @@ -207,9 +207,9 @@ def test_recursive_function_context() -> None: test_cfg=TestConfig( tests_root=file_path, tests_project_rootdir=file_path.parent, - project_root_path=file_path.parent, + project_root=file_path.parent, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ), ) with open(file_path) as f: diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index b6c781a01..1851cba2e 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -10,7 +10,7 @@ get_functions_to_optimize, inspect_top_level_functions_or_methods, ) -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_function_eligible_for_optimization() -> None: @@ -144,7 +144,7 @@ def functionA(): ) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root="tests", project_root=".", test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, @@ -247,7 +247,7 @@ def traverse(node_id): ) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root="tests", project_root=".", test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, @@ -279,7 +279,7 @@ def inner_function(): ) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root="tests", project_root=".", test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, @@ -313,7 +313,7 @@ def another_inner_function(): ) test_config = TestConfig( - tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + tests_root="tests", project_root=".", test_framework="pytest", tests_project_rootdir=Path() ) functions, functions_count, _ = get_functions_to_optimize( optimize_all=None, diff --git a/tests/test_get_code.py b/tests/test_get_code.py index ad040f122..1893df503 100644 --- a/tests/test_get_code.py +++ b/tests/test_get_code.py @@ -3,7 +3,7 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.static_analysis.code_extractor import get_code from codeflash.models.models import FunctionParent diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index 4825926b4..625322e57 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -4,12 +4,12 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.either import is_successful from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.models.models import FunctionParent, get_code_block_splitter from codeflash.optimization.optimizer import Optimizer -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig class HelperClass: @@ -30,7 +30,7 @@ def test_get_outside_method_helper() -> None: disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) @@ -229,9 +229,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: test_config = TestConfig( tests_root="tests", tests_project_rootdir=Path.cwd(), - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config) with open(file_path) as f: @@ -400,9 +400,9 @@ def test_bubble_sort_deps() -> None: test_config = TestConfig( tests_root=str(file_path.parent / "tests"), tests_project_rootdir=file_path.parent.resolve(), - project_root_path=project_root, + project_root=project_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config) with open(file_path) as f: diff --git a/tests/test_inject_profiling_used_frameworks.py b/tests/test_inject_profiling_used_frameworks.py index 826be09c8..4a9b1901c 100644 --- a/tests/test_inject_profiling_used_frameworks.py +++ b/tests/test_inject_profiling_used_frameworks.py @@ -13,7 +13,7 @@ detect_frameworks_from_code, inject_profiling_into_existing_test, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, TestingMode diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 40fb8fbcf..763aa449f 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -8,7 +8,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results @@ -138,7 +138,7 @@ def test_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -316,7 +316,7 @@ def test_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -431,7 +431,7 @@ def sorter(self, arr): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -586,7 +586,7 @@ def test_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -757,7 +757,7 @@ def test_sort(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index b1729630d..5be02c6e8 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -11,7 +11,7 @@ get_decorator_name_for_mode, inject_profiling_into_existing_test, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, TestingMode diff --git a/tests/test_instrument_codeflash_capture.py b/tests/test_instrument_codeflash_capture.py index 4ad0fade1..0612beb43 100644 --- a/tests/test_instrument_codeflash_capture.py +++ b/tests/test_instrument_codeflash_capture.py @@ -1,7 +1,7 @@ from pathlib import Path from codeflash.code_utils.code_utils import get_run_tmp_file -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.models.models import FunctionParent diff --git a/tests/test_instrument_codeflash_trace.py b/tests/test_instrument_codeflash_trace.py index 056743864..d729f831f 100644 --- a/tests/test_instrument_codeflash_trace.py +++ b/tests/test_instrument_codeflash_trace.py @@ -7,7 +7,7 @@ add_codeflash_decorator_to_code, instrument_codeflash_trace_decorator, ) -from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize +from codeflash_core.models import FunctionParent, FunctionToOptimize def test_add_decorator_to_normal_function() -> None: diff --git a/tests/test_instrument_line_profiler.py b/tests/test_instrument_line_profiler.py index 5a6a04e6e..66c29ebd6 100644 --- a/tests/test_instrument_line_profiler.py +++ b/tests/test_instrument_line_profiler.py @@ -2,11 +2,11 @@ from pathlib import Path from tempfile import TemporaryDirectory -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator from codeflash.models.models import CodeOptimizationContext -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_add_decorator_imports_helper_in_class(): @@ -17,9 +17,9 @@ def test_add_decorator_imports_helper_in_class(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) @@ -89,9 +89,9 @@ def test_add_decorator_imports_helper_in_nested_class(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) @@ -138,9 +138,9 @@ def test_add_decorator_imports_nodeps(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) @@ -189,9 +189,9 @@ def test_add_decorator_imports_helper_outside(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) @@ -266,9 +266,9 @@ def __init__(self, arr): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index b31804259..86d92d963 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -15,7 +15,7 @@ FunctionImportedAsVisitor, inject_profiling_into_existing_test, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports from codeflash.models.models import ( @@ -28,7 +28,7 @@ TestsInFile, TestType, ) -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' @@ -430,9 +430,9 @@ def test_sort(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_env = os.environ.copy() @@ -691,9 +691,9 @@ def test_sort_parametrized(input, expected_output): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -980,9 +980,9 @@ def test_sort_parametrized_loop(input, expected_output): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -1337,9 +1337,9 @@ def test_sort(): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -1719,9 +1719,9 @@ def test_sort(self): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="unittest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -1969,9 +1969,9 @@ def test_sort(self, input, expected_output): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="unittest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -2225,9 +2225,9 @@ def test_sort(self): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="unittest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -2477,9 +2477,9 @@ def test_sort(self, input, expected_output): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="unittest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=f, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( @@ -2985,7 +2985,7 @@ def test_code_replacement10() -> None: disable_telemetry=True, tests_root="tests", test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ), ) @@ -3033,7 +3033,7 @@ def test_code_replacement10() -> None: codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') get_code_output = 'random code' file_path = Path(__file__).resolve() - opt = Optimizer(Namespace(project_root=str(file_path.parent.resolve()), disable_telemetry=True, tests_root='tests', test_framework='pytest', pytest_cmd='pytest', experiment_id=None)) + opt = Optimizer(Namespace(project_root=str(file_path.parent.resolve()), disable_telemetry=True, tests_root='tests', test_framework='pytest', test_command='pytest', experiment_id=None)) func_top_optimize = FunctionToOptimize(function_name='main_method', file_path=str(file_path), parents=[FunctionParent('MainClass', 'ClassDef')]) with open(file_path) as f: original_code = f.read() @@ -3140,9 +3140,9 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_files = TestFiles( @@ -3275,9 +3275,9 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): test_config = TestConfig( tests_root=tests_root, tests_project_rootdir=project_root_path, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="unittest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) test_results, coverage_data = func_optimizer.run_and_parse_tests( diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index 69fa82d2e..a1c219768 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -10,7 +10,7 @@ from code_to_optimize.bubble_sort_method import BubbleSorter from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.formatter import sort_imports -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results @@ -150,7 +150,7 @@ def test_single_element_list(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -293,7 +293,7 @@ def test_single_element_list(): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -384,7 +384,7 @@ def sorter(self, arr): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) @@ -438,7 +438,7 @@ def sorter(self, arr): disable_telemetry=True, tests_root=tests_root, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root_path, ) diff --git a/tests/test_java_assertion_removal.py b/tests/test_java_assertion_removal.py index 3e88440fd..007e34479 100644 --- a/tests/test_java_assertion_removal.py +++ b/tests/test_java_assertion_removal.py @@ -8,7 +8,7 @@ from pathlib import Path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions diff --git a/tests/test_java_test_discovery.py b/tests/test_java_test_discovery.py index fc9b01f2a..f5f0a46b3 100644 --- a/tests/test_java_test_discovery.py +++ b/tests/test_java_test_discovery.py @@ -20,7 +20,7 @@ get_test_class_for_source_class, is_test_file, ) -from codeflash.models.function_types import FunctionParent, FunctionToOptimize +from codeflash_core.models import FunctionParent, FunctionToOptimize # --------------------------------------------------------------------------- # Helpers diff --git a/tests/test_java_tests_project_rootdir.py b/tests/test_java_tests_project_rootdir.py index 8985fed2b..1ea4925f1 100644 --- a/tests/test_java_tests_project_rootdir.py +++ b/tests/test_java_tests_project_rootdir.py @@ -6,7 +6,7 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.languages.base import Language from codeflash.languages.current import reset_current_language, set_current_language -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_java_tests_project_rootdir_set_to_tests_root(tmp_path): @@ -23,7 +23,7 @@ def test_java_tests_project_rootdir_set_to_tests_root(tmp_path): # (simulating what happens before the fix) test_cfg = TestConfig( tests_root=tests_root, - project_root_path=project_root, + project_root=project_root, tests_project_rootdir=project_root, # Initially set to project root ) @@ -67,7 +67,7 @@ def test_python_tests_project_rootdir_unchanged(tmp_path): # Create test config original_tests_project_rootdir = project_root / "some" / "other" / "dir" test_cfg = TestConfig( - tests_root=tests_root, project_root_path=project_root, tests_project_rootdir=original_tests_project_rootdir + tests_root=tests_root, project_root=project_root, tests_project_rootdir=original_tests_project_rootdir ) # Mock pytest discovery diff --git a/tests/test_javascript_assertion_removal.py b/tests/test_javascript_assertion_removal.py index 8cc12cc4a..14c37264a 100644 --- a/tests/test_javascript_assertion_removal.py +++ b/tests/test_javascript_assertion_removal.py @@ -8,7 +8,7 @@ from pathlib import Path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.javascript.instrument import TestingMode, instrument_generated_js_test, transform_expect_calls from codeflash.models.models import FunctionParent diff --git a/tests/test_javascript_function_discovery.py b/tests/test_javascript_function_discovery.py index cf76bee2d..5a6db80f4 100644 --- a/tests/test_javascript_function_discovery.py +++ b/tests/test_javascript_function_discovery.py @@ -13,7 +13,7 @@ get_functions_to_optimize, ) from codeflash.languages.base import Language -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig class TestJavaScriptFunctionDiscovery: @@ -322,7 +322,7 @@ def test_get_functions_from_file(self, tmp_path): """) test_config = TestConfig( tests_root=str(tmp_path / "tests"), - project_root_path=str(tmp_path), + project_root=str(tmp_path), tests_project_rootdir=tmp_path / "tests", ) @@ -356,7 +356,7 @@ def test_get_specific_function(self, tmp_path): """) test_config = TestConfig( tests_root=str(tmp_path / "tests"), - project_root_path=str(tmp_path), + project_root=str(tmp_path), tests_project_rootdir=tmp_path / "tests", ) @@ -394,7 +394,7 @@ def test_get_class_method(self, tmp_path): """) test_config = TestConfig( tests_root=str(tmp_path / "tests"), - project_root_path=str(tmp_path), + project_root=str(tmp_path), tests_project_rootdir=tmp_path / "tests", ) diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index f70a82f01..444058ce6 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -24,11 +24,11 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig @pytest.fixture @@ -1838,7 +1838,7 @@ def test_with_tricky_helpers(self, ts_support, temp_project): ) test_config = TestConfig( - tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project + tests_root=temp_project, tests_project_rootdir=temp_project, project_root=temp_project ) func_optimizer = JavaScriptFunctionOptimizer( function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock() diff --git a/tests/test_languages/test_find_references.py b/tests/test_languages/test_find_references.py index 979af23e3..83b589a82 100644 --- a/tests/test_languages/test_find_references.py +++ b/tests/test_languages/test_find_references.py @@ -14,7 +14,7 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language, ReferenceInfo from codeflash.languages.javascript.find_references import ( ExportedFunction, diff --git a/tests/test_languages/test_java/test_context.py b/tests/test_languages/test_java/test_context.py index 17dc1ca25..ca1ede23e 100644 --- a/tests/test_languages/test_java/test_context.py +++ b/tests/test_languages/test_java/test_context.py @@ -1805,8 +1805,8 @@ def test_unicode_in_source(self, tmp_path: Path): def test_file_not_found(self, tmp_path: Path): """Test context extraction for missing file.""" - from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.function_types import FunctionParent + from codeflash_core.models import FunctionToOptimize + from codeflash_core.models import FunctionParent missing_file = tmp_path / "NonExistent.java" func = FunctionToOptimize( diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index e82194e17..25fa92c20 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -19,7 +19,7 @@ # Set API key for tests that instantiate Optimizer os.environ["CODEFLASH_API_KEY"] = "cf-test-key" -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.current import set_current_language from codeflash.languages.java.maven_strategy import MavenStrategy @@ -2064,7 +2064,7 @@ def test_run_and_parse_behavior_mode(self, java_project): """Test run_and_parse_tests in BEHAVIOR mode.""" from argparse import Namespace - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer @@ -2129,7 +2129,7 @@ def test_run_and_parse_behavior_mode(self, java_project): disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) @@ -2180,7 +2180,7 @@ def test_run_and_parse_performance_mode(self, java_project): """ from argparse import Namespace - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer @@ -2284,7 +2284,7 @@ def test_run_and_parse_performance_mode(self, java_project): disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) @@ -2343,7 +2343,7 @@ def test_run_and_parse_multiple_test_methods(self, java_project): """Test run_and_parse_tests with multiple test methods.""" from argparse import Namespace - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer @@ -2416,7 +2416,7 @@ def test_run_and_parse_multiple_test_methods(self, java_project): disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) @@ -2458,7 +2458,7 @@ def test_run_and_parse_failing_test(self, java_project): """Test run_and_parse_tests correctly reports failing tests.""" from argparse import Namespace - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer @@ -2522,7 +2522,7 @@ def test_run_and_parse_failing_test(self, java_project): disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) @@ -2563,7 +2563,7 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): from argparse import Namespace from codeflash.code_utils.code_utils import get_run_tmp_file - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer @@ -2709,7 +2709,7 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) diff --git a/tests/test_languages/test_java/test_java_tracer_integration.py b/tests/test_languages/test_java/test_java_tracer_integration.py index f6ffefdf2..8eb4418ed 100644 --- a/tests/test_languages/test_java/test_java_tracer_integration.py +++ b/tests/test_languages/test_java/test_java_tracer_integration.py @@ -59,7 +59,7 @@ def test_discover_functions_from_replay_tests(self, traced_workload: tuple) -> N _trace_db, _jfr_file, output_dir, _test_count = traced_workload from codeflash.discovery.functions_to_optimize import _get_java_replay_test_functions - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig replay_test_paths = list(output_dir.glob("*.java")) assert len(replay_test_paths) >= 1 @@ -67,8 +67,8 @@ def test_discover_functions_from_replay_tests(self, traced_workload: tuple) -> N test_cfg = TestConfig( tests_root=FIXTURE_DIR / "src" / "test" / "java", tests_project_rootdir=FIXTURE_DIR, - project_root_path=FIXTURE_DIR, - pytest_cmd="pytest", + project_root=FIXTURE_DIR, + test_command="pytest", ) functions, trace_file_path = _get_java_replay_test_functions(replay_test_paths, test_cfg, FIXTURE_DIR) @@ -189,14 +189,14 @@ def test_full_pipeline(self, compiled_workload: Path, tmp_path: Path) -> None: # Step 1: Discover functions from replay tests (like get_optimizable_functions) from codeflash.discovery.functions_to_optimize import _get_java_replay_test_functions - from codeflash.verification.verification_utils import TestConfig + from codeflash_core.config import TestConfig replay_test_paths = list(output_dir.glob("*.java")) test_cfg = TestConfig( tests_root=test_root, tests_project_rootdir=FIXTURE_DIR, - project_root_path=FIXTURE_DIR, - pytest_cmd="pytest", + project_root=FIXTURE_DIR, + test_command="pytest", ) file_to_funcs, trace_file_path = _get_java_replay_test_functions(replay_test_paths, test_cfg, FIXTURE_DIR) diff --git a/tests/test_languages/test_java/test_replacement.py b/tests/test_languages/test_java/test_replacement.py index 1bd4f7abb..656537c3f 100644 --- a/tests/test_languages/test_java/test_replacement.py +++ b/tests/test_languages/test_java/test_replacement.py @@ -1607,7 +1607,7 @@ def test_replace_specific_overload_by_line_number(self, tmp_path: Path, java_sup optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java") # Create FunctionToOptimize with line info for the 3-arg version (lines 13-18) - from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + from codeflash_core.models import FunctionParent, FunctionToOptimize function_to_optimize = FunctionToOptimize( function_name="bytesToHexString", @@ -1615,7 +1615,6 @@ def test_replace_specific_overload_by_line_number(self, tmp_path: Path, java_sup starting_line=13, # Line where 3-arg version starts (1-indexed) ending_line=18, parents=[FunctionParent(name="Buffer", type="ClassDef")], - qualified_name="Buffer.bytesToHexString", is_method=True, ) @@ -1678,7 +1677,7 @@ def test_standalone_wrong_method_name_leaves_source_unchanged(self, tmp_path, ja Applying that would create a duplicate ``unpackMap`` and delete ``unpackObjectMap``, causing compilation failures. """ - from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + from codeflash_core.models import FunctionParent, FunctionToOptimize java_file = tmp_path / "Unpacker.java" original_code = """\ @@ -1710,7 +1709,6 @@ def test_standalone_wrong_method_name_leaves_source_unchanged(self, tmp_path, ja starting_line=2, ending_line=4, parents=[FunctionParent(name="Unpacker", type="ClassDef")], - qualified_name="Unpacker.unpackObjectMap", is_method=True, ) @@ -1733,7 +1731,7 @@ def test_class_wrapper_with_wrong_target_method_leaves_source_unchanged(self, tm contained only ``sizeTxn`` (a helper) and did not include ``estimateKeySize`` (the target). Applying it would duplicate ``sizeTxn`` in the source. """ - from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + from codeflash_core.models import FunctionParent, FunctionToOptimize java_file = tmp_path / "Command.java" original_code = """\ @@ -1767,7 +1765,6 @@ def test_class_wrapper_with_wrong_target_method_leaves_source_unchanged(self, tm starting_line=2, ending_line=4, parents=[FunctionParent(name="Command", type="ClassDef")], - qualified_name="Command.estimateKeySize", is_method=True, ) @@ -1803,7 +1800,7 @@ def test_anonymous_iterator_methods_not_hoisted_to_class(self, tmp_path, java_su Those three methods must remain inside the anonymous class body and must NOT be added as top-level members of the outer class. """ - from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + from codeflash_core.models import FunctionParent, FunctionToOptimize java_file = tmp_path / "LuaMap.java" original_code = """\ @@ -1876,7 +1873,6 @@ def test_anonymous_iterator_methods_not_hoisted_to_class(self, tmp_path, java_su starting_line=11, ending_line=13, parents=[FunctionParent(name="LuaMap", type="ClassDef")], - qualified_name="LuaMap.keySetIterator", is_method=True, ) diff --git a/tests/test_languages/test_java/test_run_and_parse.py b/tests/test_languages/test_java/test_run_and_parse.py index 7d093dbb3..48da387fb 100644 --- a/tests/test_languages/test_java/test_run_and_parse.py +++ b/tests/test_languages/test_java/test_run_and_parse.py @@ -13,7 +13,7 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.current import set_current_language from codeflash.languages.java.instrumentation import instrument_existing_test @@ -126,7 +126,7 @@ def _make_optimizer(project_root: Path, test_dir: Path, function_name: str, src_ disable_telemetry=True, tests_root=test_dir, test_project_root=project_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, ) ) diff --git a/tests/test_languages/test_java/test_test_discovery.py b/tests/test_languages/test_java/test_test_discovery.py index d6830389c..d2c6300d7 100644 --- a/tests/test_languages/test_java/test_test_discovery.py +++ b/tests/test_languages/test_java/test_test_discovery.py @@ -165,7 +165,7 @@ def test_find_tests(self, tmp_path: Path): """) # Create source function - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize func = FunctionToOptimize( function_name="reverse", diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index a700996c1..114af8e47 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -8,7 +8,7 @@ import tempfile from pathlib import Path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import FunctionInfo, Language from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler from codeflash.languages.javascript.tracer import JavaScriptTracer diff --git a/tests/test_languages/test_javascript_optimization_flow.py b/tests/test_languages/test_javascript_optimization_flow.py index 7ec447d06..3c576f4b2 100644 --- a/tests/test_languages/test_javascript_optimization_flow.py +++ b/tests/test_languages/test_javascript_optimization_flow.py @@ -13,10 +13,10 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.models.models import CodeString, FunctionParent -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def skip_if_js_not_supported(): @@ -322,8 +322,8 @@ def test_function_optimizer_instantiation_javascript(self, js_project): test_config = TestConfig( tests_root=js_project / "tests", tests_project_rootdir=js_project, - project_root_path=js_project, - pytest_cmd="jest", + project_root=js_project, + test_command="jest", ) optimizer = FunctionOptimizer( @@ -355,8 +355,8 @@ def test_function_optimizer_instantiation_typescript(self, ts_project): test_config = TestConfig( tests_root=ts_project / "tests", tests_project_rootdir=ts_project, - project_root_path=ts_project, - pytest_cmd="vitest", + project_root=ts_project, + test_command="vitest", ) optimizer = FunctionOptimizer( @@ -388,8 +388,8 @@ def test_get_code_optimization_context_javascript(self, js_project): test_config = TestConfig( tests_root=js_project / "tests", tests_project_rootdir=js_project, - project_root_path=js_project, - pytest_cmd="jest", + project_root=js_project, + test_command="jest", ) optimizer = JavaScriptFunctionOptimizer( @@ -425,8 +425,8 @@ def test_get_code_optimization_context_typescript(self, ts_project): test_config = TestConfig( tests_root=ts_project / "tests", tests_project_rootdir=ts_project, - project_root_path=ts_project, - pytest_cmd="vitest", + project_root=ts_project, + test_command="vitest", ) optimizer = JavaScriptFunctionOptimizer( @@ -477,7 +477,7 @@ def test_helper_functions_have_correct_language_javascript(self, tmp_path): ) test_config = TestConfig( - tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="jest" + tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root=tmp_path, test_command="jest" ) optimizer = JavaScriptFunctionOptimizer( @@ -521,7 +521,7 @@ def test_helper_functions_have_correct_language_typescript(self, tmp_path): ) test_config = TestConfig( - tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root_path=tmp_path, pytest_cmd="vitest" + tests_root=tmp_path, tests_project_rootdir=tmp_path, project_root=tmp_path, test_command="vitest" ) optimizer = JavaScriptFunctionOptimizer( diff --git a/tests/test_languages/test_javascript_run_and_parse.py b/tests/test_languages/test_javascript_run_and_parse.py index 3781cc637..7666b53a2 100644 --- a/tests/test_languages/test_javascript_run_and_parse.py +++ b/tests/test_languages/test_javascript_run_and_parse.py @@ -22,10 +22,10 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.models.models import FunctionParent -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def is_node_available(): @@ -373,8 +373,8 @@ def test_function_optimizer_run_and_parse_typescript(self, vitest_project): test_config = TestConfig( tests_root=vitest_project / "tests", tests_project_rootdir=vitest_project, - project_root_path=vitest_project, - pytest_cmd="vitest", + project_root=vitest_project, + test_command="vitest", test_framework="vitest", ) diff --git a/tests/test_languages/test_javascript_setup_test_config.py b/tests/test_languages/test_javascript_setup_test_config.py index 2455dacba..ba94f8bb3 100644 --- a/tests/test_languages/test_javascript_setup_test_config.py +++ b/tests/test_languages/test_javascript_setup_test_config.py @@ -4,7 +4,7 @@ import pytest from codeflash.languages.javascript.support import JavaScriptSupport -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig @pytest.fixture @@ -15,7 +15,7 @@ def js_support() -> JavaScriptSupport: def make_test_config(project_root: Path) -> TestConfig: return TestConfig( tests_root=project_root / "tests", - project_root_path=project_root, + project_root=project_root, tests_project_rootdir=project_root, ) diff --git a/tests/test_languages/test_js_code_extractor.py b/tests/test_languages/test_js_code_extractor.py index 8cab1dedd..2d7f10d27 100644 --- a/tests/test_languages/test_js_code_extractor.py +++ b/tests/test_languages/test_js_code_extractor.py @@ -9,13 +9,13 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport from codeflash.languages.registry import get_language_support from codeflash.models.models import FunctionParent -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig FIXTURES_DIR = Path(__file__).parent / "fixtures" @@ -1098,8 +1098,8 @@ def test_function_optimizer_workflow(self, cjs_project): test_config = TestConfig( tests_root=cjs_project / "tests", tests_project_rootdir=cjs_project, - project_root_path=cjs_project, - pytest_cmd="jest", + project_root=cjs_project, + test_command="jest", ) func_optimizer = JavaScriptFunctionOptimizer( diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 38ced89a7..93a814edd 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -2434,7 +2434,7 @@ def test_arrow_function_replacement_after_new_global_const(self, ts_support, tem Uses function_to_optimize (the if-branch in replace_function_definitions_for_language) to reproduce the real pipeline where stale starting_line causes failures. """ - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown original_source = """\ @@ -2525,7 +2525,7 @@ def test_arrow_function_with_new_set_declaration(self, ts_support, temp_project) introduces a new Set (NON_FILE_TYPES/MEDIA_TYPES) above the arrow function. Uses function_to_optimize to trigger the stale-line-number bug. """ - from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown original_source = """\ diff --git a/tests/test_languages/test_multi_file_code_replacer.py b/tests/test_languages/test_multi_file_code_replacer.py index 2e6bf5fb9..0f6f7f5db 100644 --- a/tests/test_languages/test_multi_file_code_replacer.py +++ b/tests/test_languages/test_multi_file_code_replacer.py @@ -81,11 +81,11 @@ from pathlib import Path from unittest.mock import MagicMock -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer from codeflash.languages.registry import get_language_support from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig class Args: @@ -132,8 +132,8 @@ def test_js_replcement() -> None: test_config = TestConfig( tests_root=root_dir / "code_to_optimize/js/code_to_optimize_js/tests", tests_project_rootdir=root_dir, - project_root_path=root_dir, - pytest_cmd="jest", + project_root=root_dir, + test_command="jest", ) func_optimizer = JavaScriptFunctionOptimizer( function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock() diff --git a/tests/test_languages/test_python_support.py b/tests/test_languages/test_python_support.py index bd1106ab4..ee66519d3 100644 --- a/tests/test_languages/test_python_support.py +++ b/tests/test_languages/test_python_support.py @@ -567,7 +567,7 @@ def test_find_references_simple_function(python_support, tmp_path): This test specifically exercises the code path that was fixed in the regression where function.name was used instead of function.function_name. """ - from codeflash.models.function_types import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize # Create source file with function definition source_file = tmp_path / "utils.py" @@ -597,7 +597,7 @@ def test_find_references_class_method(python_support, tmp_path): This verifies the class_name attribute is correctly used to disambiguate methods. """ - from codeflash.models.function_types import FunctionParent, FunctionToOptimize + from codeflash_core.models import FunctionParent, FunctionToOptimize # Create source file with class and method source_file = tmp_path / "calculator.py" @@ -633,7 +633,7 @@ def compute(): def test_find_references_no_references(python_support, tmp_path): """Test that find_references returns empty list when no references exist.""" - from codeflash.models.function_types import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize source_file = tmp_path / "isolated.py" source_file.write_text("""def isolated_function(): @@ -649,7 +649,7 @@ def test_find_references_no_references(python_support, tmp_path): def test_find_references_nonexistent_function(python_support, tmp_path): """Test that find_references handles nonexistent functions gracefully.""" - from codeflash.models.function_types import FunctionToOptimize + from codeflash_core.models import FunctionToOptimize source_file = tmp_path / "source.py" source_file.write_text("""def existing_function(): diff --git a/tests/test_mock_candidate_replacement.py b/tests/test_mock_candidate_replacement.py index 4a2292a11..9ef270424 100644 --- a/tests/test_mock_candidate_replacement.py +++ b/tests/test_mock_candidate_replacement.py @@ -5,12 +5,12 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.context.unused_definition_remover import detect_unused_helper_functions -from codeflash.models.function_types import FunctionParent +from codeflash_core.models import FunctionParent from codeflash.models.models import CodeStringsMarkdown from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig ORIGINAL_SOURCE = '''\ import contextlib @@ -628,9 +628,9 @@ def temp_project(): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) yield temp_dir, source_file, test_cfg diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index 82256001a..342b8d461 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -1,9 +1,9 @@ from pathlib import Path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig class Args: @@ -102,9 +102,9 @@ def _get_string_usage(text: str) -> Usage: test_config = TestConfig( tests_root=root_dir / "tests/pytest", tests_project_rootdir=root_dir, - project_root_path=root_dir, + project_root=root_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config) code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index 804ff137b..f4ededfe3 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -15,7 +15,7 @@ from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest from codeflash.benchmarking.utils import validate_and_format_benchmark_table from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results @@ -361,7 +361,7 @@ def test_run_and_parse_picklepatch() -> None: project_root=project_root, disable_telemetry=True, tests_root=tests_root, - pytest_cmd="pytest", + test_command="pytest", experiment_id=None, test_project_root=project_root, ) diff --git a/tests/test_ranking_boost.py b/tests/test_ranking_boost.py index c3e6fcd80..8e14d7dc6 100644 --- a/tests/test_ranking_boost.py +++ b/tests/test_ranking_boost.py @@ -7,7 +7,7 @@ import pytest from codeflash.discovery.discover_unit_tests import existing_unit_test_count -from codeflash.models.function_types import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile from codeflash.models.test_type import TestType from codeflash.optimization.optimizer import Optimizer diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index eb20812d6..4dbdb3d11 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -6,7 +6,7 @@ from codeflash.languages import current_language_support from codeflash.models.models import TestFile, TestFiles, TestType from codeflash.verification.parse_test_output import parse_test_xml -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_unittest_runner(): @@ -29,7 +29,7 @@ def test_sort(self): cur_dir_path = Path(__file__).resolve().parent config = TestConfig( tests_root=cur_dir_path, - project_root_path=cur_dir_path, + project_root=cur_dir_path, test_framework="unittest", tests_project_rootdir=cur_dir_path.parent, ) @@ -38,9 +38,9 @@ def test_sort(self): test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_TRACER_DISABLE"] = "1" if "PYTHONPATH" not in test_env: - test_env["PYTHONPATH"] = str(config.project_root_path) + test_env["PYTHONPATH"] = str(config.project_root) else: - test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path) + test_env["PYTHONPATH"] += os.pathsep + str(config.project_root) with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir: test_file_path = Path(temp_dir) / "test_xx.py" @@ -49,7 +49,7 @@ def test_sort(self): ) test_file_path.write_text(code, encoding="utf-8") result_file, process, _, _ = current_language_support().run_behavioral_tests( - test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path) + test_paths=test_files, test_env=test_env, cwd=Path(config.project_root) ) results = parse_test_xml(result_file, test_files, config, process) assert results[0].did_pass, "Test did not pass as expected" @@ -70,7 +70,7 @@ def test_sort(): cur_dir_path = Path(__file__).resolve().parent config = TestConfig( tests_root=cur_dir_path, - project_root_path=cur_dir_path, + project_root=cur_dir_path, test_framework="pytest", tests_project_rootdir=cur_dir_path.parent, ) @@ -79,9 +79,9 @@ def test_sort(): test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_TRACER_DISABLE"] = "1" if "PYTHONPATH" not in test_env: - test_env["PYTHONPATH"] = str(config.project_root_path) + test_env["PYTHONPATH"] = str(config.project_root) else: - test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path) + test_env["PYTHONPATH"] += os.pathsep + str(config.project_root) with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir: test_file_path = Path(temp_dir) / "test_xx.py" @@ -90,7 +90,7 @@ def test_sort(): ) test_file_path.write_text(code, encoding="utf-8") result_file, process, _, _ = current_language_support().run_behavioral_tests( - test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1 + test_paths=test_files, test_env=test_env, cwd=Path(config.project_root), timeout=1 ) results = parse_test_xml( test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process @@ -112,7 +112,7 @@ def test_sort(): cur_dir_path = Path(__file__).resolve().parent config = TestConfig( tests_root=cur_dir_path, - project_root_path=cur_dir_path, + project_root=cur_dir_path, test_framework="pytest", tests_project_rootdir=cur_dir_path.parent, ) @@ -121,9 +121,9 @@ def test_sort(): test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_TRACER_DISABLE"] = "1" if "PYTHONPATH" not in test_env: - test_env["PYTHONPATH"] = str(config.project_root_path) + test_env["PYTHONPATH"] = str(config.project_root) else: - test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path) + test_env["PYTHONPATH"] += os.pathsep + str(config.project_root) with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir: test_file_path = Path(temp_dir) / "test_xx.py" @@ -132,7 +132,7 @@ def test_sort(): ) test_file_path.write_text(code, encoding="utf-8") result_file, process, _, _ = current_language_support().run_behavioral_tests( - test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1 + test_paths=test_files, test_env=test_env, cwd=Path(config.project_root), timeout=1 ) results = parse_test_xml( test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index e232a5b71..31969a0b2 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -7,9 +7,9 @@ discover_unit_tests, filter_test_files_by_imports, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.models.models import FunctionParent, TestsInFile, TestType -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig def test_unit_test_discovery_pytest(): @@ -17,7 +17,7 @@ def test_unit_test_discovery_pytest(): tests_path = project_path / "tests" / "pytest" test_config = TestConfig( tests_root=tests_path, - project_root_path=project_path, + project_root=project_path, test_framework="pytest", tests_project_rootdir=tests_path.parent, ) @@ -30,7 +30,7 @@ def test_benchmark_test_discovery_pytest(): tests_path = project_path / "tests" / "pytest" / "benchmarks" test_config = TestConfig( tests_root=tests_path, - project_root_path=project_path, + project_root=project_path, test_framework="pytest", tests_project_rootdir=tests_path.parent, ) @@ -43,7 +43,7 @@ def test_unit_test_discovery_unittest(): test_path = project_path / "tests" / "unittest" test_config = TestConfig( tests_root=project_path, - project_root_path=project_path, + project_root=project_path, test_framework="unittest", tests_project_rootdir=project_path.parent, ) @@ -81,7 +81,7 @@ def sorter(arr): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tempdirname, - project_root_path=path_obj_tempdirname, + project_root=path_obj_tempdirname, test_framework="pytest", tests_project_rootdir=path_obj_tempdirname.parent, ) @@ -121,7 +121,7 @@ def test_discover_tests_pytest_with_temp_dir_root(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tempdirname, - project_root_path=path_obj_tempdirname, + project_root=path_obj_tempdirname, test_framework="pytest", tests_project_rootdir=path_obj_tempdirname.parent, ) @@ -194,7 +194,7 @@ def test_discover_tests_pytest_with_multi_level_dirs(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -286,7 +286,7 @@ def test_discover_tests_pytest_dirs(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -334,7 +334,7 @@ def test_discover_tests_pytest_with_class(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -375,7 +375,7 @@ def test_discover_tests_pytest_with_double_nested_directories(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -423,7 +423,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): # Create a TestConfig with the code directory as the root test_config = TestConfig( tests_root=test_subdir, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=test_subdir.parent, ) @@ -460,7 +460,7 @@ def test_discover_tests_pytest_with_nested_class(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -500,7 +500,7 @@ def test_discover_tests_pytest_separate_moduledir(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( tests_root=testdir, - project_root_path=codedir.parent.resolve(), + project_root=codedir.parent.resolve(), test_framework="pytest", tests_project_rootdir=testdir.parent, ) @@ -543,7 +543,7 @@ def test_add(self): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -611,7 +611,7 @@ def test_add(self): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -657,7 +657,7 @@ def _test_add(self): # Private test method should not be discovered # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -709,7 +709,7 @@ def test_add_with_parameters(self): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -769,7 +769,7 @@ def test_topological_sort(g): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -919,7 +919,7 @@ def test_build_model_id_to_deployment_index_map(self, router): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", # Using pytest framework to discover unittest tests tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -1004,7 +1004,7 @@ def test_add_mixed(self, name, a, b, expected): # Configure test discovery test_config = TestConfig( tests_root=path_obj_tmpdirname, - project_root_path=path_obj_tmpdirname, + project_root=path_obj_tmpdirname, test_framework="pytest", tests_project_rootdir=path_obj_tmpdirname.parent, ) @@ -1373,7 +1373,7 @@ def test_other(): # Configure test discovery test_config = TestConfig( - tests_root=tmpdir, project_root_path=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent + tests_root=tmpdir, project_root=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent ) all_tests, _, _ = discover_unit_tests(test_config) @@ -1501,7 +1501,7 @@ def test_unrelated(): # Configure test discovery test_config = TestConfig( - tests_root=tmpdir, project_root_path=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent + tests_root=tmpdir, project_root=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent ) # Test without filtering @@ -2024,7 +2024,7 @@ def test_discover_unit_tests_caching(): test_config = TestConfig( tests_root=tests_root, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", tests_project_rootdir=project_root_path, use_cache=False, @@ -2035,7 +2035,7 @@ def test_discover_unit_tests_caching(): ) cache_config = TestConfig( tests_root=tests_root, - project_root_path=project_root_path, + project_root=project_root_path, test_framework="pytest", tests_project_rootdir=project_root_path, use_cache=True, diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index ba5740d5a..a38e7ee66 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -5,14 +5,14 @@ import pytest -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash_core.models import FunctionToOptimize from codeflash.languages.python.context.unused_definition_remover import ( detect_unused_helper_functions, revert_unused_helper_functions, ) from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer from codeflash.models.models import CodeStringsMarkdown -from codeflash.verification.verification_utils import TestConfig +from codeflash_core.config import TestConfig @pytest.fixture @@ -42,9 +42,9 @@ def helper_function_2(x): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) yield temp_dir, main_file, test_cfg @@ -79,7 +79,7 @@ def helper_function_2(x): # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -190,7 +190,7 @@ def helper_function_2(x): # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -265,7 +265,7 @@ def helper_function_2(x): # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -354,14 +354,14 @@ def entrypoint_function(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -543,9 +543,9 @@ def helper_method_2(self, x): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for class method @@ -554,7 +554,6 @@ def helper_method_2(self, x): function_to_optimize = FunctionToOptimize( file_path=main_file, function_name="entrypoint_method", - qualified_name="Calculator.entrypoint_method", parents=[FunctionParent(name="Calculator", type="ClassDef")], ) @@ -694,9 +693,9 @@ def process_data(self, n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for class method @@ -705,7 +704,6 @@ def process_data(self, n): function_to_optimize = FunctionToOptimize( file_path=main_file, function_name="process_data", - qualified_name="Processor.process_data", parents=[FunctionParent(name="Processor", type="ClassDef")], ) @@ -875,9 +873,9 @@ def local_helper(self, x): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Note: In practice, codeflash might not handle deeply nested classes, @@ -887,7 +885,6 @@ def local_helper(self, x): function_to_optimize = FunctionToOptimize( file_path=main_file, function_name="compute", - qualified_name="OuterClass.InnerProcessor.compute", parents=[ FunctionParent(name="OuterClass", type="ClassDef"), FunctionParent(name="InnerProcessor", type="ClassDef"), @@ -1040,14 +1037,14 @@ def entrypoint_function(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -1204,14 +1201,14 @@ def entrypoint_function(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance function_to_optimize = FunctionToOptimize( - file_path=main_file, function_name="entrypoint_function", qualified_name="entrypoint_function", parents=[] + file_path=main_file, function_name="entrypoint_function", parents=[] ) # Create function optimizer @@ -1426,9 +1423,9 @@ def calculate_class(cls, n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Test static method optimization @@ -1437,7 +1434,6 @@ def calculate_class(cls, n): function_to_optimize = FunctionToOptimize( file_path=main_file, function_name="calculate_static", - qualified_name="MathUtils.calculate_static", parents=[FunctionParent(name="MathUtils", type="ClassDef")], ) @@ -1565,9 +1561,9 @@ async def async_entrypoint(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for async function @@ -1655,9 +1651,9 @@ def sync_entrypoint(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for sync function @@ -1762,9 +1758,9 @@ async def mixed_entrypoint(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for async function @@ -1858,9 +1854,9 @@ def sync_helper_method(self, x): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for async class method @@ -1949,9 +1945,9 @@ async def async_entrypoint(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for async function @@ -2030,9 +2026,9 @@ def gcd_recursive(a: int, b: int) -> int: test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance @@ -2141,9 +2137,9 @@ async def async_entrypoint_with_generators(n): test_cfg = TestConfig( tests_root=temp_dir / "tests", tests_project_rootdir=temp_dir, - project_root_path=temp_dir, + project_root=temp_dir, test_framework="pytest", - pytest_cmd="pytest", + test_command="pytest", ) # Create FunctionToOptimize instance for async function diff --git a/tests/test_worktree.py b/tests/test_worktree.py index 75de860fd..d7ab1e7b1 100644 --- a/tests/test_worktree.py +++ b/tests/test_worktree.py @@ -36,7 +36,7 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch): assert optimizer.args.file == worktree_dir / "src" / "app" / "main.py" assert optimizer.test_cfg.tests_root == worktree_dir / "src" / "tests" - assert optimizer.test_cfg.project_root_path == worktree_dir / "src" # same as project_root + assert optimizer.test_cfg.project_root == worktree_dir / "src" # same as project_root assert optimizer.test_cfg.tests_project_rootdir == worktree_dir / "src" # same as test_project_root # test on our repo @@ -65,5 +65,5 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch): assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py" assert optimizer.test_cfg.tests_root == worktree_dir / "tests" - assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root + assert optimizer.test_cfg.project_root == worktree_dir # same as project_root assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root From 55bc577811402d547b76017a053c37192ca6071c Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 06:58:55 -0500 Subject: [PATCH 2/9] refactor: wire PythonPlugin into optimizer and remove dead code --- codeflash/languages/__init__.py | 5 - codeflash/languages/base.py | 13 +- .../languages/java/concurrency_analyzer.py | 7 +- codeflash/languages/java/line_profiler.py | 11 +- codeflash/optimization/optimizer.py | 320 +++--------------- src/codeflash_python/plugin.py | 7 + tests/test_languages/test_base.py | 44 +-- .../test_java/test_concurrency_analyzer.py | 33 +- .../test_java/test_line_profiler.py | 35 +- .../test_line_profiler_integration.py | 17 +- tests/test_languages/test_java_e2e.py | 13 +- tests/test_languages/test_javascript_e2e.py | 9 +- .../test_javascript_instrumentation.py | 10 +- .../test_languages/test_javascript_support.py | 20 +- tests/test_languages/test_language_parity.py | 20 +- tests/test_languages/test_python_support.py | 16 +- tests/test_languages/test_typescript_e2e.py | 13 +- tests/test_languages/test_vitest_e2e.py | 5 +- tests/test_ranking_boost.py | 148 +------- 19 files changed, 187 insertions(+), 559 deletions(-) diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index b66c3211b..3f2fbdf92 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -63,10 +63,6 @@ # Lazy imports to avoid circular imports def __getattr__(name: str): - if name == "FunctionInfo": - from codeflash_core.models import FunctionToOptimize - - return FunctionToOptimize if name == "JavaScriptSupport": from codeflash.languages.javascript.support import JavaScriptSupport @@ -90,7 +86,6 @@ def __getattr__(name: str): __all__ = [ "CodeContext", "DependencyResolver", - "FunctionInfo", "HelperFunction", "IndexResult", "Language", diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 9c385a271..713b47b20 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -31,17 +31,6 @@ ParentInfo = FunctionParent -# Lazy import for FunctionInfo to avoid circular imports -# This allows `from codeflash.languages.base import FunctionInfo` to work at runtime -def __getattr__(name: str) -> Any: - if name == "FunctionInfo": - from codeflash_core.models import FunctionToOptimize - - return FunctionToOptimize - msg = f"module {__name__!r} has no attribute {name!r}" - raise AttributeError(msg) - - @dataclass(frozen=True) class IndexResult: file_path: Path @@ -259,7 +248,7 @@ class PythonSupport(LanguageSupport): def language(self) -> Language: return Language.PYTHON - def discover_functions(self, source: str, file_path: Path, ...) -> list[FunctionInfo]: + def discover_functions(self, source: str, file_path: Path, ...) -> list[FunctionToOptimize]: # Python-specific implementation using LibCST ... diff --git a/codeflash/languages/java/concurrency_analyzer.py b/codeflash/languages/java/concurrency_analyzer.py index d529a4265..653207c1d 100644 --- a/codeflash/languages/java/concurrency_analyzer.py +++ b/codeflash/languages/java/concurrency_analyzer.py @@ -19,7 +19,8 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash.languages.base import FunctionInfo + from codeflash_core.models import FunctionToOptimize + logger = logging.getLogger(__name__) @@ -131,7 +132,7 @@ def __init__(self, analyzer=None) -> None: """ self.analyzer = analyzer - def analyze_function(self, func: FunctionInfo, source: str | None = None) -> ConcurrencyInfo: + def analyze_function(self, func: FunctionToOptimize, source: str | None = None) -> ConcurrencyInfo: """Analyze a function for concurrency patterns. Args: @@ -305,7 +306,7 @@ def get_optimization_suggestions(concurrency_info: ConcurrencyInfo) -> list[str] return suggestions -def analyze_function_concurrency(func: FunctionInfo, source: str | None = None, analyzer=None) -> ConcurrencyInfo: +def analyze_function_concurrency(func: FunctionToOptimize, source: str | None = None, analyzer=None) -> ConcurrencyInfo: """Analyze a function for concurrency patterns. Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function. diff --git a/codeflash/languages/java/line_profiler.py b/codeflash/languages/java/line_profiler.py index 854a8549d..ffa878d1d 100644 --- a/codeflash/languages/java/line_profiler.py +++ b/codeflash/languages/java/line_profiler.py @@ -22,7 +22,8 @@ if TYPE_CHECKING: from tree_sitter import Node - from codeflash.languages.base import FunctionInfo + from codeflash_core.models import FunctionToOptimize + logger = logging.getLogger(__name__) @@ -74,7 +75,7 @@ def __init__(self, output_file: Path, warmup_iterations: int = DEFAULT_WARMUP_IT # === Agent-based profiling (bytecode instrumentation) === def generate_agent_config( - self, source: str, file_path: Path, functions: list[FunctionInfo], config_output_path: Path + self, source: str, file_path: Path, functions: list[FunctionToOptimize], config_output_path: Path ) -> Path: """Generate config JSON for the profiler agent. @@ -141,7 +142,7 @@ def build_javaagent_arg(self, config_path: Path) -> str: # === Source-level instrumentation === def instrument_source( - self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer: Any = None + self, source: str, file_path: Path, functions: list[FunctionToOptimize], analyzer: Any = None ) -> str: """Instrument Java source code with line profiling. @@ -326,7 +327,9 @@ class {self.profiler_class} {{ }} """ - def instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer: Any) -> list[str]: + def instrument_function( + self, func: FunctionToOptimize, lines: list[str], file_path: Path, analyzer: Any + ) -> list[str]: """Instrument a single function with line profiling. Args: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 6ff310f1e..2088f55e0 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -10,7 +10,7 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email -from codeflash.cli_cmds.console import call_graph_live_display, call_graph_summary, console, logger, progress_bar +from codeflash.cli_cmds.console import call_graph_live_display, console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.config_consts import HIGH_EFFORT_TOP_N, EffortLevel @@ -23,34 +23,21 @@ remove_worktree, ) from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.either import is_successful -from codeflash.languages import current_language_support, set_current_language +from codeflash.languages import set_current_language from codeflash.lsp.helpers import is_subagent_mode from codeflash.telemetry.posthog_cf import ph from codeflash_core.config import TestConfig +from codeflash_python.plugin import PythonPlugin if TYPE_CHECKING: import ast from argparse import Namespace - from codeflash.benchmarking.function_ranker import FunctionRanker from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint - from codeflash.languages.base import DependencyResolver - from codeflash.languages.function_optimizer import FunctionOptimizer - from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode from codeflash_core.models import FunctionToOptimize - - -def _extract_java_package_from_path(file_path: Path) -> str | None: - """Extract Java package from file path by finding src/main/java or src/test/java marker.""" - parts = file_path.parts - for i, part in enumerate(parts): - if part == "java" and i >= 2 and parts[i - 1] in ("main", "test") and parts[i - 2] == "src": - package_parts = parts[i + 1 : -1] # After java/, exclude filename - if package_parts: - return ".".join(package_parts) - return None - return None + from codeflash_python.context.types import DependencyResolver + from codeflash_python.function_optimizer import FunctionOptimizer + from codeflash_python.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode class Optimizer: @@ -65,6 +52,7 @@ def __init__(self, args: Namespace) -> None: benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None, ) + self.plugin = PythonPlugin(args.project_root) self.aiservice_client = AiServiceClient() self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None @@ -76,7 +64,6 @@ def __init__(self, args: Namespace) -> None: self.current_worktree: Path | None = None self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None self.patch_files: list[Path] = [] - self._cached_callee_counts: dict[tuple[Path, str], int] = {} def run_benchmarks( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int @@ -88,11 +75,11 @@ def run_benchmarks( if not (hasattr(self.args, "benchmark") and self.args.benchmark and num_optimizable_functions > 0): return function_benchmark_timings, total_benchmark_timings - from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator - from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin - from codeflash.benchmarking.replay_test import generate_replay_test - from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest - from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table + from codeflash_python.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator + from codeflash_python.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin + from codeflash_python.benchmarking.replay_test import generate_replay_test + from codeflash_python.benchmarking.trace_benchmarks import trace_benchmarks_pytest + from codeflash_python.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table console.rule() with progress_bar( @@ -142,7 +129,7 @@ def run_benchmarks( def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: """Discover functions to optimize.""" - from codeflash.discovery.functions_to_optimize import get_functions_to_optimize + from codeflash_python.discovery.functions_to_optimize import get_functions_to_optimize # In worktree mode for git-diff discovery, file paths come from the original repo # (via get_git_diff using cwd), but module_root/project_root have been mirrored to @@ -215,11 +202,9 @@ def create_function_optimizer( ): function_specific_timings = function_benchmark_timings[qualified_name_w_module] - cls = current_language_support().function_optimizer_class + from codeflash_python.function_optimizer import FunctionOptimizer as PythonFunctionOptimizer - # TODO: _resolve_function_ast re-parses source via ast.parse() per function, even when the caller already - # has a parsed module AST. Consider passing the pre-parsed AST through to avoid redundant parsing. - function_optimizer = cls( + function_optimizer = PythonFunctionOptimizer( function_to_optimize=function_to_optimize, test_cfg=self.test_cfg, function_to_optimize_source_code=function_to_optimize_source_code, @@ -249,14 +234,14 @@ def prepare_module_for_optimization( original_module_code: str = original_module_path.read_text(encoding="utf8") - return current_language_support().prepare_module( - original_module_code, original_module_path, self.args.project_root - ) + from codeflash_python.optimizer import prepare_python_module + + return prepare_python_module(original_module_code, original_module_path, self.args.project_root) def discover_tests( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] ) -> tuple[dict[str, set[FunctionCalledInTest]], int]: - from codeflash.discovery.discover_unit_tests import discover_unit_tests + from codeflash_python.discovery.discover_unit_tests import discover_unit_tests console.rule() start_time = time.time() @@ -272,213 +257,6 @@ def discover_tests( ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) return function_to_tests, num_discovered_tests - def display_global_ranking( - self, globally_ranked: list[tuple[Path, FunctionToOptimize]], ranker: FunctionRanker, show_top_n: int = 15 - ) -> None: - from rich.table import Table - - if not globally_ranked: - return - - # Show top N functions - display_count = min(show_top_n, len(globally_ranked)) - - table = Table( - title=f"Function Ranking (Top {display_count} of {len(globally_ranked)})", - title_style="bold cyan", - border_style="cyan", - show_lines=False, - ) - - table.add_column("Priority", style="bold yellow", justify="center", width=8) - table.add_column("Function", style="cyan", width=40) - table.add_column("File", style="dim", width=25) - table.add_column("Addressable Time", justify="right", style="green", width=12) - table.add_column("Impact", justify="center", style="bold", width=8) - - # Get addressable time for display - for i, (file_path, func) in enumerate(globally_ranked[:display_count], 1): - addressable_time = ranker.get_function_addressable_time(func) - - # Format function name - func_name = func.qualified_name - if len(func_name) > 38: - func_name = func_name[:35] + "..." - - # Format file name - file_name = file_path.name - if len(file_name) > 23: - file_name = "..." + file_name[-20:] - - # Format addressable time - if addressable_time >= 1e9: - time_display = f"{addressable_time / 1e9:.2f}s" - elif addressable_time >= 1e6: - time_display = f"{addressable_time / 1e6:.1f}ms" - elif addressable_time >= 1e3: - time_display = f"{addressable_time / 1e3:.1f}µs" - else: - time_display = f"{addressable_time:.0f}ns" - - # Impact indicator - if i <= 5: - impact = "🔥" - impact_style = "bold red" - elif i <= 10: - impact = "⚡" - impact_style = "bold yellow" - else: - impact = "💡" - impact_style = "bold blue" - - table.add_row(f"#{i}", func_name, file_name, time_display, impact, style=impact_style if i <= 5 else None) - - console.print(table) - - if len(globally_ranked) > display_count: - console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]") - - def rank_all_functions_globally( - self, - file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], - trace_file_path: Path | None, - call_graph: DependencyResolver | None = None, - test_count_cache: dict[tuple[Path, str], int] | None = None, - ) -> list[tuple[Path, FunctionToOptimize]]: - """Rank all functions globally across all files based on trace data. - - This performs global ranking instead of per-file ranking, ensuring that - high-impact functions are optimized first regardless of which file they're in. - - Args: - file_to_funcs_to_optimize: Mapping of file paths to functions to optimize - trace_file_path: Path to trace file with performance data - - Returns: - List of (file_path, function) tuples in globally ranked order by addressable time. - If no trace file or ranking fails, returns functions in original file order. - - """ - all_functions: list[tuple[Path, FunctionToOptimize]] = [] - for file_path, functions in file_to_funcs_to_optimize.items(): - all_functions.extend((file_path, func) for func in functions) - - # If no trace file, rank by dependency count if call graph is available - if not trace_file_path or not trace_file_path.exists(): - if call_graph is not None: - return self.rank_by_dependency_count(all_functions, call_graph, test_count_cache=test_count_cache) - logger.debug("No trace file available, using original function order") - return all_functions - - try: - from codeflash.benchmarking.function_ranker import FunctionRanker, JavaFunctionRanker - - console.rule() - logger.info("loading|Ranking functions globally by performance impact...") - console.rule() - - # Extract just the functions for ranking (without file paths) - functions_only = [func for _, func in all_functions] - - # Detect if functions are Java and use appropriate ranker - if functions_only and functions_only[0].language == "java": - from codeflash.languages.java.jfr_parser import JfrProfile - - # JFR file is alongside the trace DB with .jfr extension - jfr_file_path = trace_file_path.with_suffix(".jfr") - if not jfr_file_path.exists(): - logger.warning(f"JFR file not found: {jfr_file_path}, falling back to original order") - return all_functions - - # Extract packages from file paths (e.g., src/main/java/com/example/Workload.java → "com.example") - packages = set() - for func in functions_only: - package = _extract_java_package_from_path(func.file_path) - if package: - # Use top two levels as filter prefix (e.g., "com.example" from "com.example.sub") - parts = package.split(".") - packages.add(".".join(parts[: min(2, len(parts))])) - - jfr_profile = JfrProfile(jfr_file_path, list(packages)) - ranker = JavaFunctionRanker(jfr_profile) - else: - # Python ranker with trace data - ranker = FunctionRanker(trace_file_path) - - # Rank globally - ranked_functions = ranker.rank_functions(functions_only) - - # Reconstruct with file paths by looking up original file for each ranked function - # Build reverse mapping: function -> file path - # Since FunctionToOptimize is unhashable (contains list), we compare by identity - func_to_file_map = {} - for file_path, func in all_functions: - # Use a tuple of unique identifiers as the key - key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) - func_to_file_map[key] = file_path - ranked_with_metadata: list[tuple[Path, FunctionToOptimize, float, int]] = [] - for rank_index, func in enumerate(ranked_functions): - key = (func.file_path, func.qualified_name, func.starting_line) - file_path = func_to_file_map.get(key) - if file_path: - ranked_with_metadata.append( - (file_path, func, ranker.get_function_addressable_time(func), rank_index) - ) - - if test_count_cache: - ranked_with_metadata.sort( - key=lambda item: (-item[2], -test_count_cache.get((item[0], item[1].qualified_name), 0), item[3]) - ) - - globally_ranked = [ - (file_path, func) for file_path, func, _addressable_time, _rank_index in ranked_with_metadata - ] - - console.rule() - logger.info( - f"Globally ranked {len(ranked_functions)} functions by addressable time " - f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)" - ) - - # Display ranking table for user visibility - self.display_global_ranking(globally_ranked, ranker) - console.rule() - - except Exception as e: - logger.warning(f"Could not perform global ranking: {e}") - logger.debug("Falling back to original function order") - return all_functions - else: - return globally_ranked - - def rank_by_dependency_count( - self, - all_functions: list[tuple[Path, FunctionToOptimize]], - call_graph: DependencyResolver, - test_count_cache: dict[tuple[Path, str], int] | None = None, - ) -> list[tuple[Path, FunctionToOptimize]]: - file_to_qns: dict[Path, set[str]] = defaultdict(set) - for file_path, func in all_functions: - file_to_qns[file_path].add(func.qualified_name) - callee_counts = call_graph.count_callees_per_function(dict(file_to_qns)) - self._cached_callee_counts = callee_counts - - if test_count_cache: - ranked = sorted( - enumerate(all_functions), - key=lambda x: ( - -callee_counts.get((x[1][0], x[1][1].qualified_name), 0), - -test_count_cache.get((x[1][0], x[1][1].qualified_name), 0), - x[0], - ), - ) - else: - ranked = sorted( - enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0]) - ) - logger.debug(f"Ranked {len(ranked)} functions by dependency count (most complex first)") - return [item for _, item in ranked] - def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint @@ -500,17 +278,17 @@ def run(self) -> None: logger.debug(f"Cleaning up {len(leftover_trace_files)} leftover trace file(s) from previous runs") cleanup_paths(leftover_trace_files) - cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root)) + self.plugin.cleanup_run(self.test_cfg.tests_root) function_optimizer = None file_to_funcs_to_optimize, num_optimizable_functions, trace_file_path = self.get_optimizable_functions() # Set language global singleton based on discovered functions if file_to_funcs_to_optimize: - for file_path, funcs in file_to_funcs_to_optimize.items(): + for funcs in file_to_funcs_to_optimize.values(): if funcs and funcs[0].language: set_current_language(funcs[0].language) - current_language_support().setup_test_config(self.test_cfg, file_path, self.current_worktree) + self.plugin.setup_test_config(self.test_cfg) break if self.args.all: @@ -527,20 +305,13 @@ def run(self) -> None: file_to_funcs_to_optimize, num_optimizable_functions ) - # Create a language-specific dependency resolver (e.g. Jedi-based call graph for Python) - # Skip in CI — the cache DB doesn't persist between runs on ephemeral runners - lang_support = current_language_support() - resolver = None - if lang_support and not env_utils.is_ci(): - resolver = lang_support.create_dependency_resolver(self.args.project_root) - - if resolver is not None and lang_support is not None and file_to_funcs_to_optimize: - supported_exts = lang_support.file_extensions - source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts] - with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress: - resolver.build_index(source_files, on_progress=on_progress) - console.rule() - call_graph_summary(resolver, file_to_funcs_to_optimize) + # Build call graph index via the Python plugin (skips CI internally) + if file_to_funcs_to_optimize: + source_files = [f for f in file_to_funcs_to_optimize if f.suffix == ".py"] + if source_files: + with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress: + self.plugin.build_index(source_files, on_progress=on_progress) + console.rule() optimizations_found: int = 0 self.test_cfg.concolic_test_root_dir = Path( @@ -559,7 +330,7 @@ def run(self) -> None: # Pre-compute test counts once for ranking and logging test_count_cache: dict[tuple[Path, str], int] if function_to_tests: - from codeflash.discovery.discover_unit_tests import existing_unit_test_count + from codeflash_python.discovery.discover_unit_tests import existing_unit_test_count test_count_cache = { (fp, fn.qualified_name): existing_unit_test_count(fn, self.args.project_root, function_to_tests) @@ -569,20 +340,22 @@ def run(self) -> None: else: test_count_cache = {} - # GLOBAL RANKING: Rank all functions together before optimizing - globally_ranked_functions = self.rank_all_functions_globally( - file_to_funcs_to_optimize, trace_file_path, call_graph=resolver, test_count_cache=test_count_cache + # GLOBAL RANKING: Rank all functions via the plugin (trace → dependency count → original order) + all_functions = [func for funcs in file_to_funcs_to_optimize.values() for func in funcs] + ranked_functions = self.plugin.rank_functions( + all_functions, trace_file=trace_file_path, test_counts=test_count_cache ) + globally_ranked_functions: list[tuple[Path, FunctionToOptimize]] = [ + (func.file_path, func) for func in ranked_functions + ] # Cache for module preparation (avoid re-parsing same files) prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {} - # Reuse callee counts from rank_by_dependency_count if available, otherwise compute - callee_counts = self._cached_callee_counts - if not callee_counts and resolver is not None: - file_to_qns: dict[Path, set[str]] = defaultdict(set) - for fp, fn in globally_ranked_functions: - file_to_qns[fp].add(fn.qualified_name) - callee_counts = resolver.count_callees_per_function(dict(file_to_qns)) + # Get dependency counts from the plugin (populated during rank_functions) + dep_counts = self.plugin.get_dependency_counts() + callee_counts: dict[tuple[Path, str], int] = { + (fp, fn.qualified_name): dep_counts.get(fn.qualified_name, 0) for fp, fn in globally_ranked_functions + } # Optimize functions in globally ranked order for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions): @@ -626,7 +399,7 @@ def run(self) -> None: function_to_optimize_source_code=validated_original_code[original_module_path].source_code, function_benchmark_timings=function_benchmark_timings, total_benchmark_timings=total_benchmark_timings, - call_graph=resolver, + call_graph=self.plugin.get_call_graph_index(), effort_override=effort_override, ) if function_optimizer is None: @@ -640,7 +413,7 @@ def run(self) -> None: self.functions_checkpoint.add_function_to_checkpoint( function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) ) - if is_successful(best_optimization): + if best_optimization.is_ok(): optimizations_found += 1 # create a diff patch for successful optimization if self.current_worktree and not is_subagent_mode(): @@ -659,7 +432,7 @@ def run(self) -> None: self.current_worktree, f"Optimizing {next_func.qualified_name}" ) else: - logger.warning(best_optimization.failure()) + logger.warning(best_optimization.error) console.rule() continue finally: @@ -691,8 +464,7 @@ def run(self) -> None: else: logger.warning("⚠️ Failed to send completion email. Status") finally: - if resolver is not None: - resolver.close() + self.plugin.cleanup_run(self.test_cfg.tests_root) if function_optimizer: function_optimizer.cleanup_generated_files() diff --git a/src/codeflash_python/plugin.py b/src/codeflash_python/plugin.py index e6be25fd3..0c0131fbf 100644 --- a/src/codeflash_python/plugin.py +++ b/src/codeflash_python/plugin.py @@ -63,6 +63,13 @@ def get_ai_client(self) -> AiServiceClient: self.ai_client = AiServiceClient() return self.ai_client + # -- setup ---------------------------------------------------------------- + + def setup_test_config(self, test_cfg: TestConfig) -> None: + from codeflash_python.verification.test_runner import setup_pytest_cmd + + setup_pytest_cmd(test_cfg.test_command) + # -- cleanup, comparison, environment validation -------------------------- def cleanup_run(self, tests_root: Path) -> None: diff --git a/tests/test_languages/test_base.py b/tests/test_languages/test_base.py index 96cd7ddd5..a8b316e9d 100644 --- a/tests/test_languages/test_base.py +++ b/tests/test_languages/test_base.py @@ -11,7 +11,6 @@ from codeflash.languages.base import ( CodeContext, FunctionFilterCriteria, - FunctionInfo, HelperFunction, Language, ParentInfo, @@ -19,6 +18,7 @@ TestResult, convert_parents_to_tuple, ) +from codeflash_core.models import FunctionToOptimize class TestLanguageEnum: @@ -89,12 +89,12 @@ def test_parent_info_hash(self): assert len(s) == 1 -class TestFunctionInfo: - """Tests for the FunctionInfo dataclass (alias for FunctionToOptimize).""" +class TestFunctionToOptimize: + """Tests for the FunctionToOptimize dataclass (alias for FunctionToOptimize).""" def test_function_info_creation_minimal(self): - """Test creating FunctionInfo with minimal args.""" - func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + """Test creating FunctionToOptimize with minimal args.""" + func = FunctionToOptimize(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.function_name == "add" assert func.file_path == Path("/test/example.py") assert func.starting_line == 1 @@ -105,9 +105,9 @@ def test_function_info_creation_minimal(self): assert func.language == "python" def test_function_info_creation_full(self): - """Test creating FunctionInfo with all args.""" + """Test creating FunctionToOptimize with all args.""" parents = [ParentInfo(name="Calculator", type="ClassDef")] - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test/example.py"), starting_line=10, @@ -126,20 +126,20 @@ def test_function_info_creation_full(self): assert func.starting_col == 4 assert func.ending_col == 20 - def test_function_info_frozen(self): - """Test that FunctionInfo is immutable.""" - func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) - with pytest.raises(AttributeError): - func.function_name = "new_name" + def test_function_info_mutable(self): + """Test that FunctionToOptimize fields can be reassigned (stdlib dataclass, not frozen).""" + func = FunctionToOptimize(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + func.function_name = "new_name" + assert func.function_name == "new_name" def test_qualified_name_no_parents(self): """Test qualified_name without parents.""" - func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.qualified_name == "add" def test_qualified_name_with_class(self): """Test qualified_name with class parent.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test/example.py"), starting_line=1, @@ -150,7 +150,7 @@ def test_qualified_name_with_class(self): def test_qualified_name_nested(self): """Test qualified_name with nested parents.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="inner", file_path=Path("/test/example.py"), starting_line=1, @@ -161,7 +161,7 @@ def test_qualified_name_nested(self): def test_class_name_with_class(self): """Test class_name property with class parent.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test/example.py"), starting_line=1, @@ -172,12 +172,12 @@ def test_class_name_with_class(self): def test_class_name_without_class(self): """Test class_name property without class parent.""" - func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.class_name is None def test_class_name_nested_function(self): """Test class_name for function nested in another function.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="inner", file_path=Path("/test/example.py"), starting_line=1, @@ -188,7 +188,7 @@ def test_class_name_nested_function(self): def test_class_name_method_in_nested_class(self): """Test class_name for method in nested class.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="method", file_path=Path("/test/example.py"), starting_line=1, @@ -200,12 +200,12 @@ def test_class_name_method_in_nested_class(self): def test_top_level_parent_name_no_parents(self): """Test top_level_parent_name without parents.""" - func = FunctionInfo(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=Path("/test/example.py"), starting_line=1, ending_line=3) assert func.top_level_parent_name == "add" def test_top_level_parent_name_with_parents(self): """Test top_level_parent_name with parents.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="method", file_path=Path("/test/example.py"), starting_line=1, @@ -216,7 +216,7 @@ def test_top_level_parent_name_with_parents(self): def test_function_info_str(self): """Test string representation.""" - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test/example.py"), starting_line=1, diff --git a/tests/test_languages/test_java/test_concurrency_analyzer.py b/tests/test_languages/test_java/test_concurrency_analyzer.py index 252b8a975..771f8f50b 100644 --- a/tests/test_languages/test_java/test_concurrency_analyzer.py +++ b/tests/test_languages/test_java/test_concurrency_analyzer.py @@ -3,9 +3,10 @@ import tempfile from pathlib import Path -from codeflash.languages.base import FunctionInfo + from codeflash.languages.java.concurrency_analyzer import JavaConcurrencyAnalyzer, analyze_function_concurrency from codeflash.languages.language_enum import Language +from codeflash_core.models import FunctionToOptimize class TestCompletableFutureDetection: @@ -25,7 +26,7 @@ def test_detect_completable_future(self): file_path = Path(tmpdir) / "AsyncService.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="fetchData", file_path=file_path, starting_line=2, @@ -59,7 +60,7 @@ def test_detect_completable_future_chain(self): file_path = Path(tmpdir) / "AsyncService.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="process", file_path=file_path, starting_line=2, @@ -98,7 +99,7 @@ def test_detect_parallel_stream(self): file_path = Path(tmpdir) / "DataProcessor.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="processData", file_path=file_path, starting_line=2, @@ -129,7 +130,7 @@ def test_detect_parallel_method(self): file_path = Path(tmpdir) / "DataProcessor.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="count", file_path=file_path, starting_line=2, @@ -165,7 +166,7 @@ def test_detect_executor_service(self): file_path = Path(tmpdir) / "TaskRunner.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="runTasks", file_path=file_path, starting_line=2, @@ -201,7 +202,7 @@ def test_detect_virtual_threads(self): file_path = Path(tmpdir) / "VirtualThreadExample.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="runWithVirtualThreads", file_path=file_path, starting_line=2, @@ -236,7 +237,7 @@ def test_detect_synchronized_method(self): file_path = Path(tmpdir) / "Counter.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="increment", file_path=file_path, starting_line=2, @@ -268,7 +269,7 @@ def test_detect_synchronized_block(self): file_path = Path(tmpdir) / "Counter.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="increment", file_path=file_path, starting_line=2, @@ -304,7 +305,7 @@ def test_detect_concurrent_hashmap(self): file_path = Path(tmpdir) / "Cache.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="put", file_path=file_path, starting_line=4, @@ -342,7 +343,7 @@ def test_detect_atomic_integer(self): file_path = Path(tmpdir) / "Counter.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="increment", file_path=file_path, starting_line=4, @@ -375,7 +376,7 @@ def test_non_concurrent_function(self): file_path = Path(tmpdir) / "Calculator.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=file_path, starting_line=2, @@ -412,7 +413,7 @@ def test_should_measure_throughput_for_async(self): file_path = Path(tmpdir) / "AsyncService.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="fetchData", file_path=file_path, starting_line=2, @@ -441,7 +442,7 @@ def test_should_not_measure_throughput_for_sync(self): file_path = Path(tmpdir) / "Calculator.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=file_path, starting_line=2, @@ -474,7 +475,7 @@ def test_suggestions_for_completable_future(self): file_path = Path(tmpdir) / "AsyncService.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="fetchData", file_path=file_path, starting_line=2, @@ -505,7 +506,7 @@ def test_suggestions_for_parallel_stream(self): file_path = Path(tmpdir) / "DataProcessor.java" file_path.write_text(source, encoding="utf-8") - func = FunctionInfo( + func = FunctionToOptimize( function_name="processData", file_path=file_path, starting_line=2, diff --git a/tests/test_languages/test_java/test_line_profiler.py b/tests/test_languages/test_java/test_line_profiler.py index 9a1e677e4..9f9cd21b8 100644 --- a/tests/test_languages/test_java/test_line_profiler.py +++ b/tests/test_languages/test_java/test_line_profiler.py @@ -15,6 +15,7 @@ format_line_profile_results, resolve_internal_class_name, ) +from codeflash_core.models import FunctionToOptimize class TestAgentConfigGeneration: @@ -22,7 +23,7 @@ class TestAgentConfigGeneration: def test_simple_method(self): """Test config generation for a simple method.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = """package com.example; @@ -34,7 +35,7 @@ def test_simple_method(self): } """ file_path = Path("/tmp/Calculator.java") - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=file_path, starting_line=4, @@ -76,7 +77,7 @@ def test_simple_method(self): def test_line_contents_extraction(self): """Test that line contents are extracted correctly.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = """public class Test { public void method() { @@ -87,7 +88,7 @@ def test_line_contents_extraction(self): } """ file_path = Path("/tmp/Test.java") - func = FunctionInfo( + func = FunctionToOptimize( function_name="method", file_path=file_path, starting_line=2, @@ -118,7 +119,7 @@ def test_line_contents_extraction(self): def test_multiple_functions(self): """Test config with multiple target functions.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = """public class Test { public void method1() { @@ -131,7 +132,7 @@ def test_multiple_functions(self): } """ file_path = Path("/tmp/Test.java") - func1 = FunctionInfo( + func1 = FunctionToOptimize( function_name="method1", file_path=file_path, starting_line=2, @@ -143,7 +144,7 @@ def test_multiple_functions(self): is_method=True, language=Language.JAVA, ) - func2 = FunctionInfo( + func2 = FunctionToOptimize( function_name="method2", file_path=file_path, starting_line=6, @@ -259,11 +260,11 @@ def test_custom_warmup_iterations(self): def test_warmup_disabled(self): """Test warmup can be disabled by setting to 0.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = "public class Test {\n public void method() {\n return;\n }\n}" file_path = Path("/tmp/Test.java") - func = FunctionInfo( + func = FunctionToOptimize( function_name="method", file_path=file_path, starting_line=2, @@ -288,11 +289,11 @@ def test_warmup_disabled(self): def test_warmup_in_config_json(self): """Test that warmupIterations appears in the generated config JSON.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = "package com.example;\npublic class Calc {\n public int add(int a, int b) {\n return a + b;\n }\n}" file_path = Path("/tmp/Calc.java") - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=file_path, starting_line=3, @@ -321,7 +322,7 @@ class TestAgentConfigBoundaryConditions: def test_start_line_beyond_end_line(self): """When starting_line > ending_line, no lines are extracted but config is still valid.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = "public class Test {\n public void foo() { return; }\n}\n" file_path = Path("/tmp/Test.java") @@ -330,7 +331,7 @@ def test_start_line_beyond_end_line(self): output_file = Path(tmpdir) / "profile.json" config_path = Path(tmpdir) / "config.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="foo", file_path=file_path, starting_line=5, @@ -354,7 +355,7 @@ def test_start_line_beyond_end_line(self): def test_line_numbers_beyond_source_length(self): """Line numbers beyond the source length are silently skipped.""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = "public class Test {\n public void foo() { return; }\n}\n" file_path = Path("/tmp/Test.java") @@ -363,7 +364,7 @@ def test_line_numbers_beyond_source_length(self): output_file = Path(tmpdir) / "profile.json" config_path = Path(tmpdir) / "config.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="foo", file_path=file_path, starting_line=100, @@ -396,7 +397,7 @@ def test_line_numbers_beyond_source_length(self): def test_negative_line_numbers(self): """Negative line numbers produce no line contents (range is empty or out of bounds).""" - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language source = "public class Test {\n public void foo() { return; }\n}\n" file_path = Path("/tmp/Test.java") @@ -405,7 +406,7 @@ def test_negative_line_numbers(self): output_file = Path(tmpdir) / "profile.json" config_path = Path(tmpdir) / "config.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="foo", file_path=file_path, starting_line=-5, diff --git a/tests/test_languages/test_java/test_line_profiler_integration.py b/tests/test_languages/test_java/test_line_profiler_integration.py index 9ffd095b3..4840b1a08 100644 --- a/tests/test_languages/test_java/test_line_profiler_integration.py +++ b/tests/test_languages/test_java/test_line_profiler_integration.py @@ -9,9 +9,10 @@ import pytest -from codeflash.languages.base import FunctionInfo, Language +from codeflash.languages.base import Language from codeflash.languages.java.line_profiler import DEFAULT_WARMUP_ITERATIONS, JavaLineProfiler, find_agent_jar from codeflash.languages.java.support import get_java_support +from codeflash_core.models import FunctionToOptimize class TestLineProfilerInstrumentation: @@ -35,7 +36,7 @@ def test_instrument_with_package(self): profile_output = tmppath / "profile.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=java_file, starting_line=4, @@ -114,7 +115,7 @@ def test_instrument_without_package(self): profile_output = tmppath / "profile.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="sort", file_path=java_file, starting_line=2, @@ -201,7 +202,7 @@ def test_instrument_multiple_methods(self): profile_output = tmppath / "profile.json" - func_reverse = FunctionInfo( + func_reverse = FunctionToOptimize( function_name="reverse", file_path=java_file, starting_line=2, @@ -213,7 +214,7 @@ def test_instrument_multiple_methods(self): is_method=True, language=Language.JAVA, ) - func_palindrome = FunctionInfo( + func_palindrome = FunctionToOptimize( function_name="isPalindrome", file_path=java_file, starting_line=16, @@ -302,7 +303,7 @@ def test_instrument_nested_package(self): profile_output = tmppath / "profile.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="isEmpty", file_path=java_file, starting_line=4, @@ -374,7 +375,7 @@ def test_instrument_verifies_line_contents(self): profile_output = tmppath / "profile.json" - func = FunctionInfo( + func = FunctionToOptimize( function_name="fib", file_path=java_file, starting_line=2, @@ -439,7 +440,7 @@ def run_spin_timer_profiled(tmppath: Path, spin_durations_ns: list[int]) -> dict profile_output = tmppath / "profile.json" config_path = profile_output.with_suffix(".config.json") - func = FunctionInfo( + func = FunctionToOptimize( function_name="spinWait", file_path=java_file, starting_line=2, diff --git a/tests/test_languages/test_java_e2e.py b/tests/test_languages/test_java_e2e.py index bce2f64f2..ef18fbd29 100644 --- a/tests/test_languages/test_java_e2e.py +++ b/tests/test_languages/test_java_e2e.py @@ -14,6 +14,7 @@ from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language from codeflash.languages.base import Language +from codeflash_core.models import FunctionToOptimize class TestJavaFunctionDiscovery: @@ -128,7 +129,7 @@ class TestJavaCodeReplacement: def test_replace_method_in_java_file(self): """Test replacing a method in a Java file.""" from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, Language, ParentInfo + from codeflash.languages.base import Language, ParentInfo original_source = """package com.example; @@ -150,8 +151,8 @@ def test_replace_method_in_java_file(self): java_support = get_language_support(Language.JAVA) - # Create FunctionInfo for the add method with parent class - func_info = FunctionInfo( + # Create FunctionToOptimize for the add method with parent class + func_info = FunctionToOptimize( function_name="add", file_path=Path("/tmp/Calculator.java"), starting_line=4, @@ -182,7 +183,7 @@ def java_project_dir(self): def test_discover_junit_tests(self, java_project_dir): """Test discovering JUnit tests for Java methods.""" from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, Language, ParentInfo + from codeflash.languages.base import Language, ParentInfo java_support = get_language_support(Language.JAVA) test_root = java_project_dir / "src" / "test" / "java" @@ -190,9 +191,9 @@ def test_discover_junit_tests(self, java_project_dir): if not test_root.exists(): pytest.skip("test directory not found") - # Create FunctionInfo for bubbleSort method with parent class + # Create FunctionToOptimize for bubbleSort method with parent class sort_file = java_project_dir / "src" / "main" / "java" / "com" / "example" / "BubbleSort.java" - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="bubbleSort", file_path=sort_file, starting_line=14, diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index c5bb722bc..758872286 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -16,6 +16,7 @@ import pytest from codeflash.languages.base import Language +from codeflash_core.models import FunctionToOptimize def skip_if_js_not_supported(): @@ -154,7 +155,7 @@ def test_replace_function_in_javascript_file(self): """Test replacing a function in a JavaScript file.""" skip_if_js_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo + original_source = """ export function add(a, b) { @@ -173,7 +174,7 @@ def test_replace_function_in_javascript_file(self): js_support = get_language_support(Language.JAVASCRIPT) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="add", file_path=Path("/tmp/test.js"), starting_line=2, ending_line=4, language="javascript" ) @@ -208,7 +209,7 @@ def test_discover_jest_tests(self, js_project_dir): """Test discovering Jest tests for JavaScript functions.""" skip_if_js_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo + js_support = get_language_support(Language.JAVASCRIPT) test_root = js_project_dir / "tests" @@ -217,7 +218,7 @@ def test_discover_jest_tests(self, js_project_dir): pytest.skip("tests directory not found") fib_file = js_project_dir / "fibonacci.js" - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="javascript" ) diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index 114af8e47..df32455d1 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -9,7 +9,7 @@ from pathlib import Path from codeflash_core.models import FunctionToOptimize -from codeflash.languages.base import FunctionInfo, Language +from codeflash.languages.base import Language from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler from codeflash.languages.javascript.tracer import JavaScriptTracer from codeflash.models.models import FunctionParent @@ -60,7 +60,7 @@ def test_line_profiler_instruments_simple_function(self): f.flush() file_path = Path(f.name) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="add", file_path=file_path, starting_line=2, ending_line=5, language="javascript" ) @@ -121,7 +121,7 @@ def test_tracer_instruments_simple_function(self): f.flush() file_path = Path(f.name) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="multiply", file_path=file_path, starting_line=2, ending_line=4, language="javascript" ) @@ -165,7 +165,7 @@ def test_javascript_support_instrument_for_behavior(self): f.flush() file_path = Path(f.name) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="greet", file_path=file_path, starting_line=2, ending_line=4, language="javascript" ) @@ -196,7 +196,7 @@ def test_javascript_support_instrument_for_line_profiling(self): f.flush() file_path = Path(f.name) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="square", file_path=file_path, starting_line=2, ending_line=5, language="javascript" ) diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 5d5943151..f63fefca2 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -9,7 +9,7 @@ import pytest -from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo +from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo from codeflash.languages.javascript.support import JavaScriptSupport @@ -316,7 +316,7 @@ def test_replace_simple_function(self, js_support): return a * b; } """ - func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) new_code = """function add(a, b) { // Optimized return (a + b) | 0; @@ -343,7 +343,7 @@ def test_replace_preserves_surrounding_code(self, js_support): // Footer """ - func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) + func = FunctionToOptimize(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) new_code = """function target() { return 42; } @@ -364,7 +364,7 @@ def test_replace_with_indentation_adjustment(self, js_support): } } """ - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test.js"), starting_line=2, @@ -391,7 +391,7 @@ def test_replace_arrow_function(self, js_support): const multiply = (x, y) => x * y; """ - func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) new_code = """const add = (a, b) => { return (a + b) | 0; }; @@ -507,7 +507,7 @@ def test_extract_simple_function(self, js_support): f.flush() file_path = Path(f.name) - func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=3) + func = FunctionToOptimize(function_name="add", file_path=file_path, starting_line=1, ending_line=3) context = js_support.extract_code_context(func, file_path.parent, file_path.parent) @@ -950,7 +950,7 @@ def test_replace_class_method_preserves_class_structure(self, js_support): } } """ - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test.js"), starting_line=2, @@ -991,7 +991,7 @@ def test_replace_class_method_with_jsdoc(self, js_support): } } """ - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test.js"), starting_line=5, # Method starts here @@ -1028,7 +1028,7 @@ def test_replace_multiple_class_methods_sequentially(self, js_support): } """ # Replace add first - add_func = FunctionInfo( + add_func = FunctionToOptimize( function_name="add", file_path=Path("/test.js"), starting_line=2, @@ -1060,7 +1060,7 @@ def test_replace_class_method_indentation_adjustment(self, js_support): } } """ - func = FunctionInfo( + func = FunctionToOptimize( function_name="innerMethod", file_path=Path("/test.js"), starting_line=2, diff --git a/tests/test_languages/test_language_parity.py b/tests/test_languages/test_language_parity.py index 2747e6892..b5c936786 100644 --- a/tests/test_languages/test_language_parity.py +++ b/tests/test_languages/test_language_parity.py @@ -14,7 +14,7 @@ import pytest -from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo +from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo from codeflash.languages.javascript.support import JavaScriptSupport from codeflash.languages.python.support import PythonSupport @@ -595,8 +595,8 @@ def multiply(a, b): return a * b; } """ - py_func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) - js_func = FunctionInfo(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) + py_func = FunctionToOptimize(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) + js_func = FunctionToOptimize(function_name="add", file_path=Path("/test.js"), starting_line=1, ending_line=3) py_new = """def add(a, b): return (a + b) | 0 @@ -642,8 +642,8 @@ def other(): // Footer """ - py_func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) - js_func = FunctionInfo(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) + py_func = FunctionToOptimize(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) + js_func = FunctionToOptimize(function_name="target", file_path=Path("/test.js"), starting_line=4, ending_line=6) py_new = """def target(): return 42 @@ -683,14 +683,14 @@ def add(self, a, b): } } """ - py_func = FunctionInfo( + py_func = FunctionToOptimize( function_name="add", file_path=Path("/test.py"), starting_line=2, ending_line=3, parents=[ParentInfo(name="Calculator", type="ClassDef")], ) - js_func = FunctionInfo( + js_func = FunctionToOptimize( function_name="add", file_path=Path("/test.js"), starting_line=2, @@ -863,8 +863,8 @@ def test_simple_function_context(self, python_support, js_support): ".js", ) - py_func = FunctionInfo(function_name="add", file_path=py_file, starting_line=1, ending_line=2) - js_func = FunctionInfo(function_name="add", file_path=js_file, starting_line=1, ending_line=3) + py_func = FunctionToOptimize(function_name="add", file_path=py_file, starting_line=1, ending_line=2) + js_func = FunctionToOptimize(function_name="add", file_path=js_file, starting_line=1, ending_line=3) py_context = python_support.extract_code_context(py_func, py_file.parent, py_file.parent) js_context = js_support.extract_code_context(js_func, js_file.parent, js_file.parent) @@ -956,7 +956,7 @@ class TestFeatureGaps: """Tests to detect gaps in JavaScript implementation vs Python.""" def test_function_info_fields_populated(self, python_support, js_support): - """Both should populate all FunctionInfo fields consistently.""" + """Both should populate all FunctionToOptimize fields consistently.""" py_file = write_temp_file(CLASS_METHODS.python, ".py") js_file = write_temp_file(CLASS_METHODS.javascript, ".js") diff --git a/tests/test_languages/test_python_support.py b/tests/test_languages/test_python_support.py index ee66519d3..8eb8d7261 100644 --- a/tests/test_languages/test_python_support.py +++ b/tests/test_languages/test_python_support.py @@ -9,7 +9,7 @@ import pytest -from codeflash.languages.base import FunctionFilterCriteria, FunctionInfo, Language, ParentInfo +from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo from codeflash.languages.python.support import PythonSupport @@ -262,7 +262,7 @@ def test_replace_simple_function(self, python_support): def multiply(a, b): return a * b """ - func = FunctionInfo(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) + func = FunctionToOptimize(function_name="add", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def add(a, b): # Optimized return (a + b) | 0 @@ -286,7 +286,7 @@ def other(): # Footer """ - func = FunctionInfo(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) + func = FunctionToOptimize(function_name="target", file_path=Path("/test.py"), starting_line=4, ending_line=5) new_code = """def target(): return 42 """ @@ -304,7 +304,7 @@ def test_replace_with_indentation_adjustment(self, python_support): def add(self, a, b): return a + b """ - func = FunctionInfo( + func = FunctionToOptimize( function_name="add", file_path=Path("/test.py"), starting_line=2, @@ -330,7 +330,7 @@ def test_replace_first_function(self, python_support): def second(): return 2 """ - func = FunctionInfo(function_name="first", file_path=Path("/test.py"), starting_line=1, ending_line=2) + func = FunctionToOptimize(function_name="first", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def first(): return 100 """ @@ -347,7 +347,7 @@ def test_replace_last_function(self, python_support): def last(): return 999 """ - func = FunctionInfo(function_name="last", file_path=Path("/test.py"), starting_line=4, ending_line=5) + func = FunctionToOptimize(function_name="last", file_path=Path("/test.py"), starting_line=4, ending_line=5) new_code = """def last(): return 1000 """ @@ -361,7 +361,7 @@ def test_replace_only_function(self, python_support): source = """def only(): return 42 """ - func = FunctionInfo(function_name="only", file_path=Path("/test.py"), starting_line=1, ending_line=2) + func = FunctionToOptimize(function_name="only", file_path=Path("/test.py"), starting_line=1, ending_line=2) new_code = """def only(): return 100 """ @@ -473,7 +473,7 @@ def test_extract_simple_function(self, python_support): f.flush() file_path = Path(f.name) - func = FunctionInfo(function_name="add", file_path=file_path, starting_line=1, ending_line=2) + func = FunctionToOptimize(function_name="add", file_path=file_path, starting_line=1, ending_line=2) context = python_support.extract_code_context(func, file_path.parent, file_path.parent) diff --git a/tests/test_languages/test_typescript_e2e.py b/tests/test_languages/test_typescript_e2e.py index 432b1b7ef..5a4e1761f 100644 --- a/tests/test_languages/test_typescript_e2e.py +++ b/tests/test_languages/test_typescript_e2e.py @@ -17,6 +17,7 @@ import pytest from codeflash.languages.base import Language +from codeflash_core.models import FunctionToOptimize def skip_if_ts_not_supported(): @@ -157,7 +158,7 @@ def test_replace_function_in_typescript_file(self): """Test replacing a function in a TypeScript file.""" skip_if_ts_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo + original_source = """ function add(a: number, b: number): number { @@ -176,7 +177,7 @@ def test_replace_function_in_typescript_file(self): ts_support = get_language_support(Language.TYPESCRIPT) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="add", file_path=Path("/tmp/test.ts"), starting_line=2, ending_line=4, language="typescript" ) @@ -198,7 +199,7 @@ def test_replace_function_preserves_types(self): """Test that replacing a function preserves TypeScript type annotations.""" skip_if_ts_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo + original_source = r""" interface Config { @@ -219,7 +220,7 @@ def test_replace_function_preserves_types(self): ts_support = get_language_support(Language.TYPESCRIPT) - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="processConfig", file_path=Path("/tmp/test.ts"), starting_line=7, @@ -251,7 +252,7 @@ def test_discover_vitest_tests_for_typescript(self, ts_project_dir): """Test discovering Vitest tests for TypeScript functions.""" skip_if_ts_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo + ts_support = get_language_support(Language.TYPESCRIPT) test_root = ts_project_dir / "tests" @@ -260,7 +261,7 @@ def test_discover_vitest_tests_for_typescript(self, ts_project_dir): pytest.skip("tests directory not found") fib_file = ts_project_dir / "fibonacci.ts" - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="fibonacci", file_path=fib_file, starting_line=1, ending_line=7, language="typescript" ) diff --git a/tests/test_languages/test_vitest_e2e.py b/tests/test_languages/test_vitest_e2e.py index bdc8a8a80..94dee4b88 100644 --- a/tests/test_languages/test_vitest_e2e.py +++ b/tests/test_languages/test_vitest_e2e.py @@ -15,6 +15,7 @@ import pytest from codeflash.code_utils.config_js import detect_test_runner, get_package_json_data +from codeflash_core.models import FunctionToOptimize def skip_if_js_not_supported(): @@ -169,7 +170,7 @@ def test_discover_vitest_tests(self, vitest_project_dir): """Test discovering Vitest tests for TypeScript functions.""" skip_if_js_not_supported() from codeflash.languages import get_language_support - from codeflash.languages.base import FunctionInfo, Language + from codeflash.languages.base import Language ts_support = get_language_support(Language.TYPESCRIPT) test_root = vitest_project_dir / "tests" @@ -178,7 +179,7 @@ def test_discover_vitest_tests(self, vitest_project_dir): pytest.skip("tests directory not found") fib_file = vitest_project_dir / "fibonacci.ts" - func_info = FunctionInfo( + func_info = FunctionToOptimize( function_name="fibonacci", file_path=fib_file, starting_line=11, ending_line=16, language="typescript" ) diff --git a/tests/test_ranking_boost.py b/tests/test_ranking_boost.py index 8e14d7dc6..77d59eaaa 100644 --- a/tests/test_ranking_boost.py +++ b/tests/test_ranking_boost.py @@ -1,16 +1,13 @@ from __future__ import annotations -from argparse import Namespace from pathlib import Path -from unittest.mock import patch import pytest from codeflash.discovery.discover_unit_tests import existing_unit_test_count -from codeflash_core.models import FunctionToOptimize from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile from codeflash.models.test_type import TestType -from codeflash.optimization.optimizer import Optimizer +from codeflash_core.models import FunctionToOptimize def make_func(name: str, project_root: Path) -> FunctionToOptimize: @@ -26,25 +23,6 @@ def make_test(test_type: TestType, test_name: str = "test_something") -> Functio ) -def build_test_count_cache( - funcs: list[FunctionToOptimize], project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]] -) -> dict[tuple[Path, str], int]: - return { - (func.file_path, func.qualified_name): existing_unit_test_count(func, project_root, function_to_tests) - for func in funcs - } - - -def make_optimizer(project_root: Path) -> Optimizer: - def _noop_display_global_ranking(*_args: object, **_kwargs: object) -> None: - return None - - optimizer = Optimizer.__new__(Optimizer) - optimizer.args = Namespace(project_root=project_root) - optimizer.display_global_ranking = _noop_display_global_ranking - return optimizer - - @pytest.fixture def project_root(tmp_path: Path) -> Path: root = tmp_path / "project" @@ -158,127 +136,3 @@ def test_parametrized_tests_deduplication(project_root: Path) -> None: assert existing_unit_test_count(func, project_root, tests) == 2 -def test_trace_ranking_keeps_addressable_time_primary_over_test_count(project_root: Path, tmp_path: Path) -> None: - optimizer = make_optimizer(project_root) - funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] - trace_file = tmp_path / "trace.db" - trace_file.touch() - - ranked_functions = [funcs[0], funcs[1], funcs[2]] - addressable_times = {"foo": 100.0, "bar": 20.0, "baz": 5.0} - function_to_tests: dict[str, set[FunctionCalledInTest]] = { - funcs[1].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one"), - make_test(TestType.EXISTING_UNIT_TEST, "test_two"), - make_test(TestType.EXISTING_UNIT_TEST, "test_three"), - } - } - - class FakeRanker: - def __init__(self, _trace_file: Path) -> None: - pass - - def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: - return ranked_functions - - def get_function_addressable_time(self, function: FunctionToOptimize) -> float: - return addressable_times[function.function_name] - - with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): - ranked = optimizer.rank_all_functions_globally( - {project_root / "mod.py": funcs}, - trace_file, - test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), - ) - - assert [func.function_name for _, func in ranked] == ["foo", "bar", "baz"] - - -def test_trace_ranking_uses_test_count_as_tiebreaker(project_root: Path, tmp_path: Path) -> None: - optimizer = make_optimizer(project_root) - funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] - trace_file = tmp_path / "trace.db" - trace_file.touch() - - ranked_functions = [funcs[0], funcs[1], funcs[2]] - addressable_times = {"foo": 100.0, "bar": 100.0, "baz": 5.0} - function_to_tests: dict[str, set[FunctionCalledInTest]] = { - funcs[0].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one") - }, - funcs[1].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one"), - make_test(TestType.EXISTING_UNIT_TEST, "test_two"), - make_test(TestType.EXISTING_UNIT_TEST, "test_three"), - }, - } - - class FakeRanker: - def __init__(self, _trace_file: Path) -> None: - pass - - def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: - return ranked_functions - - def get_function_addressable_time(self, function: FunctionToOptimize) -> float: - return addressable_times[function.function_name] - - with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): - ranked = optimizer.rank_all_functions_globally( - {project_root / "mod.py": funcs}, - trace_file, - test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), - ) - - assert [func.function_name for _, func in ranked] == ["bar", "foo", "baz"] - - -def test_dependency_count_ranking_keeps_callee_count_primary(project_root: Path) -> None: - optimizer = make_optimizer(project_root) - funcs = [make_func(name, project_root) for name in ("foo", "bar")] - function_to_tests: dict[str, set[FunctionCalledInTest]] = { - funcs[1].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one"), - make_test(TestType.EXISTING_UNIT_TEST, "test_two"), - make_test(TestType.EXISTING_UNIT_TEST, "test_three"), - } - } - - class FakeResolver: - def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: - return {(project_root / "mod.py", "foo"): 5, (project_root / "mod.py", "bar"): 1} - - ranked = optimizer.rank_by_dependency_count( - [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], - FakeResolver(), - test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), - ) - - assert [func.function_name for _, func in ranked] == ["foo", "bar"] - - -def test_dependency_count_ranking_uses_test_count_as_tiebreaker(project_root: Path) -> None: - optimizer = make_optimizer(project_root) - funcs = [make_func(name, project_root) for name in ("foo", "bar")] - function_to_tests: dict[str, set[FunctionCalledInTest]] = { - funcs[0].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one") - }, - funcs[1].qualified_name_with_modules_from_root(project_root): { - make_test(TestType.EXISTING_UNIT_TEST, "test_one"), - make_test(TestType.EXISTING_UNIT_TEST, "test_two"), - make_test(TestType.EXISTING_UNIT_TEST, "test_three"), - }, - } - - class FakeResolver: - def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: - return {(project_root / "mod.py", "foo"): 2, (project_root / "mod.py", "bar"): 2} - - ranked = optimizer.rank_by_dependency_count( - [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], - FakeResolver(), - test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests), - ) - - assert [func.function_name for _, func in ranked] == ["bar", "foo"] From 61b3f6a985a38ade61f2af0a22aa5c8a733f4b92 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 07:47:12 -0500 Subject: [PATCH 3/9] fix: add missing version.py to codeflash_python package --- src/codeflash_python/version.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/codeflash_python/version.py diff --git a/src/codeflash_python/version.py b/src/codeflash_python/version.py new file mode 100644 index 000000000..db671620c --- /dev/null +++ b/src/codeflash_python/version.py @@ -0,0 +1,3 @@ +from codeflash.version import __version__ + +__all__ = ["__version__"] From b2f9f22e774555c82ae85f50e600eb78fe288d22 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 07:53:16 -0500 Subject: [PATCH 4/9] fix: pass TestConfig to generate_tests instead of Path The codeflash_python verifier's generate_tests expected a Path but callers pass TestConfig. Match the old verifier's signature. --- src/codeflash_python/verification/verifier.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/codeflash_python/verification/verifier.py b/src/codeflash_python/verification/verifier.py index c03b6e5d0..363103313 100644 --- a/src/codeflash_python/verification/verifier.py +++ b/src/codeflash_python/verification/verifier.py @@ -11,6 +11,7 @@ from codeflash_python.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main if TYPE_CHECKING: + from codeflash_core.config import TestConfig from codeflash_core.models import FunctionToOptimize from codeflash_python.api.aiservice import AiServiceClient @@ -24,7 +25,7 @@ def generate_tests( function_to_optimize: FunctionToOptimize, helper_function_names: list[str], module_path: Path, - test_cfg_project_root: Path, + test_cfg: TestConfig, test_timeout: int, function_trace_id: str, test_index: int, @@ -38,7 +39,7 @@ def generate_tests( the returned test strings. """ start_time = time.perf_counter() - test_module_path = Path(module_name_from_file_path(test_path, test_cfg_project_root)) + test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir)) response = aiservice_client.generate_regression_tests( source_code_being_tested=source_code_being_tested, From 08ba2913120d1e08fedd2a90954c77b621d351fc Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 07:56:41 -0500 Subject: [PATCH 5/9] style: remove stale type: ignore comment in test_generation.py --- src/codeflash_python/optimizer_mixins/test_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codeflash_python/optimizer_mixins/test_generation.py b/src/codeflash_python/optimizer_mixins/test_generation.py index 97ecae2ae..11d3a32bb 100644 --- a/src/codeflash_python/optimizer_mixins/test_generation.py +++ b/src/codeflash_python/optimizer_mixins/test_generation.py @@ -123,7 +123,7 @@ def submit_test_generation_tasks( self.function_to_optimize, helper_function_names, Path(self.original_module_path), - self.test_cfg, # type: ignore[arg-type] + self.test_cfg, INDIVIDUAL_TESTCASE_TIMEOUT, self.function_trace_id, test_index, From 0166a4999f354abcb83a8b9d924274d0f8d077e9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 08:05:17 -0500 Subject: [PATCH 6/9] fix: use codeflash_python AiServiceClient to avoid cross-package type mismatch The old AiServiceClient creates OptimizedCandidate from codeflash.models but the new OptimizationSet expects them from codeflash_python.models. --- codeflash/optimization/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 2088f55e0..c9a6db6ad 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import TYPE_CHECKING -from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email from codeflash.cli_cmds.console import call_graph_live_display, console, logger, progress_bar from codeflash.code_utils import env_utils @@ -27,6 +26,7 @@ from codeflash.lsp.helpers import is_subagent_mode from codeflash.telemetry.posthog_cf import ph from codeflash_core.config import TestConfig +from codeflash_python.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash_python.plugin import PythonPlugin if TYPE_CHECKING: From 5ac82b44eee53b1b071e6783573d7bf133324788 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 08:12:37 -0500 Subject: [PATCH 7/9] refactor: remove duplicate models.py from codeflash_python MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete src/codeflash_python/models/models.py and update all 53 files to import from codeflash.models.models — single source of truth. --- codeflash/optimization/optimizer.py | 2 +- src/codeflash_python/api/aiservice.py | 4 +- .../api/aiservice_optimize.py | 4 +- src/codeflash_python/api/types.py | 2 +- .../benchmarking/plugin/plugin.py | 6 +- src/codeflash_python/benchmarking/utils.py | 4 +- src/codeflash_python/context/ast_helpers.py | 7 +- .../context/call_graph_index.py | 2 +- .../context/class_extraction.py | 2 +- .../context/code_context_extractor.py | 8 +- src/codeflash_python/context/jedi_helpers.py | 8 +- .../context/type_extraction.py | 2 +- src/codeflash_python/context/types.py | 2 +- .../context/unused_helper_detection.py | 4 +- .../discovery/discover_unit_tests.py | 2 +- .../discovery/function_filtering.py | 2 +- .../discovery/import_analyzer.py | 2 +- src/codeflash_python/discovery/tests_cache.py | 2 +- src/codeflash_python/function_optimizer.py | 16 +- src/codeflash_python/models/call_graph.py | 4 +- src/codeflash_python/models/models.py | 817 ------------------ .../optimization/optimizer.py | 2 +- src/codeflash_python/optimizer.py | 2 +- .../optimizer_mixins/_protocol.py | 14 +- .../optimizer_mixins/baseline.py | 4 +- .../optimizer_mixins/candidate_evaluation.py | 20 +- .../optimizer_mixins/candidate_structures.py | 2 +- .../optimizer_mixins/code_replacement.py | 2 +- .../optimizer_mixins/refinement.py | 4 +- .../optimizer_mixins/result_processing.py | 2 +- .../optimizer_mixins/test_execution.py | 8 +- .../optimizer_mixins/test_generation.py | 8 +- .../optimizer_mixins/test_review.py | 4 +- src/codeflash_python/plugin.py | 2 +- src/codeflash_python/plugin_helpers.py | 4 +- src/codeflash_python/plugin_results.py | 2 +- src/codeflash_python/result/create_pr.py | 2 +- src/codeflash_python/result/critic.py | 7 +- src/codeflash_python/result/explanation.py | 2 +- src/codeflash_python/result/pr_comment.py | 2 +- .../static_analysis/code_replacer.py | 2 +- .../static_analysis/code_replacer_base.py | 2 +- .../static_analysis/coverage_utils.py | 2 +- .../static_analysis/import_analysis.py | 2 +- .../static_analysis/line_profile_utils.py | 2 +- .../verification/async_instrumentation.py | 2 +- .../verification/coverage_utils.py | 4 +- .../verification/edit_generated_tests.py | 4 +- .../verification/equivalence.py | 4 +- .../verification/instrument_existing_tests.py | 4 +- .../verification/parse_test_output.py | 4 +- .../verification/parse_xml.py | 4 +- .../verification/test_output_utils.py | 2 +- .../verification/test_runner.py | 2 +- .../verification/wrapper_generation.py | 2 +- 55 files changed, 111 insertions(+), 926 deletions(-) delete mode 100644 src/codeflash_python/models/models.py diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index c9a6db6ad..55fd371db 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -34,10 +34,10 @@ from argparse import Namespace from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint + from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode from codeflash_core.models import FunctionToOptimize from codeflash_python.context.types import DependencyResolver from codeflash_python.function_optimizer import FunctionOptimizer - from codeflash_python.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode class Optimizer: diff --git a/src/codeflash_python/api/aiservice.py b/src/codeflash_python/api/aiservice.py index a862aa414..3bb881bc6 100644 --- a/src/codeflash_python/api/aiservice.py +++ b/src/codeflash_python/api/aiservice.py @@ -9,15 +9,15 @@ import requests from pydantic.json import pydantic_encoder +from codeflash.models.models import CodeStringsMarkdown, OptimizedCandidate from codeflash_python.api.aiservice_optimize import AiServiceOptimizeMixin from codeflash_python.api.aiservice_results import AiServiceResultsMixin from codeflash_python.api.aiservice_testgen import AiServiceTestgenMixin from codeflash_python.code_utils.config_consts import PYTHON_LANGUAGE_VERSION from codeflash_python.code_utils.env_utils import get_codeflash_api_key -from codeflash_python.models.models import CodeStringsMarkdown, OptimizedCandidate if TYPE_CHECKING: - from codeflash_python.models.models import OptimizedCandidateSource + from codeflash.models.models import OptimizedCandidateSource logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/api/aiservice_optimize.py b/src/codeflash_python/api/aiservice_optimize.py index c91e15bea..eb2dce443 100644 --- a/src/codeflash_python/api/aiservice_optimize.py +++ b/src/codeflash_python/api/aiservice_optimize.py @@ -8,20 +8,20 @@ import requests +from codeflash.models.models import OptimizedCandidateSource from codeflash_python.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name from codeflash_python.code_utils.time_utils import humanize_runtime -from codeflash_python.models.models import OptimizedCandidateSource from codeflash_python.telemetry.posthog_cf import ph from codeflash_python.version import __version__ as codeflash_version if TYPE_CHECKING: + from codeflash.models.models import OptimizedCandidate from codeflash_python.api.types import ( AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, AIServiceRefinerRequest, ) from codeflash_python.models.experiment_metadata import ExperimentMetadata - from codeflash_python.models.models import OptimizedCandidate else: _Base = object diff --git a/src/codeflash_python/api/types.py b/src/codeflash_python/api/types.py index 027c46620..2cc269399 100644 --- a/src/codeflash_python/api/types.py +++ b/src/codeflash_python/api/types.py @@ -5,7 +5,7 @@ from pydantic.dataclasses import dataclass -from codeflash_python.models.models import OptimizedCandidateSource +from codeflash.models.models import OptimizedCandidateSource @dataclass(frozen=True) diff --git a/src/codeflash_python/benchmarking/plugin/plugin.py b/src/codeflash_python/benchmarking/plugin/plugin.py index 48866e627..e16b452e2 100644 --- a/src/codeflash_python/benchmarking/plugin/plugin.py +++ b/src/codeflash_python/benchmarking/plugin/plugin.py @@ -14,7 +14,7 @@ from codeflash_python.code_utils.code_utils import module_name_from_file_path if TYPE_CHECKING: - from codeflash_python.models.models import BenchmarkKey + from codeflash.models.models import BenchmarkKey PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None @@ -85,7 +85,7 @@ def close(self) -> None: @staticmethod def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]: - from codeflash_python.models.models import BenchmarkKey + from codeflash.models.models import BenchmarkKey """Process the trace file and extract timing data for all functions. @@ -147,7 +147,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark @staticmethod def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: - from codeflash_python.models.models import BenchmarkKey + from codeflash.models.models import BenchmarkKey """Extract total benchmark timings from trace files. diff --git a/src/codeflash_python/benchmarking/utils.py b/src/codeflash_python/benchmarking/utils.py index 4fe58f893..ddc0b50b9 100644 --- a/src/codeflash_python/benchmarking/utils.py +++ b/src/codeflash_python/benchmarking/utils.py @@ -3,12 +3,12 @@ import logging from typing import TYPE_CHECKING +from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo from codeflash_python.code_utils.time_utils import humanize_runtime -from codeflash_python.models.models import BenchmarkDetail, ProcessedBenchmarkInfo from codeflash_python.result.critic import performance_gain if TYPE_CHECKING: - from codeflash_python.models.models import BenchmarkKey + from codeflash.models.models import BenchmarkKey logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/context/ast_helpers.py b/src/codeflash_python/context/ast_helpers.py index e5dd5193a..c92feab20 100644 --- a/src/codeflash_python/context/ast_helpers.py +++ b/src/codeflash_python/context/ast_helpers.py @@ -2,9 +2,12 @@ import ast import os -from pathlib import Path +from typing import TYPE_CHECKING -from codeflash_python.models.models import CodeStringsMarkdown +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.models.models import CodeStringsMarkdown def parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None: diff --git a/src/codeflash_python/context/call_graph_index.py b/src/codeflash_python/context/call_graph_index.py index 24eca72b0..a97bbb1f0 100644 --- a/src/codeflash_python/context/call_graph_index.py +++ b/src/codeflash_python/context/call_graph_index.py @@ -8,10 +8,10 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import FunctionSource from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages from codeflash_python.context.types import IndexResult from codeflash_python.context.utils import get_qualified_name -from codeflash_python.models.models import FunctionSource if TYPE_CHECKING: from collections.abc import Callable, Iterable diff --git a/src/codeflash_python/context/class_extraction.py b/src/codeflash_python/context/class_extraction.py index 1e21f84ee..2722de97e 100644 --- a/src/codeflash_python/context/class_extraction.py +++ b/src/codeflash_python/context/class_extraction.py @@ -5,6 +5,7 @@ import os from typing import TYPE_CHECKING +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash_python.context.ast_helpers import ( MAX_RAW_PROJECT_CLASS_BODY_ITEMS, MAX_RAW_PROJECT_CLASS_LINES, @@ -24,7 +25,6 @@ parse_and_collect_imports, ) from codeflash_python.context.jedi_helpers import get_jedi_project -from codeflash_python.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: from pathlib import Path diff --git a/src/codeflash_python/context/code_context_extractor.py b/src/codeflash_python/context/code_context_extractor.py index a2087a842..99ab213f9 100644 --- a/src/codeflash_python/context/code_context_extractor.py +++ b/src/codeflash_python/context/code_context_extractor.py @@ -5,10 +5,9 @@ import logging from collections import defaultdict from itertools import chain -from pathlib import Path from typing import TYPE_CHECKING -from codeflash_core.models import FunctionToOptimize # noqa: TC001 +from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown from codeflash_python.code_utils.code_utils import encoded_tokens_len from codeflash_python.code_utils.config_consts import ( OPTIMIZATION_CONTEXT_TOKEN_LIMIT, @@ -26,11 +25,14 @@ from codeflash_python.context.type_extraction import extract_parameter_type_constructors from codeflash_python.context.types import CodeContextType from codeflash_python.context.unused_definition_remover import remove_unused_definitions_by_function_names -from codeflash_python.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource from codeflash_python.static_analysis.code_extractor import find_preexisting_objects from codeflash_python.static_analysis.import_analysis import add_needed_imports_from_module if TYPE_CHECKING: + from pathlib import Path + + from codeflash.models.models import FunctionSource + from codeflash_core.models import FunctionToOptimize from codeflash_python.context.types import DependencyResolver logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/context/jedi_helpers.py b/src/codeflash_python/context/jedi_helpers.py index c5ef9242f..2e311cdb0 100644 --- a/src/codeflash_python/context/jedi_helpers.py +++ b/src/codeflash_python/context/jedi_helpers.py @@ -4,17 +4,19 @@ import os from collections import defaultdict from functools import cache -from pathlib import Path from typing import TYPE_CHECKING -from codeflash_core.models import FunctionToOptimize # noqa: TC001 +from codeflash.models.models import FunctionSource from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages from codeflash_python.context.utils import get_qualified_name -from codeflash_python.models.models import FunctionSource if TYPE_CHECKING: + from pathlib import Path + from jedi.api.classes import Name + from codeflash_core.models import FunctionToOptimize + logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/context/type_extraction.py b/src/codeflash_python/context/type_extraction.py index 8f4840397..ed0c91751 100644 --- a/src/codeflash_python/context/type_extraction.py +++ b/src/codeflash_python/context/type_extraction.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash_python.code_utils.code_utils import path_belongs_to_site_packages from codeflash_python.context.ast_helpers import ( BUILTIN_AND_TYPING_NAMES, @@ -23,7 +24,6 @@ should_use_raw_project_class_context, ) from codeflash_python.context.jedi_helpers import get_jedi_project -from codeflash_python.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: from codeflash_core.models import FunctionToOptimize diff --git a/src/codeflash_python/context/types.py b/src/codeflash_python/context/types.py index 4cad143d1..e7e8e9ae2 100644 --- a/src/codeflash_python/context/types.py +++ b/src/codeflash_python/context/types.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Iterable from pathlib import Path - from codeflash_python.models.models import FunctionSource + from codeflash.models.models import FunctionSource from codeflash_core.models import HelperFunction diff --git a/src/codeflash_python/context/unused_helper_detection.py b/src/codeflash_python/context/unused_helper_detection.py index e7a375f92..763473e60 100644 --- a/src/codeflash_python/context/unused_helper_detection.py +++ b/src/codeflash_python/context/unused_helper_detection.py @@ -9,12 +9,12 @@ from pathlib import Path from typing import TYPE_CHECKING -from codeflash_python.models.models import CodeString, CodeStringsMarkdown +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash_python.static_analysis.code_replacer import replace_function_definitions_in_module if TYPE_CHECKING: + from codeflash.models.models import CodeOptimizationContext, FunctionSource from codeflash_core.models import FunctionToOptimize - from codeflash_python.models.models import CodeOptimizationContext, FunctionSource logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/discovery/discover_unit_tests.py b/src/codeflash_python/discovery/discover_unit_tests.py index c823f9610..dd4b992be 100644 --- a/src/codeflash_python/discovery/discover_unit_tests.py +++ b/src/codeflash_python/discovery/discover_unit_tests.py @@ -17,12 +17,12 @@ from codeflash_core.models import FunctionToOptimize from pydantic.dataclasses import dataclass +from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType from codeflash_python.code_utils.code_utils import ImportErrorPattern, get_run_tmp_file, module_name_from_file_path from codeflash_python.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash_python.code_utils.shell_utils import get_cross_platform_subprocess_run_args from codeflash_python.discovery.import_analyzer import filter_test_files_by_imports from codeflash_python.discovery.tests_cache import TestsCache -from codeflash_python.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType from codeflash_python.verification.addopts import custom_addopts if TYPE_CHECKING: diff --git a/src/codeflash_python/discovery/function_filtering.py b/src/codeflash_python/discovery/function_filtering.py index 452bb4802..d88241604 100644 --- a/src/codeflash_python/discovery/function_filtering.py +++ b/src/codeflash_python/discovery/function_filtering.py @@ -21,8 +21,8 @@ if TYPE_CHECKING: from argparse import Namespace + from codeflash.models.models import CodeOptimizationContext from codeflash_core.models import FunctionToOptimize - from codeflash_python.models.models import CodeOptimizationContext logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/discovery/import_analyzer.py b/src/codeflash_python/discovery/import_analyzer.py index 8a29f7288..03c766a9f 100644 --- a/src/codeflash_python/discovery/import_analyzer.py +++ b/src/codeflash_python/discovery/import_analyzer.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from codeflash_python.models.models import TestsInFile + from codeflash.models.models import TestsInFile logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/discovery/tests_cache.py b/src/codeflash_python/discovery/tests_cache.py index 2cfbaaf97..850810d5e 100644 --- a/src/codeflash_python/discovery/tests_cache.py +++ b/src/codeflash_python/discovery/tests_cache.py @@ -8,8 +8,8 @@ from collections import defaultdict from pathlib import Path +from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType from codeflash_python.code_utils.compat import codeflash_cache_db -from codeflash_python.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/function_optimizer.py b/src/codeflash_python/function_optimizer.py index 3a95d9e77..ec8bcae8c 100644 --- a/src/codeflash_python/function_optimizer.py +++ b/src/codeflash_python/function_optimizer.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import TYPE_CHECKING, cast +from codeflash.models.models import OptimizationSet, TestFiles, TestingMode, TestResults from codeflash_core.danom import Err, Ok from codeflash_python.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path, unified_diff_strings from codeflash_python.code_utils.config_consts import ( @@ -27,7 +28,6 @@ ) from codeflash_python.discovery.function_filtering import was_function_previously_optimized from codeflash_python.models.experiment_metadata import ExperimentMetadata -from codeflash_python.models.models import OptimizationSet, TestFiles, TestingMode, TestResults from codeflash_python.optimizer import resolve_python_function_ast from codeflash_python.optimizer_mixins import ( BaselineEstablishmentMixin, @@ -52,13 +52,7 @@ from argparse import Namespace from typing import Any - from codeflash_core.config import TestConfig - from codeflash_core.danom import Result - from codeflash_core.models import FunctionToOptimize - from codeflash_python.api.aiservice import AiServiceClient - from codeflash_python.api.types import TestDiff, TestFileReview - from codeflash_python.context.types import DependencyResolver - from codeflash_python.models.models import ( + from codeflash.models.models import ( BenchmarkKey, BestOptimization, CodeOptimizationContext, @@ -69,6 +63,12 @@ GeneratedTestsList, OriginalCodeBaseline, ) + from codeflash_core.config import TestConfig + from codeflash_core.danom import Result + from codeflash_core.models import FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.api.types import TestDiff, TestFileReview + from codeflash_python.context.types import DependencyResolver logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/models/call_graph.py b/src/codeflash_python/models/call_graph.py index 93cc2d936..b44ba125f 100644 --- a/src/codeflash_python/models/call_graph.py +++ b/src/codeflash_python/models/call_graph.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash_python.models.models import FunctionSource + from codeflash.models.models import FunctionSource class FunctionNode(NamedTuple): @@ -197,7 +197,7 @@ def augment_with_trace(graph: CallGraph, trace_db_path: Path) -> CallGraph: def callees_from_graph(graph: CallGraph) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: - from codeflash_python.models.models import FunctionSource + from codeflash.models.models import FunctionSource file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) function_source_list: list[FunctionSource] = [] diff --git a/src/codeflash_python/models/models.py b/src/codeflash_python/models/models.py deleted file mode 100644 index ead112341..000000000 --- a/src/codeflash_python/models/models.py +++ /dev/null @@ -1,817 +0,0 @@ -from __future__ import annotations - -import os -from collections import Counter, defaultdict -from collections.abc import Collection -from functools import lru_cache -from re import Pattern -from typing import TYPE_CHECKING - -from codeflash_core.models import FunctionParent -from codeflash_python.models.test_type import TestType - -if TYPE_CHECKING: - from collections.abc import Iterator - -import enum -import logging -import re -import sys -from enum import Enum -from pathlib import Path -from typing import Any, cast - -from pydantic import BaseModel, ConfigDict, PrivateAttr, ValidationError, model_validator -from pydantic.dataclasses import dataclass - -from codeflash_python.code_utils.code_utils import module_name_from_file_path, validate_python_code - -logger = logging.getLogger("codeflash_python") - -DEBUG_MODE = os.environ.get("CODEFLASH_DEBUG", "").lower() in ("1", "true") - - -# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully -# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name -# of the module is foo.eggs. - - -class ValidCode(BaseModel): - model_config = ConfigDict(frozen=True) - - source_code: str - normalized_code: str - - -@dataclass(frozen=True) -class FunctionSource: - file_path: Path - qualified_name: str - fully_qualified_name: str - only_function_name: str - source_code: str - definition_type: str | None = None # e.g. "function", "class"; None for non-Python languages - - def __eq__(self, other: object) -> bool: - if not isinstance(other, FunctionSource): - return False - return ( - self.file_path == other.file_path - and self.qualified_name == other.qualified_name - and self.fully_qualified_name == other.fully_qualified_name - and self.only_function_name == other.only_function_name - and self.source_code == other.source_code - ) - - def __hash__(self) -> int: - return hash( - (self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code) - ) - - -class BestOptimization(BaseModel): - candidate: OptimizedCandidate - explanation_v2: str | None = None - helper_functions: list[FunctionSource] - code_context: CodeOptimizationContext - runtime: int - replay_performance_gain: dict[BenchmarkKey, float] | None = None - winning_behavior_test_results: TestResults - winning_benchmarking_test_results: TestResults - winning_replay_benchmarking_test_results: TestResults | None = None - line_profiler_test_results: dict[Any, Any] - async_throughput: int | None = None - concurrency_metrics: ConcurrencyMetrics | None = None - - -@dataclass(frozen=True) -class BenchmarkKey: - module_path: str - function_name: str - - def __str__(self) -> str: - return f"{self.module_path}::{self.function_name}" - - -@dataclass -class ConcurrencyMetrics: - sequential_time_ns: int - concurrent_time_ns: int - concurrency_factor: int - concurrency_ratio: float # sequential_time / concurrent_time - - -@dataclass -class BenchmarkDetail: - benchmark_name: str - test_function: str - original_timing: str - expected_new_timing: str - speedup_percent: float - - def to_string(self) -> str: - return ( - f"Original timing for {self.benchmark_name}::{self.test_function}: {self.original_timing}\n" - f"Expected new timing for {self.benchmark_name}::{self.test_function}: {self.expected_new_timing}\n" - f"Benchmark speedup for {self.benchmark_name}::{self.test_function}: {self.speedup_percent:.2f}%\n" - ) - - def to_dict(self) -> dict[str, Any]: - return { - "benchmark_name": self.benchmark_name, - "test_function": self.test_function, - "original_timing": self.original_timing, - "expected_new_timing": self.expected_new_timing, - "speedup_percent": self.speedup_percent, - } - - -@dataclass -class ProcessedBenchmarkInfo: - benchmark_details: list[BenchmarkDetail] - - def to_string(self) -> str: - if not self.benchmark_details: - return "" - - result = "Benchmark Performance Details:\n" - for detail in self.benchmark_details: - result += detail.to_string() + "\n" - return result - - def to_dict(self) -> dict[str, list[dict[str, Any]]]: - return {"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]} - - -class CodeString(BaseModel): - code: str - file_path: Path | None = None - language: str = "python" # Language for validation - - @model_validator(mode="after") - def validate_code_syntax(self) -> CodeString: - """Validate code syntax for the specified language.""" - if self.language == "python": - validate_python_code(self.code) - else: - try: - compile(self.code, "", "exec") - except SyntaxError: - msg = f"Invalid {self.language.title()} code" - raise ValueError(msg) from None - return self - - -def get_comment_prefix(file_path: Path) -> str: - """Get the comment prefix for a given language.""" - return "#" - - -def get_code_block_splitter(file_path: Path | None) -> str: - if file_path is None: - return "" - comment_prefix = get_comment_prefix(file_path) - return f"{comment_prefix} file: {file_path.as_posix()}" - - -# Pattern to match markdown code blocks with optional language tag and file path -# Matches: ```language:filepath\ncode\n``` or ```language\ncode\n``` -markdown_pattern = re.compile(r"```(\w+)(?::([^\n]+))?\n(.*?)\n```", re.DOTALL) - - -class CodeStringsMarkdown(BaseModel): - code_strings: list[CodeString] = [] - language: str = "python" # Language for markdown code block tags - _cache: dict[str, Any] = PrivateAttr(default_factory=dict) - - @property - def flat(self) -> str: - """Returns the combined source code module from all code blocks. - - Each block is prefixed by a file path comment to indicate its origin. - The comment prefix is determined by the language attribute. - - Returns: - str: The concatenated code of all blocks with file path annotations. - - !! Important !!: - Avoid parsing the flat code with multiple files, - parsing may result in unexpected behavior. - - - """ - if self._cache.get("flat") is not None: - return self._cache["flat"] - self._cache["flat"] = "\n".join( - get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings - ) - return self._cache["flat"] - - @property - def markdown(self) -> str: - """Returns a Markdown-formatted string containing all code blocks. - - Each block is enclosed in a triple-backtick code block with an optional - file path suffix (e.g., ```python:filename.py). - - The language tag is determined by the `language` attribute. - - Returns: - str: Markdown representation of the code blocks. - - """ - return "\n".join( - [ - f"```{self.language}{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```" - for code_string in self.code_strings - ] - ) - - def file_to_path(self) -> dict[str, str]: - """Return a dictionary mapping file paths to their corresponding code blocks. - - Returns: - dict[str, str]: Mapping from file path (as string) to code. - - """ - if self._cache.get("file_to_path") is not None: - return self._cache["file_to_path"] - self._cache["file_to_path"] = { - str(code_string.file_path): code_string.code for code_string in self.code_strings - } - return self._cache["file_to_path"] - - @staticmethod - def parse_markdown_code(markdown_code: str, expected_language: str = "python") -> CodeStringsMarkdown: - """Parse a Markdown string into a CodeStringsMarkdown object. - - Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance. - - Args: - markdown_code (str): The Markdown-formatted string to parse. - expected_language (str): The expected language of code blocks (default: "python"). - - Returns: - CodeStringsMarkdown: Parsed object containing code blocks. - - """ - matches = markdown_pattern.findall(markdown_code) - code_string_list = [] - detected_language = expected_language - try: - for language, file_path, code in matches: - # Use the first detected language or the expected language - if language: - detected_language = language - if file_path: - path = file_path.strip() - code_string_list.append(CodeString(code=code, file_path=Path(path), language=detected_language)) - else: - # No file path specified - skip this block or create with None - code_string_list.append(CodeString(code=code, file_path=None, language=detected_language)) - return CodeStringsMarkdown(code_strings=code_string_list, language=detected_language) - except ValidationError: - # if any file is invalid, return an empty CodeStringsMarkdown for the entire context - return CodeStringsMarkdown(language=expected_language) - - -class CodeOptimizationContext(BaseModel): - testgen_context: CodeStringsMarkdown - read_writable_code: CodeStringsMarkdown - read_only_context_code: str = "" - hashing_code_context: str = "" - hashing_code_context_hash: str = "" - helper_functions: list[FunctionSource] - testgen_helper_fqns: list[str] = [] - preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] - - -class OptimizedCandidateResult(BaseModel): - max_loop_count: int - best_test_runtime: int - behavior_test_results: TestResults - benchmarking_test_results: TestResults - replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None - optimization_candidate_index: int - total_candidate_timing: int - async_throughput: int | None = None - concurrency_metrics: ConcurrencyMetrics | None = None - - -class GeneratedTests(BaseModel): - generated_original_test_source: str - instrumented_behavior_test_source: str - instrumented_perf_test_source: str - raw_generated_test_source: str | None = None - behavior_file_path: Path - perf_file_path: Path - - -class GeneratedTestsList(BaseModel): - generated_tests: list[GeneratedTests] - - -class TestFile(BaseModel): - instrumented_behavior_file_path: Path - benchmarking_file_path: Path | None = None - original_file_path: Path | None = None - original_source: str | None = None - test_type: TestType - tests_in_file: list[TestsInFile] | None = None - - -class TestFiles(BaseModel): - test_files: list[TestFile] - - def get_by_type(self, test_type: TestType) -> TestFiles: - return TestFiles(test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type]) - - def add(self, test_file: TestFile) -> None: - if test_file not in self.test_files: - self.test_files.append(test_file) - else: - msg = "Test file already exists in the list" - raise ValueError(msg) - - def get_by_original_file_path(self, file_path: Path) -> TestFile | None: - normalized = self._normalize_path_for_comparison(file_path) - for test_file in self.test_files: - if test_file.original_file_path is None: - continue - normalized_test_path = self._normalize_path_for_comparison(test_file.original_file_path) - if normalized == normalized_test_path: - return test_file - return None - - def get_test_type_by_instrumented_file_path(self, file_path: Path) -> TestType | None: - normalized = self._normalize_path_for_comparison(file_path) - for test_file in self.test_files: - normalized_behavior_path = self._normalize_path_for_comparison(test_file.instrumented_behavior_file_path) - if normalized == normalized_behavior_path: - return test_file.test_type - if test_file.benchmarking_file_path is not None: - normalized_benchmark_path = self._normalize_path_for_comparison(test_file.benchmarking_file_path) - if normalized == normalized_benchmark_path: - return test_file.test_type - - # Fallback: try filename-only matching when normalized paths don't match - file_name = file_path.name - for test_file in self.test_files: - if ( - test_file.instrumented_behavior_file_path - and test_file.instrumented_behavior_file_path.name == file_name - ): - return test_file.test_type - if test_file.benchmarking_file_path and test_file.benchmarking_file_path.name == file_name: - return test_file.test_type - - return None - - def get_test_type_by_original_file_path(self, file_path: Path) -> TestType | None: - normalized = self._normalize_path_for_comparison(file_path) - for test_file in self.test_files: - if test_file.original_file_path is None: - continue - normalized_test_path = self._normalize_path_for_comparison(test_file.original_file_path) - if normalized == normalized_test_path: - return test_file.test_type - return None - - @staticmethod - @lru_cache(maxsize=4096) - def _normalize_path_for_comparison(path: Path) -> str: - """Normalize a path for cross-platform comparison. - - Resolves the path to an absolute path and handles Windows case-insensitivity. - """ - try: - resolved = str(path.resolve()) - except (OSError, RuntimeError): - # If resolve fails (e.g., file doesn't exist), use absolute path - resolved = str(path.absolute()) - # Only lowercase on Windows where filesystem is case-insensitive - return resolved.lower() if sys.platform == "win32" else resolved - - def __iter__(self) -> Iterator[TestFile]: # type: ignore[override] - return iter(self.test_files) - - def __len__(self) -> int: - return len(self.test_files) - - -class OptimizationSet(BaseModel): - control: list[OptimizedCandidate] - experiment: list[OptimizedCandidate] | None - - -@dataclass(frozen=True) -class TestsInFile: - test_file: Path - test_class: str | None - test_function: str - test_type: TestType - - -class OptimizedCandidateSource(str, Enum): - OPTIMIZE = "OPTIMIZE" - OPTIMIZE_LP = "OPTIMIZE_LP" - REFINE = "REFINE" - REPAIR = "REPAIR" - ADAPTIVE = "ADAPTIVE" - JIT_REWRITE = "JIT_REWRITE" - - -@dataclass(frozen=True) -class OptimizedCandidate: - source_code: CodeStringsMarkdown - explanation: str - optimization_id: str - source: OptimizedCandidateSource - parent_id: str | None = None - model: str | None = None # Which LLM model generated this candidate - - -@dataclass(frozen=True) -class FunctionCalledInTest: - tests_in_file: TestsInFile - position: CodePosition - - -@dataclass(frozen=True) -class CodePosition: - line_no: int - col_no: int - - -class OriginalCodeBaseline(BaseModel): - behavior_test_results: TestResults - benchmarking_test_results: TestResults - replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None - line_profile_results: dict - runtime: int - coverage_results: CoverageData | None - async_throughput: int | None = None - concurrency_metrics: ConcurrencyMetrics | None = None - - -class CoverageStatus(Enum): - NOT_FOUND = "Coverage Data Not Found" - PARSED_SUCCESSFULLY = "Parsed Successfully" - - -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class CoverageData: - """Represents the coverage data for a specific function in a source file, using one or more test files.""" - - file_path: Path - coverage: float - function_name: str - functions_being_tested: list[str] - graph: dict[str, dict[str, Collection[object]]] - code_context: CodeOptimizationContext - main_func_coverage: FunctionCoverage - dependent_func_coverage: FunctionCoverage | None - status: CoverageStatus - blank_re: Pattern[str] = re.compile(r"\s*(#|$)") - else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)") - - def build_message(self) -> str: - if self.status == CoverageStatus.NOT_FOUND: - return f"No coverage data found for {self.function_name}" - return f"{self.coverage:.1f}%" - - def log_coverage(self) -> None: - lines = ["Test Coverage Results", f" Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%"] - if self.dependent_func_coverage: - lines.append( - f" Dependent Function: {self.dependent_func_coverage.name}: {self.dependent_func_coverage.coverage:.2f}%" - ) - lines.append(f" Total Coverage: {self.coverage:.2f}%") - logger.info("\n".join(lines)) - - if not self.coverage: - logger.debug(self.graph) - - @classmethod - def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOptimizationContext) -> CoverageData: - return cls( - file_path=file_path, - coverage=0.0, - function_name=function_name, - functions_being_tested=[function_name], - graph={ - function_name: { - "executed_lines": set(), - "unexecuted_lines": set(), - "executed_branches": [], - "unexecuted_branches": [], - } - }, - code_context=code_context, - main_func_coverage=FunctionCoverage( - name=function_name, - coverage=0.0, - executed_lines=[], - unexecuted_lines=[], - executed_branches=[], - unexecuted_branches=[], - ), - dependent_func_coverage=None, - status=CoverageStatus.NOT_FOUND, - ) - - -@dataclass -class FunctionCoverage: - """Represents the coverage data for a specific function in a source file.""" - - name: str - coverage: float - executed_lines: list[int] - unexecuted_lines: list[int] - executed_branches: list[list[int]] - unexecuted_branches: list[list[int]] - - -class TestingMode(enum.Enum): - BEHAVIOR = "behavior" - PERFORMANCE = "performance" - LINE_PROFILE = "line_profile" - CONCURRENCY = "concurrency" - - -# Intentionally duplicated in codeflash_capture (runs in subprocess, can't import from here) -class VerificationType(str, Enum): - FUNCTION_CALL = ( - "function_call" # Correctness verification for a test function, checks input values and output values) - ) - INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init - INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init - - -@dataclass(frozen=True) -class InvocationId: - test_module_path: str # The fully qualified name of the test module - test_class_name: str | None # The name of the class where the test is defined - test_function_name: str | None # The name of the test_function. Does not include the components of the file_name - function_getting_tested: str - iteration_id: str | None - - # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id - def id(self) -> str: - class_prefix = f"{self.test_class_name}." if self.test_class_name else "" - return ( - f"{self.test_module_path}:{class_prefix}{self.test_function_name}:" - f"{self.function_getting_tested}:{self.iteration_id}" - ) - - # TestSuiteClass.test_function_name - def test_fn_qualified_name(self) -> str: - # Use f-string with inline conditional to reduce string concatenation operations - return ( - f"{self.test_class_name}.{self.test_function_name}" - if self.test_class_name - else str(self.test_function_name) - ) - - @staticmethod - def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: - components = string_id.split(":") - assert len(components) == 4 - second_components = components[1].split(".") - if len(second_components) == 1: - test_class_name = None - test_function_name = second_components[0] - else: - test_class_name = second_components[0] - test_function_name = second_components[1] - return InvocationId( - test_module_path=components[0], - test_class_name=test_class_name, - test_function_name=test_function_name, - function_getting_tested=components[2], - iteration_id=iteration_id if iteration_id else components[3], - ) - - -@dataclass(frozen=True) -class FunctionTestInvocation: - loop_index: int # The loop index of the function invocation, starts at 1 - id: InvocationId # The fully qualified name of the function invocation (id) - file_name: Path # The file where the test is defined - did_pass: bool # Whether the test this function invocation was part of, passed or failed - runtime: int | None # Time in nanoseconds - test_framework: str # unittest or pytest - test_type: TestType - return_value: object | None # The return value of the function invocation - timed_out: bool | None - verification_type: str | None = VerificationType.FUNCTION_CALL - stdout: str | None = None - - @property - def unique_invocation_loop_id(self) -> str: - return f"{self.loop_index}:{self.id.id()}" - - -class TestResults(BaseModel): # noqa: PLW1641 - # don't modify these directly, use the add method - # also we don't support deletion of test results elements - caution is advised - test_results: list[FunctionTestInvocation] = [] - test_result_idx: dict[str, int] = {} - - perf_stdout: str | None = None - # mapping between test function name and stdout failure message - test_failures: dict[str, str] | None = None - - def add(self, function_test_invocation: FunctionTestInvocation) -> None: - unique_id = function_test_invocation.unique_invocation_loop_id - test_result_idx = self.test_result_idx - if unique_id in test_result_idx: - if DEBUG_MODE: - logger.warning("Test result with id %s already exists. SKIPPING", unique_id) - return - test_results = self.test_results - test_result_idx[unique_id] = len(test_results) - test_results.append(function_test_invocation) - - def merge(self, other: TestResults) -> None: - original_len = len(self.test_results) - self.test_results.extend(other.test_results) - for k, v in other.test_result_idx.items(): - if k in self.test_result_idx: - msg = f"Test result with id {k} already exists." - raise ValueError(msg) - self.test_result_idx[k] = v + original_len - - def group_by_benchmarks( - self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path - ) -> dict[BenchmarkKey, TestResults]: - """Group TestResults by benchmark for calculating improvements for each benchmark.""" - test_results_by_benchmark = defaultdict(TestResults) - benchmark_module_path = {} - for benchmark_key in benchmark_keys: - benchmark_module_path[benchmark_key] = module_name_from_file_path( - benchmark_replay_test_dir.resolve() - / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", - project_root, - traverse_up=True, - ) - for test_result in self.test_results: - if test_result.test_type == TestType.REPLAY_TEST: - for benchmark_key, module_path in benchmark_module_path.items(): - if test_result.id.test_module_path.startswith(module_path): - test_results_by_benchmark[benchmark_key].add(test_result) - - return test_results_by_benchmark - - def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None: - try: - return self.test_results[self.test_result_idx[unique_invocation_loop_id]] - except (IndexError, KeyError): - return None - - def get_all_ids(self) -> set[InvocationId]: - return {test_result.id for test_result in self.test_results} - - def get_all_unique_invocation_loop_ids(self) -> set[str]: - return {test_result.unique_invocation_loop_id for test_result in self.test_results} - - def number_of_loops(self) -> int: - if not self.test_results: - return 0 - return max(test_result.loop_index for test_result in self.test_results) - - def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: - report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType} - for test_result in self.test_results: - if test_result.loop_index != 1: - continue - if test_result.did_pass: - report[test_result.test_type]["passed"] += 1 - else: - report[test_result.test_type]["failed"] += 1 - return report - - @staticmethod - def report_to_string(report: dict[TestType, dict[str, int]]) -> str: - return " ".join( - [ - f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" - for test_type in TestType - ] - ) - - @staticmethod - def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> str: - lines = [title] - for test_type in TestType: - if test_type is TestType.INIT_STATE_TEST: - continue - lines.append( - f" {test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}" - ) - return "\n".join(lines) - - def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: - # Efficient single traversal, directly accumulating into a dict. - # can track mins here and only sums can be return in total_passed_runtime - by_id: dict[InvocationId, list[int]] = {} - for result in self.test_results: - if result.did_pass: - if result.runtime: - by_id.setdefault(result.id, []).append(result.runtime) - else: - msg = ( - f"Ignoring test case that passed but had no runtime -> {result.id}, " - f"Loop # {result.loop_index}, Test Type: {result.test_type}, " - f"Verification Type: {result.verification_type}" - ) - logger.debug(msg) - return by_id - - def total_passed_runtime(self) -> int: - """Calculate the sum of runtimes of all test cases that passed. - - A testcase runtime is the minimum value of all looped execution runtimes. - - :return: The runtime in nanoseconds. - """ - # TODO this doesn't look at the intersection of tests of baseline and original - return sum( - [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] - ) - - def effective_loop_count(self) -> int: - """Calculate the effective number of complete loops. - - Returns the maximum loop_index seen across all test results. This represents - the number of timing iterations that were performed. - - :return: The effective loop count, or 0 if no test results. - """ - if not self.test_results: - return 0 - # Get all loop indices from results that have timing data - loop_indices = {result.loop_index for result in self.test_results if result.runtime is not None} - if not loop_indices: - # Fallback: use all loop indices even without runtime - loop_indices = {result.loop_index for result in self.test_results} - return max(loop_indices) if loop_indices else 0 - - def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]: - map_gen_test_file_to_no_of_tests = Counter() - for gen_test_result in self.test_results: - if ( - gen_test_result.test_type == TestType.GENERATED_REGRESSION - and gen_test_result.id.test_function_name not in test_functions_to_remove - ): - map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1 - return map_gen_test_file_to_no_of_tests - - def __iter__(self) -> Iterator[FunctionTestInvocation]: # type: ignore[override] - return iter(self.test_results) - - def __len__(self) -> int: - return len(self.test_results) - - def __getitem__(self, index: int) -> FunctionTestInvocation: - return self.test_results[index] - - def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: - self.test_results[index] = value - - def __contains__(self, value: FunctionTestInvocation) -> bool: - return value in self.test_results - - def __bool__(self) -> bool: - return bool(self.test_results) - - def __eq__(self, other: object) -> bool: - # Unordered comparison - if type(self) is not type(other): - return False - if len(self) != len(other): # type: ignore[arg-type] - return False - from codeflash_python.verification.comparator import comparator - - original_recursion_limit = sys.getrecursionlimit() - cast("TestResults", other) - for test_result in self: - other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) # type: ignore[attr-defined] - if other_test_result is None: - return False - - if original_recursion_limit < 5000: - sys.setrecursionlimit(5000) - if ( - test_result.file_name != other_test_result.file_name - or test_result.did_pass != other_test_result.did_pass - or test_result.runtime != other_test_result.runtime - or test_result.test_framework != other_test_result.test_framework - or test_result.test_type != other_test_result.test_type - or not comparator(test_result.return_value, other_test_result.return_value) - ): - sys.setrecursionlimit(original_recursion_limit) - return False - sys.setrecursionlimit(original_recursion_limit) - return True diff --git a/src/codeflash_python/optimization/optimizer.py b/src/codeflash_python/optimization/optimizer.py index f54237e7a..adad3a91c 100644 --- a/src/codeflash_python/optimization/optimizer.py +++ b/src/codeflash_python/optimization/optimizer.py @@ -17,10 +17,10 @@ import ast from argparse import Namespace + from codeflash.models.models import BenchmarkKey, FunctionCalledInTest from codeflash_core.models import FunctionToOptimize from codeflash_python.context.types import DependencyResolver from codeflash_python.function_optimizer import FunctionOptimizer - from codeflash_python.models.models import BenchmarkKey, FunctionCalledInTest try: from codeflash_core.config import TestConfig diff --git a/src/codeflash_python/optimizer.py b/src/codeflash_python/optimizer.py index d455b04f1..c5a7c1971 100644 --- a/src/codeflash_python/optimizer.py +++ b/src/codeflash_python/optimizer.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING -from codeflash_python.models.models import ValidCode +from codeflash.models.models import ValidCode if TYPE_CHECKING: from pathlib import Path diff --git a/src/codeflash_python/optimizer_mixins/_protocol.py b/src/codeflash_python/optimizer_mixins/_protocol.py index bd1d57c3f..b39242ed5 100644 --- a/src/codeflash_python/optimizer_mixins/_protocol.py +++ b/src/codeflash_python/optimizer_mixins/_protocol.py @@ -14,13 +14,7 @@ from argparse import Namespace from pathlib import Path - from codeflash_core.config import TestConfig - from codeflash_core.danom import Err, Result - from codeflash_core.models import FunctionToOptimize - from codeflash_python.api.aiservice import AiServiceClient - from codeflash_python.api.types import TestDiff, TestFileReview - from codeflash_python.context.types import DependencyResolver - from codeflash_python.models.models import ( + from codeflash.models.models import ( BenchmarkKey, BestOptimization, CodeOptimizationContext, @@ -38,6 +32,12 @@ TestingMode, TestResults, ) + from codeflash_core.config import TestConfig + from codeflash_core.danom import Err, Result + from codeflash_core.models import FunctionToOptimize + from codeflash_python.api.aiservice import AiServiceClient + from codeflash_python.api.types import TestDiff, TestFileReview + from codeflash_python.context.types import DependencyResolver from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext, CandidateNode from codeflash_python.result.explanation import Explanation diff --git a/src/codeflash_python/optimizer_mixins/baseline.py b/src/codeflash_python/optimizer_mixins/baseline.py index 6d6ffab44..8b34fb903 100644 --- a/src/codeflash_python/optimizer_mixins/baseline.py +++ b/src/codeflash_python/optimizer_mixins/baseline.py @@ -4,18 +4,18 @@ from collections import defaultdict from typing import TYPE_CHECKING, cast +from codeflash.models.models import OriginalCodeBaseline, TestingMode, TestResults, TestType from codeflash_core.danom import Err, Ok from codeflash_python.code_utils.code_utils import cleanup_paths from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE from codeflash_python.code_utils.time_utils import humanize_runtime -from codeflash_python.models.models import OriginalCodeBaseline, TestingMode, TestType from codeflash_python.result.critic import coverage_critic, quantity_of_tests_critic if TYPE_CHECKING: from pathlib import Path + from codeflash.models.models import CodeOptimizationContext, CoverageData, FunctionCalledInTest from codeflash_core.danom import Result - from codeflash_python.models.models import CodeOptimizationContext, CoverageData, FunctionCalledInTest, TestResults from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base else: _Base = object diff --git a/src/codeflash_python/optimizer_mixins/candidate_evaluation.py b/src/codeflash_python/optimizer_mixins/candidate_evaluation.py index c9da5c5aa..c72a0876c 100644 --- a/src/codeflash_python/optimizer_mixins/candidate_evaluation.py +++ b/src/codeflash_python/optimizer_mixins/candidate_evaluation.py @@ -7,6 +7,13 @@ import libcst as cst +from codeflash.models.models import ( + BestOptimization, + OptimizedCandidate, + OptimizedCandidateResult, + OptimizedCandidateSource, + TestingMode, +) from codeflash_core.danom import Ok from codeflash_python.api.types import AIServiceRefinerRequest from codeflash_python.code_utils.code_utils import get_run_tmp_file, unified_diff_strings @@ -16,21 +23,14 @@ EffortKeys, get_effort_value, ) -from codeflash_python.models.models import ( - BestOptimization, - OptimizedCandidate, - OptimizedCandidateResult, - OptimizedCandidateSource, - TestingMode, -) from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext, CandidateProcessor from codeflash_python.optimizer_mixins.scoring import create_rank_dictionary_compact, diff_length from codeflash_python.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic if TYPE_CHECKING: + from codeflash.models.models import CodeOptimizationContext, OriginalCodeBaseline from codeflash_core.danom import Result from codeflash_python.api.aiservice import AiServiceClient - from codeflash_python.models.models import CodeOptimizationContext, OriginalCodeBaseline from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base from codeflash_python.optimizer_mixins.candidate_structures import CandidateNode else: @@ -243,7 +243,7 @@ def run_optimized_candidate( ) finally: self.write_code_and_helpers(candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path) - from codeflash_python.models.models import TestResults + from codeflash.models.models import TestResults assert isinstance(candidate_behavior_results, TestResults) match, diffs = self.compare_candidate_results( @@ -276,7 +276,7 @@ def run_optimized_candidate( candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path ) # Use effective_loop_count which represents the number of timing samples across all test cases. - from codeflash_python.models.models import TestResults as TestResultsModel + from codeflash.models.models import TestResults as TestResultsModel assert isinstance(candidate_benchmarking_results, TestResultsModel) loop_count = candidate_benchmarking_results.effective_loop_count() diff --git a/src/codeflash_python/optimizer_mixins/candidate_structures.py b/src/codeflash_python/optimizer_mixins/candidate_structures.py index 03bcce3c9..6c1bd6963 100644 --- a/src/codeflash_python/optimizer_mixins/candidate_structures.py +++ b/src/codeflash_python/optimizer_mixins/candidate_structures.py @@ -17,7 +17,7 @@ ) if TYPE_CHECKING: - from codeflash_python.models.models import CodeOptimizationContext, OptimizedCandidate + from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/optimizer_mixins/code_replacement.py b/src/codeflash_python/optimizer_mixins/code_replacement.py index 6f8faad55..4050af1dd 100644 --- a/src/codeflash_python/optimizer_mixins/code_replacement.py +++ b/src/codeflash_python/optimizer_mixins/code_replacement.py @@ -8,7 +8,7 @@ from codeflash_python.code_utils.formatter import format_code, sort_imports if TYPE_CHECKING: - from codeflash_python.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionSource + from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionSource from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base else: _Base = object diff --git a/src/codeflash_python/optimizer_mixins/refinement.py b/src/codeflash_python/optimizer_mixins/refinement.py index d52311296..c095f528f 100644 --- a/src/codeflash_python/optimizer_mixins/refinement.py +++ b/src/codeflash_python/optimizer_mixins/refinement.py @@ -3,15 +3,15 @@ import logging from typing import TYPE_CHECKING +from codeflash.models.models import OptimizedCandidateSource from codeflash_python.code_utils.config_consts import MIN_CORRECT_CANDIDATES, EffortKeys, get_effort_value -from codeflash_python.models.models import OptimizedCandidateSource if TYPE_CHECKING: import concurrent.futures + from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate from codeflash_python.api.aiservice import AiServiceClient from codeflash_python.api.types import TestDiff - from codeflash_python.models.models import CodeOptimizationContext, OptimizedCandidate from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base from codeflash_python.optimizer_mixins.candidate_structures import CandidateEvaluationContext else: diff --git a/src/codeflash_python/optimizer_mixins/result_processing.py b/src/codeflash_python/optimizer_mixins/result_processing.py index 6d4b88264..f0ac1b509 100644 --- a/src/codeflash_python/optimizer_mixins/result_processing.py +++ b/src/codeflash_python/optimizer_mixins/result_processing.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash_python.models.models import ( + from codeflash.models.models import ( BestOptimization, CodeOptimizationContext, FunctionCalledInTest, diff --git a/src/codeflash_python/optimizer_mixins/test_execution.py b/src/codeflash_python/optimizer_mixins/test_execution.py index c3fbf5aa9..7a3720b56 100644 --- a/src/codeflash_python/optimizer_mixins/test_execution.py +++ b/src/codeflash_python/optimizer_mixins/test_execution.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import TestingMode, TestResults, TestType from codeflash_python.code_utils.config_consts import INDIVIDUAL_TESTCASE_TIMEOUT, TOTAL_LOOPING_TIME_EFFECTIVE -from codeflash_python.models.models import TestingMode, TestResults, TestType from codeflash_python.verification.instrument_existing_tests import inject_profiling_into_existing_test from codeflash_python.verification.parse_test_output import parse_test_results from codeflash_python.verification.test_output_utils import parse_concurrency_metrics @@ -18,7 +18,7 @@ ) if TYPE_CHECKING: - from codeflash_python.models.models import ( + from codeflash.models.models import ( CodeOptimizationContext, ConcurrencyMetrics, CoverageData, @@ -135,8 +135,8 @@ def run_behavioral_validation( return None def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]: + from codeflash.models.models import TestFile from codeflash_python.models.function_types import qualified_name_with_modules_from_root - from codeflash_python.models.models import TestFile assert self.project_root is not None existing_test_files_count = 0 @@ -307,7 +307,7 @@ def run_concurrency_benchmark( ) # Parse concurrency metrics from stdout - from codeflash_python.models.models import TestResults as TestResultsInternal + from codeflash.models.models import TestResults as TestResultsInternal if ( concurrency_results diff --git a/src/codeflash_python/optimizer_mixins/test_generation.py b/src/codeflash_python/optimizer_mixins/test_generation.py index 11d3a32bb..2f0be1a2a 100644 --- a/src/codeflash_python/optimizer_mixins/test_generation.py +++ b/src/codeflash_python/optimizer_mixins/test_generation.py @@ -5,19 +5,19 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import GeneratedTests, GeneratedTestsList from codeflash_core.danom import Err, Ok from codeflash_python.code_utils.config_consts import INDIVIDUAL_TESTCASE_TIMEOUT, EffortKeys, get_effort_value -from codeflash_python.models.models import GeneratedTests, GeneratedTestsList from codeflash_python.verification.verifier import generate_tests if TYPE_CHECKING: - from codeflash_core.danom import Result - from codeflash_python.models.models import ( + from codeflash.models.models import ( CodeOptimizationContext, CodeStringsMarkdown, FunctionCalledInTest, FunctionSource, ) + from codeflash_core.danom import Result from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base else: _Base = object @@ -151,8 +151,8 @@ def generate_and_instrument_tests( str, ]: """Generate and instrument tests for the function.""" + from codeflash.models.models import TestFile, TestType from codeflash_python.code_utils.code_utils import get_run_tmp_file - from codeflash_python.models.models import TestFile, TestType from codeflash_python.verification.verification_utils import get_test_file_path n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) diff --git a/src/codeflash_python/optimizer_mixins/test_review.py b/src/codeflash_python/optimizer_mixins/test_review.py index c27b3a631..52962af28 100644 --- a/src/codeflash_python/optimizer_mixins/test_review.py +++ b/src/codeflash_python/optimizer_mixins/test_review.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import TestType from codeflash_core.danom import Err, Ok from codeflash_python.code_utils.code_utils import encoded_tokens_len, module_name_from_file_path from codeflash_python.code_utils.config_consts import ( @@ -13,7 +14,6 @@ MAX_TEST_REPAIR_CYCLES, OPTIMIZATION_CONTEXT_TOKEN_LIMIT, ) -from codeflash_python.models.models import TestType from codeflash_python.telemetry.posthog_cf import ph from codeflash_python.verification.edit_generated_tests import remove_test_functions from codeflash_python.verification.test_runner import process_generated_test_strings @@ -21,8 +21,8 @@ if TYPE_CHECKING: from typing import Any + from codeflash.models.models import CodeOptimizationContext, CoverageData, GeneratedTestsList, TestResults from codeflash_core.danom import Result - from codeflash_python.models.models import CodeOptimizationContext, CoverageData, GeneratedTestsList, TestResults from codeflash_python.optimizer_mixins._protocol import FunctionOptimizerProtocol as _Base else: _Base = object diff --git a/src/codeflash_python/plugin.py b/src/codeflash_python/plugin.py index 0c0131fbf..37033e5ee 100644 --- a/src/codeflash_python/plugin.py +++ b/src/codeflash_python/plugin.py @@ -394,11 +394,11 @@ def replace_function_full(self, function: FunctionToOptimize, internal_ctx: obje """Port of FunctionOptimizer.replace_function_and_helpers_with_optimized_code.""" from collections import defaultdict + from codeflash.models.models import CodeStringsMarkdown from codeflash_python.context.unused_helper_detection import ( detect_unused_helper_functions, revert_unused_helper_functions, ) - from codeflash_python.models.models import CodeStringsMarkdown from codeflash_python.static_analysis.code_replacer import replace_function_definitions_in_module optimized_code = CodeStringsMarkdown.parse_markdown_code(code_markdown) diff --git a/src/codeflash_python/plugin_helpers.py b/src/codeflash_python/plugin_helpers.py index 15da19e06..27961cd58 100644 --- a/src/codeflash_python/plugin_helpers.py +++ b/src/codeflash_python/plugin_helpers.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: from typing import Any + from codeflash.models.models import OptimizedCandidateSource from codeflash_core.models import CoverageData, FunctionToOptimize - from codeflash_python.models.models import OptimizedCandidateSource logger = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def read_return_values(test_iteration: int) -> dict[str, list[object]]: def map_candidate_source(source: str) -> OptimizedCandidateSource: """Map core Candidate.source string to OptimizedCandidateSource enum value.""" - from codeflash_python.models.models import OptimizedCandidateSource + from codeflash.models.models import OptimizedCandidateSource mapping = { "optimize": OptimizedCandidateSource.OPTIMIZE, diff --git a/src/codeflash_python/plugin_results.py b/src/codeflash_python/plugin_results.py index d436d66c4..eb16746bb 100644 --- a/src/codeflash_python/plugin_results.py +++ b/src/codeflash_python/plugin_results.py @@ -83,7 +83,7 @@ def create_pr( trace_id: str = "", generated_tests: GeneratedTestSuite | None = None, ) -> str | None: - from codeflash_python.models.models import TestResults as InternalTestResults + from codeflash.models.models import TestResults as InternalTestResults from codeflash_python.result.create_pr import check_create_pr from codeflash_python.result.explanation import Explanation diff --git a/src/codeflash_python/result/create_pr.py b/src/codeflash_python/result/create_pr.py index fd9041a1a..91a1aa8a1 100644 --- a/src/codeflash_python/result/create_pr.py +++ b/src/codeflash_python/result/create_pr.py @@ -18,8 +18,8 @@ from codeflash_python.static_analysis.code_replacer import is_zero_diff if TYPE_CHECKING: + from codeflash.models.models import FunctionCalledInTest, InvocationId, TestFiles from codeflash_core.config import TestConfig - from codeflash_python.models.models import FunctionCalledInTest, InvocationId, TestFiles from codeflash_python.result.explanation import Explanation diff --git a/src/codeflash_python/result/critic.py b/src/codeflash_python/result/critic.py index a49ed3054..305d9ee24 100644 --- a/src/codeflash_python/result/critic.py +++ b/src/codeflash_python/result/critic.py @@ -14,12 +14,7 @@ from codeflash_python.models.test_type import TestType if TYPE_CHECKING: - from codeflash_python.models.models import ( - ConcurrencyMetrics, - CoverageData, - OptimizedCandidateResult, - OriginalCodeBaseline, - ) + from codeflash.models.models import ConcurrencyMetrics, CoverageData, OptimizedCandidateResult, OriginalCodeBaseline class AcceptanceReason(Enum): diff --git a/src/codeflash_python/result/explanation.py b/src/codeflash_python/result/explanation.py index 93ada3726..7b7b160c6 100644 --- a/src/codeflash_python/result/explanation.py +++ b/src/codeflash_python/result/explanation.py @@ -4,8 +4,8 @@ from pydantic.dataclasses import dataclass +from codeflash.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults from codeflash_python.code_utils.time_utils import humanize_runtime -from codeflash_python.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults from codeflash_python.result.critic import AcceptanceReason, concurrency_gain, throughput_gain diff --git a/src/codeflash_python/result/pr_comment.py b/src/codeflash_python/result/pr_comment.py index 815bd6463..f440dd23a 100644 --- a/src/codeflash_python/result/pr_comment.py +++ b/src/codeflash_python/result/pr_comment.py @@ -3,8 +3,8 @@ from pydantic import BaseModel from pydantic.dataclasses import dataclass +from codeflash.models.models import BenchmarkDetail, TestResults from codeflash_python.code_utils.time_utils import humanize_runtime -from codeflash_python.models.models import BenchmarkDetail, TestResults @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) diff --git a/src/codeflash_python/static_analysis/code_replacer.py b/src/codeflash_python/static_analysis/code_replacer.py index 3c74ec593..ce05ba063 100644 --- a/src/codeflash_python/static_analysis/code_replacer.py +++ b/src/codeflash_python/static_analysis/code_replacer.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from pathlib import Path - from codeflash_python.models.models import CodeStringsMarkdown + from codeflash.models.models import CodeStringsMarkdown logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/static_analysis/code_replacer_base.py b/src/codeflash_python/static_analysis/code_replacer_base.py index a26315cfe..872113e53 100644 --- a/src/codeflash_python/static_analysis/code_replacer_base.py +++ b/src/codeflash_python/static_analysis/code_replacer_base.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from codeflash_python.models.models import CodeStringsMarkdown + from codeflash.models.models import CodeStringsMarkdown def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str: diff --git a/src/codeflash_python/static_analysis/coverage_utils.py b/src/codeflash_python/static_analysis/coverage_utils.py index 8adc3858c..4e8292aa7 100644 --- a/src/codeflash_python/static_analysis/coverage_utils.py +++ b/src/codeflash_python/static_analysis/coverage_utils.py @@ -7,7 +7,7 @@ from codeflash_python.code_utils.code_utils import get_run_tmp_file if TYPE_CHECKING: - from codeflash_python.models.models import CodeOptimizationContext + from codeflash.models.models import CodeOptimizationContext def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: diff --git a/src/codeflash_python/static_analysis/import_analysis.py b/src/codeflash_python/static_analysis/import_analysis.py index 6cea2de73..a3e744082 100644 --- a/src/codeflash_python/static_analysis/import_analysis.py +++ b/src/codeflash_python/static_analysis/import_analysis.py @@ -14,7 +14,7 @@ from libcst.helpers import ModuleNameAndPackage - from codeflash_python.models.models import FunctionSource + from codeflash.models.models import FunctionSource logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/static_analysis/line_profile_utils.py b/src/codeflash_python/static_analysis/line_profile_utils.py index 5d8ff603d..a080c58cf 100644 --- a/src/codeflash_python/static_analysis/line_profile_utils.py +++ b/src/codeflash_python/static_analysis/line_profile_utils.py @@ -13,8 +13,8 @@ from codeflash_python.code_utils.formatter import sort_imports if TYPE_CHECKING: + from codeflash.models.models import CodeOptimizationContext from codeflash_core.models import FunctionToOptimize - from codeflash_python.models.models import CodeOptimizationContext # Known JIT decorators organized by module # Format: {module_path: {decorator_name, ...}} diff --git a/src/codeflash_python/verification/async_instrumentation.py b/src/codeflash_python/verification/async_instrumentation.py index 6046e1c3e..ace828d40 100644 --- a/src/codeflash_python/verification/async_instrumentation.py +++ b/src/codeflash_python/verification/async_instrumentation.py @@ -5,8 +5,8 @@ import libcst as cst +from codeflash.models.models import TestingMode from codeflash_python.code_utils.formatter import sort_imports -from codeflash_python.models.models import TestingMode if TYPE_CHECKING: from pathlib import Path diff --git a/src/codeflash_python/verification/coverage_utils.py b/src/codeflash_python/verification/coverage_utils.py index c126611c8..3a13c0cf4 100644 --- a/src/codeflash_python/verification/coverage_utils.py +++ b/src/codeflash_python/verification/coverage_utils.py @@ -7,7 +7,7 @@ import sentry_sdk from coverage.exceptions import NoDataError -from codeflash_python.models.models import CoverageData, CoverageStatus, FunctionCoverage +from codeflash.models.models import CoverageData, CoverageStatus, FunctionCoverage from codeflash_python.static_analysis.coverage_utils import ( build_fully_qualified_name, extract_dependent_function, @@ -18,7 +18,7 @@ from collections.abc import Collection from pathlib import Path - from codeflash_python.models.models import CodeOptimizationContext + from codeflash.models.models import CodeOptimizationContext logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/verification/edit_generated_tests.py b/src/codeflash_python/verification/edit_generated_tests.py index aad437d25..6f0004c6c 100644 --- a/src/codeflash_python/verification/edit_generated_tests.py +++ b/src/codeflash_python/verification/edit_generated_tests.py @@ -11,12 +11,12 @@ from libcst import MetadataWrapper from libcst.metadata import PositionProvider +from codeflash.models.models import GeneratedTests, GeneratedTestsList from codeflash_python.code_utils.time_utils import format_perf, format_time -from codeflash_python.models.models import GeneratedTests, GeneratedTestsList from codeflash_python.result.critic import performance_gain if TYPE_CHECKING: - from codeflash_python.models.models import InvocationId + from codeflash.models.models import InvocationId logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/verification/equivalence.py b/src/codeflash_python/verification/equivalence.py index 825d20a06..989d47981 100644 --- a/src/codeflash_python/verification/equivalence.py +++ b/src/codeflash_python/verification/equivalence.py @@ -8,14 +8,14 @@ import libcst as cst +from codeflash.models.models import TestResults, TestType, VerificationType from codeflash_python.api.types import TestDiff, TestDiffScope -from codeflash_python.models.models import TestResults, TestType, VerificationType from codeflash_python.verification.comparator import comparator if TYPE_CHECKING: from pathlib import Path - from codeflash_python.models.models import InvocationId, TestResults + from codeflash.models.models import InvocationId, TestResults logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/verification/instrument_existing_tests.py b/src/codeflash_python/verification/instrument_existing_tests.py index 1ea968ae7..a17ac932d 100644 --- a/src/codeflash_python/verification/instrument_existing_tests.py +++ b/src/codeflash_python/verification/instrument_existing_tests.py @@ -6,17 +6,17 @@ from pathlib import Path from typing import TYPE_CHECKING +from codeflash.models.models import TestingMode from codeflash_core.models import FunctionParent, FunctionToOptimize from codeflash_python.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash_python.code_utils.formatter import sort_imports -from codeflash_python.models.models import TestingMode from codeflash_python.verification.device_sync import detect_frameworks_from_code from codeflash_python.verification.wrapper_generation import create_wrapper_function if TYPE_CHECKING: from collections.abc import Iterable - from codeflash_python.models.models import CodePosition + from codeflash.models.models import CodePosition logger = logging.getLogger("codeflash_python") diff --git a/src/codeflash_python/verification/parse_test_output.py b/src/codeflash_python/verification/parse_test_output.py index 6650bf797..250939ef3 100644 --- a/src/codeflash_python/verification/parse_test_output.py +++ b/src/codeflash_python/verification/parse_test_output.py @@ -10,16 +10,16 @@ import dill as pickle from lxml.etree import XMLParser, parse # type: ignore[import-not-found] +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType from codeflash_python.code_utils.code_utils import get_run_tmp_file -from codeflash_python.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType from codeflash_python.verification.path_utils import file_path_from_module_name from codeflash_python.verification.test_output_utils import merge_test_results, parse_test_failures_from_stdout if TYPE_CHECKING: import subprocess + from codeflash.models.models import CodeOptimizationContext, CoverageData, TestFiles from codeflash_core.config import TestConfig - from codeflash_python.models.models import CodeOptimizationContext, CoverageData, TestFiles logger = logging.getLogger("codeflash_python") DEBUG_MODE = os.environ.get("CODEFLASH_DEBUG", "").lower() in ("1", "true") diff --git a/src/codeflash_python/verification/parse_xml.py b/src/codeflash_python/verification/parse_xml.py index 232b8c153..fcd325a41 100644 --- a/src/codeflash_python/verification/parse_xml.py +++ b/src/codeflash_python/verification/parse_xml.py @@ -14,8 +14,8 @@ from junitparser.xunit2 import JUnitXml +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults from codeflash_python.code_utils.code_utils import module_name_from_file_path -from codeflash_python.models.models import FunctionTestInvocation, InvocationId, TestResults from codeflash_python.verification.path_utils import file_path_from_module_name logger = logging.getLogger("codeflash_python") @@ -26,8 +26,8 @@ from lxml import etree # type: ignore[import-not-found] + from codeflash.models.models import TestFiles from codeflash_core.config import TestConfig - from codeflash_python.models.models import TestFiles matches_re_start = re.compile( r"!\$######([^:]*)" # group 1: module path diff --git a/src/codeflash_python/verification/test_output_utils.py b/src/codeflash_python/verification/test_output_utils.py index 52e0c8965..541923396 100644 --- a/src/codeflash_python/verification/test_output_utils.py +++ b/src/codeflash_python/verification/test_output_utils.py @@ -7,8 +7,8 @@ from collections import defaultdict from typing import TYPE_CHECKING +from codeflash.models.models import ConcurrencyMetrics, FunctionTestInvocation, TestResults, VerificationType from codeflash_python.discovery.discover_unit_tests import discover_parameters_unittest -from codeflash_python.models.models import ConcurrencyMetrics, FunctionTestInvocation, TestResults, VerificationType from codeflash_python.verification.path_utils import file_name_from_test_module_name if TYPE_CHECKING: diff --git a/src/codeflash_python/verification/test_runner.py b/src/codeflash_python/verification/test_runner.py index 32a619d72..c1b0bca95 100644 --- a/src/codeflash_python/verification/test_runner.py +++ b/src/codeflash_python/verification/test_runner.py @@ -256,10 +256,10 @@ def run_behavioral_tests( import shlex import sys + from codeflash.models.models import TestType from codeflash_python.code_utils.code_utils import get_run_tmp_file from codeflash_python.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash_python.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE - from codeflash_python.models.models import TestType from codeflash_python.static_analysis.coverage_utils import prepare_coverage_files blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"] diff --git a/src/codeflash_python/verification/wrapper_generation.py b/src/codeflash_python/verification/wrapper_generation.py index af24638c4..2650911e7 100644 --- a/src/codeflash_python/verification/wrapper_generation.py +++ b/src/codeflash_python/verification/wrapper_generation.py @@ -3,7 +3,7 @@ import ast import logging -from codeflash_python.models.models import TestingMode, VerificationType +from codeflash.models.models import TestingMode, VerificationType from codeflash_python.verification.device_sync import ( create_device_sync_precompute_statements, create_device_sync_statements, From f9188ed296f5e64c2e922390a9791dc3614cf6a9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 24 Mar 2026 08:23:31 -0500 Subject: [PATCH 8/9] style: fix lint errors and add per-file-ignores for codeflash_python Add ruff per-file-ignores for pre-existing PTH110, PTH123, PD011, E721 in src/codeflash_python/. Fix TC003 in addopts.py. --- pyproject.toml | 8 ++++++++ src/codeflash_python/verification/addopts.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3bfb4399d..03ff09934 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -304,6 +304,14 @@ ignore = [ "PTH119", # os.path.basename — faster than Path().name for string paths ] +[tool.ruff.lint.per-file-ignores] +"src/codeflash_python/**" = [ + "PTH110", # os.path.exists — used deliberately for performance in hot paths + "PTH123", # open() — pre-existing, not worth churning + "PD011", # .values — false positive on non-pandas objects + "E721", # type() == — intentional for MappingProxyType checks +] + [tool.ruff.lint.flake8-type-checking] strict = true runtime-evaluated-base-classes = ["pydantic.BaseModel"] diff --git a/src/codeflash_python/verification/addopts.py b/src/codeflash_python/verification/addopts.py index c42999402..505c5227a 100644 --- a/src/codeflash_python/verification/addopts.py +++ b/src/codeflash_python/verification/addopts.py @@ -2,12 +2,15 @@ import configparser import logging -from collections.abc import Generator from contextlib import contextmanager from pathlib import Path +from typing import TYPE_CHECKING import tomlkit +if TYPE_CHECKING: + from collections.abc import Generator + from codeflash_python.code_utils.config_parser import get_all_closest_config_files logger = logging.getLogger("codeflash_python") From 938abc0d3c5a0487d59e6fbc49e24bddd5297f0c Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:24:05 +0000 Subject: [PATCH 9/9] Optimize existing_tests_source_for The hot loop that processes invocation IDs now hoists three expensive operations outside the loop: `current_language_support()` (which imports and instantiates a registry lookup costing ~29 ms), `tests_root.resolve()` (filesystem stat calls adding ~1 ms), and constructing the Jest extensions tuple (repeated allocation overhead). Profiler data confirms `current_language_support()` consumed 99.8% of its 28.8 ms call time in a registry import, and moving it before the loop eliminates 17 redundant calls. Additionally, the optimized version skips `tabulate()` calls when row lists are empty, saving ~6-13 ms per empty table (three tables checked per invocation). These changes reduce the function's total time from 54.9 ms to 48.7 ms with no regressions. --- codeflash/result/create_pr.py | 113 ++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 3fd6dc31a..b827b3339 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -88,29 +88,39 @@ def existing_tests_source_for( logger.debug(f"[PR-DEBUG] Processing {len(all_invocation_ids)} invocation_ids") matched_count = 0 skipped_count = 0 + + # Precompute some costly or repeated values + # current_language_support may be somewhat expensive; call once and reuse + lang = current_language_support() + # resolve tests_root once + try: + tests_root_resolved = tests_root.resolve() + except Exception: + tests_root_resolved = tests_root + # tuple of jest extensions for quick endswith checks + jest_test_extensions = ( + ".test.ts", + ".test.js", + ".test.tsx", + ".test.jsx", + ".spec.ts", + ".spec.js", + ".spec.tsx", + ".spec.jsx", + ".ts", + ".js", + ".tsx", + ".jsx", + ".mjs", + ".mts", + ) + for invocation_id in all_invocation_ids: # For JavaScript/TypeScript, test_module_path could be: # - A module-style path with dots: "tests.fibonacci.test.ts" # - A file path: "tests/fibonacci.test.ts" # For Python, it's a module name (e.g., "tests.test_example") that needs conversion test_module_path = invocation_id.test_module_path - # Jest test file extensions (including .test.ts, .spec.ts patterns) - jest_test_extensions = ( - ".test.ts", - ".test.js", - ".test.tsx", - ".test.jsx", - ".spec.ts", - ".spec.js", - ".spec.tsx", - ".spec.jsx", - ".ts", - ".js", - ".tsx", - ".jsx", - ".mjs", - ".mts", - ) # Find the appropriate extension matched_ext = None for ext in jest_test_extensions: @@ -140,7 +150,6 @@ def existing_tests_source_for( else: logger.debug(f"[PR-DEBUG] No mapping found for {instrumented_abs_path.name}") else: - lang = current_language_support() # Let language-specific resolution handle non-Python module paths lang_result = lang.resolve_test_module_path_for_pr( test_module_path, test_cfg.tests_project_rootdir or test_cfg.project_root, non_generated_tests @@ -189,26 +198,20 @@ def existing_tests_source_for( ].keys() # both will have the same keys as some default values are assigned in the previous loop for qualified_name in sorted(all_qualified_names): # if not present in optimized output nan - if ( - original_tests_to_runtimes[filename][qualified_name] != 0 - and optimized_tests_to_runtimes[filename][qualified_name] != 0 - ): - print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name]) - print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name]) - print_filename = filename.resolve().relative_to(tests_root.resolve()).as_posix() - greater = ( - optimized_tests_to_runtimes[filename][qualified_name] - > original_tests_to_runtimes[filename][qualified_name] - ) + orig_val = original_tests_to_runtimes[filename][qualified_name] + opt_val = optimized_tests_to_runtimes[filename][qualified_name] + if orig_val != 0 and opt_val != 0: + print_optimized_runtime = format_time(opt_val) + print_original_runtime = format_time(orig_val) + # Reuse resolved tests_root for relative computation + print_filename = filename.resolve().relative_to(tests_root_resolved).as_posix() + print_filename_str = str(print_filename) + greater = opt_val > orig_val perf_gain = format_perf( - performance_gain( - original_runtime_ns=original_tests_to_runtimes[filename][qualified_name], - optimized_runtime_ns=optimized_tests_to_runtimes[filename][qualified_name], - ) - * 100 + performance_gain(original_runtime_ns=orig_val, optimized_runtime_ns=opt_val) * 100 ) if greater: - if "__replay_test_" in str(print_filename): + if "__replay_test_" in print_filename_str: rows_replay.append( [ f"`{print_filename}::{qualified_name}`", @@ -217,7 +220,7 @@ def existing_tests_source_for( f"{perf_gain}%⚠️", ] ) - elif "codeflash_concolic" in str(print_filename): + elif "codeflash_concolic" in print_filename_str: rows_concolic.append( [ f"`{print_filename}::{qualified_name}`", @@ -235,7 +238,7 @@ def existing_tests_source_for( f"{perf_gain}%⚠️", ] ) - elif "__replay_test_" in str(print_filename): + elif "__replay_test_" in print_filename_str: rows_replay.append( [ f"`{print_filename}::{qualified_name}`", @@ -244,7 +247,7 @@ def existing_tests_source_for( f"{perf_gain}%✅", ] ) - elif "codeflash_concolic" in str(print_filename): + elif "codeflash_concolic" in print_filename_str: rows_concolic.append( [ f"`{print_filename}::{qualified_name}`", @@ -262,23 +265,27 @@ def existing_tests_source_for( f"{perf_gain}%✅", ] ) - output_existing += tabulate( - headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True - ) - output_existing += "\n" - if len(rows_existing) == 0: + # Only call tabulate if we have rows to format (avoid expensive tabulate calls for empty lists) + if rows_existing: + output_existing += tabulate( + headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_existing += "\n" + else: output_existing = "" - output_concolic += tabulate( - headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True - ) - output_concolic += "\n" - if len(rows_concolic) == 0: + if rows_concolic: + output_concolic += tabulate( + headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_concolic += "\n" + else: output_concolic = "" - output_replay += tabulate( - headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True - ) - output_replay += "\n" - if len(rows_replay) == 0: + if rows_replay: + output_replay += tabulate( + headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output_replay += "\n" + else: output_replay = "" return output_existing, output_replay, output_concolic