Skip to content
Merged
3 changes: 2 additions & 1 deletion codeflash/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,11 +544,12 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio

# === Validation ===

def validate_syntax(self, source: str) -> bool:
def validate_syntax(self, source: str, file_path: Path | None = None) -> bool:
"""Check if source code is syntactically valid.
Args:
source: Source code to validate.
file_path: Optional file path for parser selection (e.g., .tsx vs .ts).
Returns:
True if valid, False otherwise.
Expand Down
2 changes: 1 addition & 1 deletion codeflash/languages/java/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio

# === Validation ===

def validate_syntax(self, source: str) -> bool:
def validate_syntax(self, source: str, file_path: Path | None = None) -> bool:
"""Check if Java source code is syntactically valid."""
return self._analyzer.validate_syntax(source)

Expand Down
9 changes: 6 additions & 3 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path,

# Validate that the extracted code is syntactically valid
# If not, raise an error to fail the optimization early
if target_code and not self.validate_syntax(target_code):
if target_code and not self.validate_syntax(target_code, file_path=function.file_path):
error_msg = (
f"Extracted code for {function.function_name} is not syntactically valid JavaScript. "
f"Cannot proceed with optimization."
Expand Down Expand Up @@ -1712,10 +1712,13 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio
def treesitter_language(self) -> TreeSitterLanguage:
return TreeSitterLanguage.JAVASCRIPT

def validate_syntax(self, source: str) -> bool:
def validate_syntax(self, source: str, file_path: Path | None = None) -> bool:
"""Check if source code is syntactically valid using tree-sitter."""
try:
analyzer = TreeSitterAnalyzer(self.treesitter_language)
if file_path is not None:
analyzer = get_analyzer_for_file(file_path)
else:
analyzer = TreeSitterAnalyzer(self.treesitter_language)
tree = analyzer.parse(source)
return not tree.root_node.has_error
except Exception:
Expand Down
28 changes: 22 additions & 6 deletions codeflash/languages/python/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
from codeflash.verification.verification_utils import TestConfig

_CACHE: dict[str, bool] = {}

_CACHE_MAX: int = 4096

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -673,7 +677,7 @@ def instrument_for_benchmarking(self, test_source: str, target_function: Functio

# === Validation ===

def validate_syntax(self, source: str) -> bool:
def validate_syntax(self, source: str, file_path: Path | None = None) -> bool:
"""Check if Python source code is syntactically valid.

Uses Python's compile() to validate syntax.
Expand All @@ -685,11 +689,7 @@ def validate_syntax(self, source: str) -> bool:
True if valid, False otherwise.

"""
try:
compile(source, "<string>", "exec")
return True
except SyntaxError:
return False
return _compile_ok(source)

def normalize_code(self, source: str) -> str:
from codeflash.languages.python.normalizer import normalize_python_code
Expand Down Expand Up @@ -1361,3 +1361,19 @@ def generate_concolic_tests(
end_time = time.perf_counter()
logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time)
return function_to_concolic_tests, concolic_test_suite_code


def _compile_ok(source: str) -> bool:
try:
cached = _CACHE.get(source)
if cached is not None:
return cached

compile(source, "<string>", "exec")
if len(_CACHE) < _CACHE_MAX:
_CACHE[source] = True
return True
except SyntaxError:
if len(_CACHE) < _CACHE_MAX:
_CACHE[source] = False
return False
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def validate_code_syntax(self) -> CodeString:
from codeflash.languages.registry import get_language_support

lang_support = get_language_support(self.language)
if not lang_support.validate_syntax(self.code):
if not lang_support.validate_syntax(self.code, file_path=self.file_path):
msg = f"Invalid {self.language.title()} code"
raise ValueError(msg)
return self
Expand Down
38 changes: 38 additions & 0 deletions tests/test_languages/test_javascript_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,44 @@ def test_syntax_error_types(self, js_support):
assert js_support.validate_syntax("function foo() {") is False


def test_tsx_jsx_syntax_valid_with_file_path(self):
"""Test that TSX/JSX syntax is valid when file_path with .tsx extension is provided."""
from codeflash.languages.javascript.support import TypeScriptSupport

ts_support = TypeScriptSupport()

tsx_code = """
function VersionHeader({ version }) {
return (
<div className="header">
<h1>{version.name}</h1>
<span>{version.date}</span>
</div>
);
}
"""
# Without file_path, TypeScriptSupport uses TYPESCRIPT parser which can't handle JSX
assert ts_support.validate_syntax(tsx_code) is False

# With .tsx file_path, it should use TSX parser and pass
tsx_path = Path("/tmp/test.tsx")
assert ts_support.validate_syntax(tsx_code, file_path=tsx_path) is True

def test_tsx_jsx_syntax_valid_with_jsx_file_path(self, js_support):
"""Test that JSX syntax is valid when file_path with .jsx extension is provided."""
jsx_code = """
function Button({ label, onClick }) {
return <button onClick={onClick}>{label}</button>;
}
"""
# JavaScript parser handles JSX natively
assert js_support.validate_syntax(jsx_code) is True

# Explicit .jsx path should also work
jsx_path = Path("/tmp/test.jsx")
assert js_support.validate_syntax(jsx_code, file_path=jsx_path) is True


class TestNormalizeCode:
"""Tests for normalize_code method using tree-sitter normalizer."""

Expand Down
Loading