diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index bcdabeb8d..b5fd583c8 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -544,11 +544,12 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio # === Validation === - def validate_syntax(self, source: str) -> bool: + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: """Check if source code is syntactically valid. Args: source: Source code to validate. + file_path: Optional file path for parser selection (e.g., .tsx vs .ts). Returns: True if valid, False otherwise. diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py index 825c7e7da..9e6149e1b 100644 --- a/codeflash/languages/java/support.py +++ b/codeflash/languages/java/support.py @@ -275,7 +275,7 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio # === Validation === - def validate_syntax(self, source: str) -> bool: + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: """Check if Java source code is syntactically valid.""" return self._analyzer.validate_syntax(source) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 039d1ce98..d891f7aed 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -413,7 +413,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, # Validate that the extracted code is syntactically valid # If not, raise an error to fail the optimization early - if target_code and not self.validate_syntax(target_code): + if target_code and not self.validate_syntax(target_code, file_path=function.file_path): error_msg = ( f"Extracted code for {function.function_name} is not syntactically valid JavaScript. " f"Cannot proceed with optimization." @@ -1712,10 +1712,13 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio def treesitter_language(self) -> TreeSitterLanguage: return TreeSitterLanguage.JAVASCRIPT - def validate_syntax(self, source: str) -> bool: + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: """Check if source code is syntactically valid using tree-sitter.""" try: - analyzer = TreeSitterAnalyzer(self.treesitter_language) + if file_path is not None: + analyzer = get_analyzer_for_file(file_path) + else: + analyzer = TreeSitterAnalyzer(self.treesitter_language) tree = analyzer.parse(source) return not tree.root_node.has_error except Exception: diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index ccf74ea86..606292977 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -33,6 +33,10 @@ from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode from codeflash.verification.verification_utils import TestConfig +_CACHE: dict[str, bool] = {} + +_CACHE_MAX: int = 4096 + logger = logging.getLogger(__name__) @@ -673,7 +677,7 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio # === Validation === - def validate_syntax(self, source: str) -> bool: + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: """Check if Python source code is syntactically valid. Uses Python's compile() to validate syntax. @@ -685,11 +689,7 @@ def validate_syntax(self, source: str) -> bool: True if valid, False otherwise. """ - try: - compile(source, "", "exec") - return True - except SyntaxError: - return False + return _compile_ok(source) def normalize_code(self, source: str) -> str: from codeflash.languages.python.normalizer import normalize_python_code @@ -1361,3 +1361,19 @@ def generate_concolic_tests( end_time = time.perf_counter() logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time) return function_to_concolic_tests, concolic_test_suite_code + + +def _compile_ok(source: str) -> bool: + try: + cached = _CACHE.get(source) + if cached is not None: + return cached + + compile(source, "", "exec") + if len(_CACHE) < _CACHE_MAX: + _CACHE[source] = True + return True + except SyntaxError: + if len(_CACHE) < _CACHE_MAX: + _CACHE[source] = False + return False diff --git a/codeflash/models/models.py b/codeflash/models/models.py index b8345dc2f..0296ab24e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -259,7 +259,7 @@ def validate_code_syntax(self) -> CodeString: from codeflash.languages.registry import get_language_support lang_support = get_language_support(self.language) - if not lang_support.validate_syntax(self.code): + if not lang_support.validate_syntax(self.code, file_path=self.file_path): msg = f"Invalid {self.language.title()} code" raise ValueError(msg) return self diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 5d5943151..922b8212f 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -442,6 +442,44 @@ def test_syntax_error_types(self, js_support): assert js_support.validate_syntax("function foo() {") is False + def test_tsx_jsx_syntax_valid_with_file_path(self): + """Test that TSX/JSX syntax is valid when file_path with .tsx extension is provided.""" + from codeflash.languages.javascript.support import TypeScriptSupport + + ts_support = TypeScriptSupport() + + tsx_code = """ +function VersionHeader({ version }) { + return ( +
+

{version.name}

+ {version.date} +
+ ); +} +""" + # Without file_path, TypeScriptSupport uses TYPESCRIPT parser which can't handle JSX + assert ts_support.validate_syntax(tsx_code) is False + + # With .tsx file_path, it should use TSX parser and pass + tsx_path = Path("/tmp/test.tsx") + assert ts_support.validate_syntax(tsx_code, file_path=tsx_path) is True + + def test_tsx_jsx_syntax_valid_with_jsx_file_path(self, js_support): + """Test that JSX syntax is valid when file_path with .jsx extension is provided.""" + jsx_code = """ +function Button({ label, onClick }) { + return ; +} +""" + # JavaScript parser handles JSX natively + assert js_support.validate_syntax(jsx_code) is True + + # Explicit .jsx path should also work + jsx_path = Path("/tmp/test.jsx") + assert js_support.validate_syntax(jsx_code, file_path=jsx_path) is True + + class TestNormalizeCode: """Tests for normalize_code method using tree-sitter normalizer."""