diff --git a/code_to_optimize/java/src/main/java/com/example/Fibonacci.java b/code_to_optimize/java/src/main/java/com/example/Fibonacci.java index b604fb928..8772905e2 100644 --- a/code_to_optimize/java/src/main/java/com/example/Fibonacci.java +++ b/code_to_optimize/java/src/main/java/com/example/Fibonacci.java @@ -172,4 +172,77 @@ public static boolean areConsecutiveFibonacci(long a, long b) { return Math.abs(indexA - indexB) == 1; } + + /** + * Sort an array in-place using bubble sort. + * Intentionally naive O(n^2) implementation for optimization testing. + * + * @param arr Array to sort (modified in-place) + */ + public static void sortArray(long[] arr) { + if (arr == null) { + throw new IllegalArgumentException("Array must not be null"); + } + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr.length - 1 - i; j++) { + if (arr[j] > arr[j + 1]) { + long temp = arr[j]; + arr[j] = arr[j + 1]; + arr[j + 1] = temp; + } + } + } + } + + /** + * Append Fibonacci numbers up to a limit into the provided list. + * Clears the list first, then fills it with Fibonacci numbers less than limit. + * Uses repeated naive recursion — intentionally slow for optimization testing. + * + * @param output List to populate (cleared first) + * @param limit Upper bound (exclusive) + */ + public static void collectFibonacciInto(List output, long limit) { + if (output == null) { + throw new IllegalArgumentException("Output list must not be null"); + } + output.clear(); + + if (limit <= 0) { + return; + } + + int index = 0; + while (true) { + long fib = fibonacci(index); + if (fib >= limit) { + break; + } + output.add(fib); + index++; + if (index > 50) { + break; + } + } + } + + /** + * Compute running Fibonacci sums in-place. + * result[i] = sum of fibonacci(0) through fibonacci(i). + * Uses repeated naive recursion — intentionally O(n * 2^n). + * + * @param result Array to fill with running sums (must be pre-allocated) + */ + public static void fillFibonacciRunningSums(long[] result) { + if (result == null) { + throw new IllegalArgumentException("Array must not be null"); + } + for (int i = 0; i < result.length; i++) { + long sum = 0; + for (int j = 0; j <= i; j++) { + sum += fibonacci(j); + } + result[i] = sum; + } + } } diff --git a/code_to_optimize/java/src/test/java/com/example/FibonacciTest.java b/code_to_optimize/java/src/test/java/com/example/FibonacciTest.java index 86724917d..f58fd87b6 100644 --- a/code_to_optimize/java/src/test/java/com/example/FibonacciTest.java +++ b/code_to_optimize/java/src/test/java/com/example/FibonacciTest.java @@ -1,6 +1,7 @@ package com.example; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -136,4 +137,105 @@ void testAreConsecutiveFibonacci() { // Non-Fibonacci number assertFalse(Fibonacci.areConsecutiveFibonacci(4, 5)); // 4 is not Fibonacci } + + @Test + void testSortArray() { + long[] arr = {5, 3, 8, 1, 2, 7, 4, 6}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{1, 2, 3, 4, 5, 6, 7, 8}, arr); + } + + @Test + void testSortArrayAlreadySorted() { + long[] arr = {1, 2, 3, 4, 5}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{1, 2, 3, 4, 5}, arr); + } + + @Test + void testSortArrayReversed() { + long[] arr = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, arr); + } + + @Test + void testSortArrayDuplicates() { + long[] arr = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{1, 1, 2, 3, 3, 4, 5, 5, 6, 9}, arr); + } + + @Test + void testSortArrayEmpty() { + long[] arr = {}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{}, arr); + } + + @Test + void testSortArraySingle() { + long[] arr = {42}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{42}, arr); + } + + @Test + void testSortArrayNegatives() { + long[] arr = {-3, -1, -4, -1, -5}; + Fibonacci.sortArray(arr); + assertArrayEquals(new long[]{-5, -4, -3, -1, -1}, arr); + } + + @Test + void testSortArrayNull() { + assertThrows(IllegalArgumentException.class, () -> Fibonacci.sortArray(null)); + } + + @Test + void testCollectFibonacciInto() { + List output = new ArrayList<>(); + Fibonacci.collectFibonacciInto(output, 10); + assertEquals(7, output.size()); + assertEquals(List.of(0L, 1L, 1L, 2L, 3L, 5L, 8L), output); + } + + @Test + void testCollectFibonacciIntoZeroLimit() { + List output = new ArrayList<>(); + Fibonacci.collectFibonacciInto(output, 0); + assertTrue(output.isEmpty()); + } + + @Test + void testCollectFibonacciIntoClearsExisting() { + List output = new ArrayList<>(List.of(99L, 100L)); + Fibonacci.collectFibonacciInto(output, 5); + assertEquals(List.of(0L, 1L, 1L, 2L, 3L), output); + } + + @Test + void testCollectFibonacciIntoNull() { + assertThrows(IllegalArgumentException.class, () -> Fibonacci.collectFibonacciInto(null, 10)); + } + + @Test + void testFillFibonacciRunningSums() { + long[] result = new long[6]; + Fibonacci.fillFibonacciRunningSums(result); + // sums: fib(0)=0, 0+1=1, 0+1+1=2, 0+1+1+2=4, 0+1+1+2+3=7, 0+1+1+2+3+5=12 + assertArrayEquals(new long[]{0, 1, 2, 4, 7, 12}, result); + } + + @Test + void testFillFibonacciRunningSumsEmpty() { + long[] result = new long[0]; + Fibonacci.fillFibonacciRunningSums(result); + assertArrayEquals(new long[]{}, result); + } + + @Test + void testFillFibonacciRunningSumsNull() { + assertThrows(IllegalArgumentException.class, () -> Fibonacci.fillFibonacciRunningSums(null)); + } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/Comparator.java b/codeflash-java-runtime/src/main/java/com/codeflash/Comparator.java index 3bd62c897..3faee5e55 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/Comparator.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/Comparator.java @@ -73,8 +73,8 @@ public static void main(String[] args) { } static String compareDatabases(String originalDbPath, String candidateDbPath) throws Exception { - Map originalResults = readTestResults(originalDbPath); - Map candidateResults = readTestResults(candidateDbPath); + Map originalResults = readTestResults(originalDbPath); + Map candidateResults = readTestResults(candidateDbPath); Set allKeys = new LinkedHashSet<>(); allKeys.addAll(originalResults.keySet()); @@ -87,46 +87,50 @@ static String compareDatabases(String originalDbPath, String candidateDbPath) th int skippedDeserializationErrors = 0; for (String key : allKeys) { - byte[] origBytes = originalResults.get(key); - byte[] candBytes = candidateResults.get(key); + TestResult origResult = originalResults.get(key); + TestResult candResult = candidateResults.get(key); - if (origBytes == null && candBytes == null) { - // Both null (void methods) — a real comparison (void-to-void match) - actualComparisons++; - continue; - } + byte[] origBytes = origResult != null ? origResult.returnValue : null; + byte[] candBytes = candResult != null ? candResult.returnValue : null; - if (origBytes == null) { + if (origBytes == null && candBytes == null) { + // Both null (void methods) — check stdout still + } else if (origBytes == null) { Object candObj = safeDeserialize(candBytes); diffs.add(formatDiff("missing", key, 0, null, safeToString(candObj))); actualComparisons++; continue; - } - - if (candBytes == null) { + } else if (candBytes == null) { Object origObj = safeDeserialize(origBytes); diffs.add(formatDiff("missing", key, 0, safeToString(origObj), null)); actualComparisons++; continue; - } + } else { + Object origObj = safeDeserialize(origBytes); + Object candObj = safeDeserialize(candBytes); - Object origObj = safeDeserialize(origBytes); - Object candObj = safeDeserialize(candBytes); + if (isDeserializationError(origObj) || isDeserializationError(candObj)) { + skippedDeserializationErrors++; + continue; + } - if (isDeserializationError(origObj) || isDeserializationError(candObj)) { - skippedDeserializationErrors++; - continue; + try { + if (!compare(origObj, candObj)) { + diffs.add(formatDiff("return_value", key, 0, safeToString(origObj), safeToString(candObj))); + } + } catch (KryoPlaceholderAccessException e) { + skippedPlaceholders++; + continue; + } } - try { - if (!compare(origObj, candObj)) { - diffs.add(formatDiff("return_value", key, 0, safeToString(origObj), safeToString(candObj))); - } - actualComparisons++; - } catch (KryoPlaceholderAccessException e) { - skippedPlaceholders++; - continue; + // Compare stdout (for void methods and side-effect verification) + String origStdout = origResult != null ? origResult.stdout : null; + String candStdout = candResult != null ? candResult.stdout : null; + if (origStdout != null && candStdout != null && !origStdout.equals(candStdout)) { + diffs.add(formatDiff("stdout", key, 0, truncate(origStdout, 200), truncate(candStdout, 200))); } + actualComparisons++; } boolean equivalent = diffs.isEmpty() && actualComparisons > 0; @@ -154,31 +158,53 @@ static String compareDatabases(String originalDbPath, String candidateDbPath) th return json.toString(); } - private static Map readTestResults(String dbPath) throws Exception { - Map results = new LinkedHashMap<>(); + private static class TestResult { + final byte[] returnValue; + final String stdout; + + TestResult(byte[] returnValue, String stdout) { + this.returnValue = returnValue; + this.stdout = stdout; + } + } + + private static Map readTestResults(String dbPath) throws Exception { + Map results = new LinkedHashMap<>(); String url = "jdbc:sqlite:" + dbPath; try (Connection conn = DriverManager.getConnection(url); - Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery( - "SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value FROM test_results WHERE loop_index = 1")) { - while (rs.next()) { - String testModulePath = rs.getString("test_module_path"); - String testClassName = rs.getString("test_class_name"); - String testFunctionName = rs.getString("test_function_name"); - String iterationId = rs.getString("iteration_id"); - byte[] returnValue = rs.getBytes("return_value"); - // Strip the CODEFLASH_TEST_ITERATION suffix (e.g. "7_0" -> "7") - // Original runs with _0, candidate with _1, but the test iteration - // counter before the underscore is what identifies the invocation. - int lastUnderscore = iterationId.lastIndexOf('_'); - if (lastUnderscore > 0) { - iterationId = iterationId.substring(0, lastUnderscore); + Statement stmt = conn.createStatement()) { + + // Check if stdout column exists (backward compatibility) + boolean hasStdout = false; + try (ResultSet columns = conn.getMetaData().getColumns(null, null, "test_results", "stdout")) { + hasStdout = columns.next(); + } + + String query = hasStdout + ? "SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value, stdout FROM test_results WHERE loop_index = 1" + : "SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value FROM test_results WHERE loop_index = 1"; + + try (ResultSet rs = stmt.executeQuery(query)) { + while (rs.next()) { + String testModulePath = rs.getString("test_module_path"); + String testClassName = rs.getString("test_class_name"); + String testFunctionName = rs.getString("test_function_name"); + String iterationId = rs.getString("iteration_id"); + byte[] returnValue = rs.getBytes("return_value"); + String stdout = hasStdout ? rs.getString("stdout") : null; + // Strip the CODEFLASH_TEST_ITERATION suffix (e.g. "7_0" -> "7") + // Original runs with _0, candidate with _1, but the test iteration + // counter before the underscore is what identifies the invocation. + int lastUnderscore = iterationId.lastIndexOf('_'); + if (lastUnderscore > 0) { + iterationId = iterationId.substring(0, lastUnderscore); + } + // Use module:class:function:iteration as key to uniquely identify + // each invocation across different test files, classes, and methods + String key = testModulePath + ":" + testClassName + ":" + testFunctionName + "::" + iterationId; + results.put(key, new TestResult(returnValue, stdout)); } - // Use module:class:function:iteration as key to uniquely identify - // each invocation across different test files, classes, and methods - String key = testModulePath + ":" + testClassName + ":" + testFunctionName + "::" + iterationId; - results.put(key, returnValue); } } return results; @@ -214,6 +240,11 @@ private static String safeToString(Object obj) { } } + private static String truncate(String s, int maxLen) { + if (s == null || s.length() <= maxLen) return s; + return s.substring(0, maxLen) + "..."; + } + private static String formatDiff(String scope, String methodId, int callId, String originalValue, String candidateValue) { StringBuilder sb = new StringBuilder(); diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index f10718415..4db43a92d 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -741,6 +741,79 @@ def is_jacoco_configured(pom_path: Path) -> bool: return False +JACOCO_PROPERTY_NAME = "jacoco.agent.argLine" + + +def ensure_jacoco_property_name(pom_path: Path) -> bool: + """Ensure the existing JaCoCo prepare-agent writes to a custom property. + + If the project already has JaCoCo configured with the default ``argLine`` + property, we must redirect it to ``jacoco.agent.argLine`` so that our + ``-DargLine=@{jacoco.agent.argLine} ...`` can compose both the agent arg + and the add-opens flags without one overriding the other. + + Also adds an empty default ```` property to + ```` so the reference resolves even if prepare-agent didn't run. + """ + if not pom_path.exists(): + return False + + try: + content = pom_path.read_text(encoding="utf-8") + + # Already using our custom property — nothing to do + if f"{JACOCO_PROPERTY_NAME}" in content: + content = _ensure_pom_property(content, JACOCO_PROPERTY_NAME, "") + pom_path.write_text(content, encoding="utf-8") + return True + + # Find the prepare-agent execution and inject + import re + + # Match prepare-agent ... ...prepare-agent... + # and check whether a block already exists for it. + prepare_agent_pattern = re.compile( + r"(\s*prepare-agent.*?)(.*?)()", re.DOTALL + ) + match = prepare_agent_pattern.search(content) + if not match: + # No prepare-agent execution found — nothing to patch + logger.debug("No prepare-agent execution found in %s", pom_path) + content = _ensure_pom_property(content, JACOCO_PROPERTY_NAME, "") + pom_path.write_text(content, encoding="utf-8") + return True + + between = match.group(2) # text between and + if "" in between: + # Configuration block exists — inject propertyName inside it + content = ( + content[: match.start(2)] + + between.replace( + "", + f"\n {JACOCO_PROPERTY_NAME}", + ) + + content[match.end(2) :] + ) + else: + # No configuration block — add one before + config_block = ( + f"\n " + f"\n {JACOCO_PROPERTY_NAME}" + f"\n \n " + ) + insert_pos = match.start(3) + content = content[:insert_pos] + config_block + content[insert_pos:] + + content = _ensure_pom_property(content, JACOCO_PROPERTY_NAME, "") + pom_path.write_text(content, encoding="utf-8") + logger.info("Patched existing JaCoCo prepare-agent to use propertyName=%s", JACOCO_PROPERTY_NAME) + return True + + except Exception: + logger.exception("Failed to patch JaCoCo propertyName in %s", pom_path) + return False + + def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: """Add JaCoCo Maven plugin to pom.xml for coverage collection. @@ -785,6 +858,9 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: prepare-agent + + jacoco.agent.argLine + report @@ -861,6 +937,10 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: """ content = content[:project_end] + build_section + content[project_end:] + # Add a default empty property so @{jacoco.agent.argLine} resolves to "" + # if prepare-agent doesn't run (avoids passing a literal to the JVM). + content = _ensure_pom_property(content, "jacoco.agent.argLine", "") + pom_path.write_text(content, encoding="utf-8") logger.info("Added JaCoCo plugin to pom.xml") return True @@ -870,6 +950,36 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: return False +def _ensure_pom_property(content: str, prop_name: str, default_value: str) -> str: + """Ensure a Maven property exists in the pom.xml section. + + If the property already exists, the content is returned unchanged. + If there is no section, one is created. + """ + # Check if the property already exists + if f"<{prop_name}>" in content: + return content + + prop_xml = f" <{prop_name}>{default_value}\n " + + # Find main section (not inside ) + profiles_start = content.find("") + search_region = content[:profiles_start] if profiles_start != -1 else content + props_end = search_region.find("") + + if props_end != -1: + return content[:props_end] + prop_xml + content[props_end:] + + # No section — insert one before or + for anchor in ("", "", ""): + anchor_pos = search_region.find(anchor) + if anchor_pos != -1: + props_section = f" \n {prop_xml}\n\n" + return content[:anchor_pos] + props_section + content[anchor_pos:] + + return content + + def _find_closing_tag(content: str, start_pos: int, tag_name: str) -> int: """Find the position of the closing tag that matches the opening tag at start_pos. diff --git a/codeflash/languages/java/comparator.py b/codeflash/languages/java/comparator.py index aa15ef071..a287300c3 100644 --- a/codeflash/languages/java/comparator.py +++ b/codeflash/languages/java/comparator.py @@ -248,6 +248,8 @@ def compare_test_results( scope = TestDiffScope.RETURN_VALUE if scope_str in {"exception", "missing"}: scope = TestDiffScope.DID_PASS + elif scope_str == "stdout": + scope = TestDiffScope.STDOUT # Build test identifier method_id = diff.get("methodId", "unknown") @@ -287,7 +289,8 @@ def compare_test_results( "(total=%s, skipped_placeholders=%s, skipped_deser_errors=%s). " "Treating as NOT equivalent.", comparison.get("totalInvocations", 0), - skipped_placeholders, skipped_deser_errors, + skipped_placeholders, + skipped_deser_errors, ) return False, [] @@ -295,7 +298,10 @@ def compare_test_results( "Java comparison: %s (%s invocations, %s compared, %s placeholder skips, %s deser skips, %s diffs)", "equivalent" if equivalent else "DIFFERENT", comparison.get("totalInvocations", 0), - actual_comparisons, skipped_placeholders, skipped_deser_errors, len(test_diffs), + actual_comparisons, + skipped_placeholders, + skipped_deser_errors, + len(test_diffs), ) return equivalent, test_diffs diff --git a/codeflash/languages/java/discovery.py b/codeflash/languages/java/discovery.py index cb610cb18..ab6d2b7ca 100644 --- a/codeflash/languages/java/discovery.py +++ b/codeflash/languages/java/discovery.py @@ -104,6 +104,7 @@ def discover_functions_from_source( is_method=method.class_name is not None, language="java", doc_start_line=method.javadoc_start_line, + return_type=method.return_type, ) ) @@ -150,14 +151,10 @@ def _should_include_method( if criteria.matches_exclude_patterns(method.name): return False - # Check require_return - void methods don't have return values - - # Check require_return - void methods don't have return values + # Check require_return - void methods are allowed (verified via test pass/fail), + # but non-void methods must have an actual return statement if criteria.require_return: - if method.return_type == "void": - return False - # Also check if the method actually has a return statement - if not analyzer.has_return_statement(method, source): + if method.return_type != "void" and not analyzer.has_return_statement(method, source): return False # Check include_methods - in Java, all functions in classes are methods diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 806b4e619..af5d52948 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -227,19 +227,20 @@ def _generate_sqlite_write_code( f'{inner_indent} _cf_stmt{iter_id}_{call_counter}.execute("CREATE TABLE IF NOT EXISTS test_results (" +', f'{inner_indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +', f'{inner_indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +', - f'{inner_indent} "runtime INTEGER, return_value BLOB, verification_type TEXT)");', + f'{inner_indent} "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)");', f"{inner_indent} }}", - f'{inner_indent} String _cf_sql{iter_id}_{call_counter} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";', + f'{inner_indent} String _cf_sql{iter_id}_{call_counter} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";', f"{inner_indent} try (PreparedStatement _cf_pstmt{iter_id}_{call_counter} = _cf_conn{iter_id}_{call_counter}.prepareStatement(_cf_sql{iter_id}_{call_counter})) {{", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(1, _cf_mod{iter_id});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(2, _cf_cls{iter_id});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(3, _cf_test{iter_id});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(4, _cf_fn{iter_id});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setInt(5, _cf_loop{iter_id});", - f'{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(6, "{call_counter}");', + f'{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(6, _cf_cls{iter_id} + "." + _cf_test{iter_id} + ".{call_counter}_" + _cf_testIteration{iter_id});', f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setLong(7, _cf_dur{iter_id}_{call_counter});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setBytes(8, _cf_serializedResult{iter_id}_{call_counter});", f'{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(9, "function_call");', + f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.setString(10, _cf_stdout{iter_id}_{call_counter});", f"{inner_indent} _cf_pstmt{iter_id}_{call_counter}.executeUpdate();", f"{inner_indent} }}", f"{inner_indent} }}", @@ -251,6 +252,25 @@ def _generate_sqlite_write_code( ] +def _build_void_serialize_expr(call: dict[str, Any]) -> str: + """Build a Serializer.serialize(...) expression for void function side-effect capture. + + For void methods, we serialize the arguments (and receiver for instance methods) + AFTER the call, to capture any mutations as side effects. + Static class receivers (uppercase first letter) are excluded since they aren't instances. + """ + parts: list[str] = [] + receiver = call.get("receiver") + arg_exprs = call.get("arg_exprs", []) + if receiver and not (receiver[0].isupper() and receiver.isidentifier()): + parts.append(receiver) + parts.extend(arg_exprs) + if not parts: + return "null" + items = ", ".join(parts) + return f"com.codeflash.Serializer.serialize(new Object[]{{{items}}})" + + def wrap_target_calls_with_treesitter( body_lines: list[str], func_name: str, @@ -258,6 +278,7 @@ def wrap_target_calls_with_treesitter( precise_call_timing: bool = False, class_name: str = "", test_method_name: str = "", + is_void: bool = False, target_return_type: str = "", ) -> tuple[list[str], int]: """Replace target method calls in body_lines with capture + serialize using tree-sitter. @@ -328,25 +349,39 @@ def wrap_target_calls_with_treesitter( call_counter += 1 var_name = f"_cf_result{iter_id}_{call_counter}" cast_type = _infer_array_cast_type(body_line) - if not cast_type and target_return_type and target_return_type != "void": + if not cast_type and target_return_type and target_return_type not in ("void", "Object"): cast_type = target_return_type var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name - # Use per-call unique variables (with call_counter suffix) for behavior mode - # For behavior mode, we declare the variable outside try block, so use assignment not declaration here - # For performance mode, use shared variables without call_counter suffix - capture_stmt_with_decl = f"var {var_name} = {call['full_call']};" - capture_stmt_assign = f"{var_name} = {call['full_call']};" + # For void functions, we can't assign the return value to a variable + if is_void: + capture_stmt_with_decl = f"{call['full_call']};" + capture_stmt_assign = f"{call['full_call']};" + else: + # Use per-call unique variables (with call_counter suffix) for behavior mode + # For behavior mode, we declare the variable outside try block, so use assignment not declaration here + # For performance mode, use shared variables without call_counter suffix + capture_stmt_with_decl = f"var {var_name} = {call['full_call']};" + capture_stmt_assign = f"{var_name} = {call['full_call']};" + if precise_call_timing: # Behavior mode: per-call unique variables - serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});" + if is_void: + ser_expr = _build_void_serialize_expr(call) + serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = {ser_expr};" + else: + serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});" start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();" end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();" else: # Performance mode: shared variables without call_counter suffix - serialize_stmt = ( - f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});" - ) + if is_void: + ser_expr = _build_void_serialize_expr(call) + serialize_stmt = f"_cf_serializedResult{iter_id} = {ser_expr};" + else: + serialize_stmt = ( + f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});" + ) start_stmt = f"_cf_start{iter_id} = System.nanoTime();" end_stmt = f"_cf_end{iter_id} = System.nanoTime();" @@ -361,27 +396,38 @@ def wrap_target_calls_with_treesitter( if precise_call_timing: # For behavior mode: wrap each call in its own try-finally with SQLite write. # This ensures data from all calls is captured independently. - # Declare per-call variables + # Declare per-call variables (skip result variable for void) var_decls = [ - f"Object {var_name} = null;", f"long _cf_end{iter_id}_{call_counter} = -1;", f"long _cf_start{iter_id}_{call_counter} = 0;", f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", + f"java.io.ByteArrayOutputStream _cf_stdoutCapture{iter_id}_{call_counter} = new java.io.ByteArrayOutputStream();", + f"java.io.PrintStream _cf_origOut{iter_id}_{call_counter} = System.out;", + f"String _cf_stdout{iter_id}_{call_counter} = null;", ] + if not is_void: + var_decls.insert(0, f"Object {var_name} = null;") # Start marker start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{call_counter}" + "######$!");' - # Try block with capture (use assignment, not declaration, since variable is declared above) + # Try block with capture and stdout redirection try_block = [ "try {", + f" System.setOut(new java.io.PrintStream(_cf_stdoutCapture{iter_id}_{call_counter}));", f" {start_stmt}", f" {capture_stmt_assign}", f" {end_stmt}", f" {serialize_stmt}", ] - # Finally block with SQLite write + # Finally block with stdout restore and SQLite write finally_block = _generate_sqlite_write_code( iter_id, call_counter, "", class_name, func_name, test_method_name ) + # Insert stdout restore at the beginning of finally (after "} finally {" line) + finally_block.insert(1, f" System.setOut(_cf_origOut{iter_id}_{call_counter});") + finally_block.insert( + 2, + f' try {{ _cf_stdout{iter_id}_{call_counter} = _cf_stdoutCapture{iter_id}_{call_counter}.toString("UTF-8"); }} catch (Exception _cf_encEx{iter_id}_{call_counter}) {{}}', + ) replacement_lines = [*var_decls, start_marker, *try_block, *finally_block] # Don't add indent to first line (it's placed after existing indent), but add to subsequent lines @@ -404,25 +450,42 @@ def wrap_target_calls_with_treesitter( # Emit capture+serialize before the line, then replace the call with the variable. if precise_call_timing: # For behavior mode: wrap in try-finally with SQLite write - # Declare per-call variables - wrapped.append(f"{line_indent_str}Object {var_name} = null;") + # Declare per-call variables (skip result variable for void) + if not is_void: + wrapped.append(f"{line_indent_str}Object {var_name} = null;") wrapped.append(f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;") wrapped.append(f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;") wrapped.append(f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;") + wrapped.append( + f"{line_indent_str}java.io.ByteArrayOutputStream _cf_stdoutCapture{iter_id}_{call_counter} = new java.io.ByteArrayOutputStream();" + ) + wrapped.append( + f"{line_indent_str}java.io.PrintStream _cf_origOut{iter_id}_{call_counter} = System.out;" + ) + wrapped.append(f"{line_indent_str}String _cf_stdout{iter_id}_{call_counter} = null;") # Start marker wrapped.append( f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{call_counter}" + "######$!");' ) - # Try block (use assignment, not declaration, since variable is declared above) + # Try block with stdout redirection wrapped.append(f"{line_indent_str}try {{") + wrapped.append( + f"{line_indent_str} System.setOut(new java.io.PrintStream(_cf_stdoutCapture{iter_id}_{call_counter}));" + ) wrapped.append(f"{line_indent_str} {start_stmt}") wrapped.append(f"{line_indent_str} {capture_stmt_assign}") wrapped.append(f"{line_indent_str} {end_stmt}") wrapped.append(f"{line_indent_str} {serialize_stmt}") - # Finally block with SQLite write + # Finally block with stdout restore and SQLite write finally_lines = _generate_sqlite_write_code( iter_id, call_counter, line_indent_str, class_name, func_name, test_method_name ) + # Insert stdout restore at beginning of finally (after "} finally {" line) + finally_lines.insert(1, f"{line_indent_str} System.setOut(_cf_origOut{iter_id}_{call_counter});") + finally_lines.insert( + 2, + f'{line_indent_str} try {{ _cf_stdout{iter_id}_{call_counter} = _cf_stdoutCapture{iter_id}_{call_counter}.toString("UTF-8"); }} catch (Exception _cf_encEx{iter_id}_{call_counter}) {{}}', + ) wrapped.extend(finally_lines) else: capture_line = f"{line_indent_str}{capture_stmt_with_decl}" @@ -430,14 +493,17 @@ def wrap_target_calls_with_treesitter( serialize_line = f"{line_indent_str}{serialize_stmt}" wrapped.append(serialize_line) - call_start_byte = call["start_byte"] - line_byte_start - call_end_byte = call["end_byte"] - line_byte_start - call_start_char = len(line_bytes[:call_start_byte].decode("utf8")) - call_end_char = len(line_bytes[:call_end_byte].decode("utf8")) - adj_start = call_start_char + char_shift - adj_end = call_end_char + char_shift - new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:] - char_shift += len(var_with_cast) - (call_end_char - call_start_char) + # For void functions embedded in expressions, don't replace the call with a variable + # (this case is unusual for void methods but handle it gracefully) + if not is_void: + call_start_byte = call["start_byte"] - line_byte_start + call_end_byte = call["end_byte"] - line_byte_start + call_start_char = len(line_bytes[:call_start_byte].decode("utf8")) + call_end_char = len(line_bytes[:call_end_byte].decode("utf8")) + adj_start = call_start_char + char_shift + adj_end = call_end_char + char_shift + new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:] + char_shift += len(var_with_cast) - (call_end_char - call_start_char) # Keep the modified line only if it has meaningful content left if new_line.strip(): @@ -470,6 +536,16 @@ def _collect_calls( if parent_type == "expression_statement": es_start = parent.start_byte - prefix_len es_end = parent.end_byte - prefix_len + # Extract receiver and argument expressions for side-effect serialization + object_node = node.child_by_field_name("object") + receiver_text = analyzer.get_node_text(object_node, wrapper_bytes) if object_node else None + args_node = node.child_by_field_name("arguments") + arg_exprs = [] + if args_node: + for child in args_node.children: + if child.type not in ("(", ")", ","): + arg_exprs.append(analyzer.get_node_text(child, wrapper_bytes)) + out.append( { "start_byte": start, @@ -480,6 +556,8 @@ def _collect_calls( "in_complex": _is_inside_complex_expression(node), "es_start_byte": es_start, "es_end_byte": es_end, + "receiver": receiver_text, + "arg_exprs": arg_exprs, } ) for child in node.children: @@ -531,20 +609,20 @@ def _extract_return_type(function_to_optimize: Any) -> str: """Extract the return type of a Java function from its source file using tree-sitter.""" file_path = getattr(function_to_optimize, "file_path", None) func_name = _get_function_name(function_to_optimize) - if not file_path or not file_path.exists(): - return "" - try: - from codeflash.languages.java.parser import get_java_analyzer - - analyzer = get_java_analyzer() - source_text = file_path.read_text(encoding="utf-8") - methods = analyzer.find_methods(source_text) - for method in methods: - if method.name == func_name and method.return_type: - return method.return_type - except Exception: - logger.debug("Could not extract return type for %s", func_name) - return "" + if file_path and file_path.exists(): + try: + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source_text = file_path.read_text(encoding="utf-8") + methods = analyzer.find_methods(source_text) + for method in methods: + if method.name == func_name and method.return_type: + return method.return_type + except Exception: + logger.debug("Could not extract return type for %s", func_name) + # Fall back to the return_type attribute on the function model + return getattr(function_to_optimize, "return_type", "") or "" def _get_qualified_name(func: Any) -> str: @@ -613,7 +691,7 @@ def instrument_for_benchmarking( def instrument_existing_test( test_string: str, - function_to_optimize: Any, # FunctionToOptimize or FunctionToOptimize + function_to_optimize: Any, mode: str, # "behavior" or "performance" test_path: Path | None = None, test_class_name: str | None = None, @@ -662,6 +740,8 @@ def instrument_existing_test( # replacing substrings of other identifiers. modified_source = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, source) + is_void = target_return_type == "void" + # Add @SuppressWarnings("CheckReturnValue") to the class declaration. # Projects using Error Prone (e.g. Guava) enforce CheckReturnValue as a compiler error. # Applied in both modes: performance mode strips assertions (creating discarded return values), @@ -679,7 +759,7 @@ def instrument_existing_test( else: # Behavior mode: add timing instrumentation that also writes to SQLite modified_source = _add_behavior_instrumentation( - modified_source, original_class_name, func_name, target_return_type + modified_source, original_class_name, func_name, is_void=is_void, target_return_type=target_return_type ) logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name) @@ -687,7 +767,9 @@ def instrument_existing_test( return True, modified_source -def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, target_return_type: str = "") -> str: +def _add_behavior_instrumentation( + source: str, class_name: str, func_name: str, is_void: bool = False, target_return_type: str = "" +) -> str: """Add behavior instrumentation to test methods. For behavior mode, this adds: @@ -828,6 +910,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, precise_call_timing=True, class_name=class_name, test_method_name=test_method_name, + is_void=is_void, target_return_type=target_return_type, ) @@ -1339,7 +1422,8 @@ def instrument_generated_java_test( from codeflash.languages.java.remove_asserts import transform_java_assertions - test_code = transform_java_assertions(test_code, function_name, qualified_name) + is_void = getattr(function_to_optimize, "return_type", None) == "void" + test_code = transform_java_assertions(test_code, function_name, qualified_name, is_void=is_void) # Extract class name from the test code # Use pattern that starts at beginning of line to avoid matching words in comments diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index 470e0d62e..a5a5986c9 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -184,13 +184,18 @@ class JavaAssertTransformer: """ def __init__( - self, function_name: str, qualified_name: str | None = None, analyzer: JavaAnalyzer | None = None + self, + function_name: str, + qualified_name: str | None = None, + analyzer: JavaAnalyzer | None = None, + is_void: bool = False, ) -> None: self.analyzer = analyzer or get_java_analyzer() self.func_name = function_name self.qualified_name = qualified_name or function_name self.invocation_counter = 0 self._detected_framework: str | None = None + self.is_void = is_void # Precompile the assignment-detection regex to avoid recompiling on each call. self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$") @@ -911,9 +916,11 @@ def _infer_return_type(self, assertion: AssertionMatch) -> str: """ method = assertion.assertion_method - # assertTrue/assertFalse always deal with boolean values + # assertTrue/assertFalse: the assertion argument is boolean, but the extracted + # target call may return any type (e.g., assertTrue(fibonacci(n) < 0L) where + # fibonacci returns long). Use Object to safely capture via autoboxing. if method in {"assertTrue", "assertFalse"}: - return "boolean" + return "Object" # assertNull/assertNotNull — keep Object (reference type) if method in {"assertNull", "assertNotNull"}: @@ -949,8 +956,29 @@ def _infer_type_from_assertion_args(self, original_text: str, method: str) -> st elif args_str.endswith(")"): args_str = args_str[:-1] - # Fast-path: only extract the first top-level argument instead of splitting all arguments. - first_arg = self._extract_first_arg(args_str) + # Fast-path: try to cheaply obtain the first top-level argument without invoking + # the full parser. If the first comma occurs before any special characters that + # would affect top-levelness (quotes/parens/braces), we can take the substring + # up to that comma as the first argument. + if not args_str: + return "Object" + + # Find first comma; if none, the entire args_str is the single argument + comma_idx = args_str.find(",") + if comma_idx == -1: + first_arg = args_str + else: + # If there are no special delimiter characters before this comma, we can + # safely take the substring as the first argument. + before = args_str[:comma_idx] + if not self._special_re.search(before): + first_arg = before + else: + # Fallback: use the full extractor which respects nesting/strings/generics. + first_arg = self._extract_first_arg(args_str) + if first_arg is None: + return "Object" + if not first_arg: return "Object" @@ -1074,14 +1102,20 @@ def _generate_replacement(self, assertion: AssertionMatch) -> str: # Handle first call explicitly to avoid a per-iteration branch if calls: inv += 1 - var_name = "_cf_result" + str(inv) - replacements.append(f"{leading_ws}{return_type} {var_name} = {calls[0].full_call};") + if self.is_void: + replacements.append(f"{leading_ws}{calls[0].full_call};") + else: + var_name = "_cf_result" + str(inv) + replacements.append(f"{leading_ws}{return_type} {var_name} = {calls[0].full_call};") # Handle remaining calls for call in calls[1:]: inv += 1 - var_name = "_cf_result" + str(inv) - replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};") + if self.is_void: + replacements.append(f"{base_indent}{call.full_call};") + else: + var_name = "_cf_result" + str(inv) + replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};") # Write back the counter self.invocation_counter = inv @@ -1198,7 +1232,9 @@ def _extract_first_arg(self, args_str: str) -> str | None: return "".join(cur).rstrip() -def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str: +def transform_java_assertions( + source: str, function_name: str, qualified_name: str | None = None, is_void: bool = False +) -> str: """Transform Java test code by removing assertions and capturing function calls. This is the main entry point for Java assertion transformation. @@ -1207,12 +1243,13 @@ def transform_java_assertions(source: str, function_name: str, qualified_name: s source: The Java test source code. function_name: Name of the function being tested. qualified_name: Optional fully qualified name of the function. + is_void: Whether the target function returns void. Returns: Transformed source code with assertions replaced by capture statements. """ - transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name) + transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name, is_void=is_void) return transformer.transform(source) @@ -1230,6 +1267,10 @@ def remove_assertions_from_test(source: str, target_function: FunctionToOptimize Transformed source code. """ + is_void = getattr(target_function, "return_type", None) == "void" return transform_java_assertions( - source=source, function_name=target_function.function_name, qualified_name=target_function.qualified_name + source=source, + function_name=target_function.function_name, + qualified_name=target_function.qualified_name, + is_void=is_void, ) diff --git a/codeflash/languages/java/resources/CodeflashHelper.java b/codeflash/languages/java/resources/CodeflashHelper.java index 9ece32679..a83894653 100644 --- a/codeflash/languages/java/resources/CodeflashHelper.java +++ b/codeflash/languages/java/resources/CodeflashHelper.java @@ -3,7 +3,9 @@ import java.io.ByteArrayOutputStream; import java.io.File; import java.io.ObjectOutputStream; +import java.io.PrintStream; import java.io.Serializable; +import java.io.UnsupportedEncodingException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; @@ -78,14 +80,23 @@ public static T capture( String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested; int iterationId = getNextIterationId(invocationKey); + // Capture stdout + PrintStream originalOut = System.out; + ByteArrayOutputStream capturedStdout = new ByteArrayOutputStream(); + long startTime = System.nanoTime(); T result; try { + System.setOut(new PrintStream(capturedStdout)); result = callable.call(); } finally { + System.setOut(originalOut); long endTime = System.nanoTime(); long durationNs = endTime - startTime; + String stdoutText = null; + try { stdoutText = capturedStdout.toString("UTF-8"); } catch (UnsupportedEncodingException ignored) {} + // Write to SQLite for behavior verification writeResultToSqlite( testModulePath, @@ -96,7 +107,8 @@ public static T capture( iterationId, durationNs, null, // return_value - TODO: serialize if needed - "output" + "output", + stdoutText ); // Print timing marker for stdout parsing (backup method) @@ -118,13 +130,22 @@ public static void captureVoid( String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested; int iterationId = getNextIterationId(invocationKey); + // Capture stdout + PrintStream originalOut = System.out; + ByteArrayOutputStream capturedStdout = new ByteArrayOutputStream(); + long startTime = System.nanoTime(); try { + System.setOut(new PrintStream(capturedStdout)); callable.call(); } finally { + System.setOut(originalOut); long endTime = System.nanoTime(); long durationNs = endTime - startTime; + String stdoutText = null; + try { stdoutText = capturedStdout.toString("UTF-8"); } catch (UnsupportedEncodingException ignored) {} + // Write to SQLite writeResultToSqlite( testModulePath, @@ -135,7 +156,8 @@ public static void captureVoid( iterationId, durationNs, null, - "output" + "output", + stdoutText ); // Print timing marker @@ -177,7 +199,8 @@ public static T capturePerf( iterationId, durationNs, null, - "output" + "output", + null ); // Print end marker with timing @@ -219,7 +242,8 @@ public static void capturePerfVoid( iterationId, durationNs, null, - "output" + "output", + null ); // Print end marker with timing @@ -277,7 +301,8 @@ private static synchronized void writeResultToSqlite( int iterationId, long runtime, byte[] returnValue, - String verificationType + String verificationType, + String stdout ) { if (OUTPUT_FILE == null || OUTPUT_FILE.isEmpty()) { return; @@ -291,8 +316,8 @@ private static synchronized void writeResultToSqlite( String sql = "INSERT INTO test_results " + "(test_module_path, test_class_name, test_function_name, function_getting_tested, " + - "loop_index, iteration_id, runtime, return_value, verification_type) " + - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + "loop_index, iteration_id, runtime, return_value, verification_type, stdout) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement stmt = dbConnection.prepareStatement(sql)) { stmt.setString(1, testModulePath); @@ -304,6 +329,7 @@ private static synchronized void writeResultToSqlite( stmt.setLong(7, runtime); stmt.setBytes(8, returnValue); stmt.setString(9, verificationType); + stmt.setString(10, stdout); stmt.executeUpdate(); } } catch (SQLException e) { @@ -348,7 +374,8 @@ private static void ensureDbInitialized() { "iteration_id INTEGER, " + "runtime INTEGER, " + "return_value BLOB, " + - "verification_type TEXT" + + "verification_type TEXT, " + + "stdout TEXT" + ")"; try (java.sql.Statement stmt = dbConnection.createStatement()) { diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar index 92ad8be00..ab469c59e 100644 Binary files a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index c56a7d1bd..363492461 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -284,15 +284,7 @@ def ensure_multi_module_deps_installed(maven_root: Path, test_module: str | None logger.error("Maven not found — cannot pre-install multi-module dependencies") return False - cmd = [ - mvn, - "install", - "-DskipTests", - "-B", - "-pl", - test_module, - "-am", - ] + cmd = [mvn, "install", "-DskipTests", "-B", "-pl", test_module, "-am"] cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS) logger.info("Pre-installing multi-module dependencies: %s (module: %s)", maven_root, test_module) @@ -1006,7 +998,10 @@ def _get_test_class_names(test_paths: Any, mode: str = "performance") -> list[st elif isinstance(path, str): class_names.append(path) - return class_names + # Sort to match Maven Surefire's alphabetical execution order. + # Without sorting, iteration_id collisions across test classes resolve + # differently between Maven (original) and direct JVM (candidate) runs. + return sorted(class_names) def _get_empty_result(maven_root: Path, test_module: str | None) -> tuple[Path, Any]: diff --git a/codeflash/models/function_types.py b/codeflash/models/function_types.py index bea6672b0..8b2f4862b 100644 --- a/codeflash/models/function_types.py +++ b/codeflash/models/function_types.py @@ -61,6 +61,7 @@ class FunctionToOptimize: is_method: bool = False language: str = "python" doc_start_line: Optional[int] = None + return_type: Optional[str] = None @property def top_level_parent_name(self) -> str: diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 641aedf51..1cee1eeb3 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -487,13 +487,26 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes console.rule() return test_results db = None + has_stdout_column = False try: db = sqlite3.connect(sqlite_file_path) cur = db.cursor() - data = cur.execute( - "SELECT test_module_path, test_class_name, test_function_name, " - "function_getting_tested, loop_index, iteration_id, runtime, return_value,verification_type FROM test_results" - ).fetchall() + # Check if stdout column exists (backward compatibility with older schemas) + columns_info = cur.execute("PRAGMA table_info(test_results)").fetchall() + column_names = {col[1] for col in columns_info} + has_stdout_column = "stdout" in column_names + if has_stdout_column: + data = cur.execute( + "SELECT test_module_path, test_class_name, test_function_name, " + "function_getting_tested, loop_index, iteration_id, runtime, return_value, " + "verification_type, stdout FROM test_results" + ).fetchall() + else: + data = cur.execute( + "SELECT test_module_path, test_class_name, test_function_name, " + "function_getting_tested, loop_index, iteration_id, runtime, return_value, " + "verification_type FROM test_results" + ).fetchall() except Exception as e: logger.warning(f"Failed to parse test results from {sqlite_file_path}. Exception: {e}") if db is not None: @@ -635,6 +648,11 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes logger.debug(f"Failed to deserialize return value for {test_function_name}: {e}") continue + # Extract stdout from SQLite (Java/JS behavior capture) + captured_stdout = None + if has_stdout_column and len(val) > 9 and val[9]: + captured_stdout = val[9] + test_results.add( function_test_invocation=FunctionTestInvocation( loop_index=loop_index, @@ -653,6 +671,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes return_value=ret_val, timed_out=False, verification_type=VerificationType(verification_type) if verification_type else None, + stdout=captured_stdout, ) ) except Exception: diff --git a/tests/test_java_assertion_removal.py b/tests/test_java_assertion_removal.py index ec37e7c27..75cca13da 100644 --- a/tests/test_java_assertion_removal.py +++ b/tests/test_java_assertion_removal.py @@ -52,7 +52,7 @@ def test_assert_true(self): expected = """\ @Test void testIsValid() { - boolean _cf_result1 = validator.isValid("test"); + Object _cf_result1 = validator.isValid("test"); }""" result = transform_java_assertions(source, "isValid") assert result == expected @@ -66,7 +66,7 @@ def test_assert_false(self): expected = """\ @Test void testIsInvalid() { - boolean _cf_result1 = validator.isValid(""); + Object _cf_result1 = validator.isValid(""); }""" result = transform_java_assertions(source, "isValid") assert result == expected @@ -1102,7 +1102,7 @@ def test_volatile_field_read_preserved(self): expected = """\ @Test void testVolatileRead() { - boolean _cf_result1 = buffer.isReady(); + Object _cf_result1 = buffer.isReady(); }""" result = transform_java_assertions(source, "isReady") assert result == expected @@ -1230,7 +1230,7 @@ def test_wait_notify_pattern_preserved(self): synchronized (monitor) { monitor.notify(); } - boolean _cf_result1 = listener.wasNotified(); + Object _cf_result1 = listener.wasNotified(); }""" result = transform_java_assertions(source, "wasNotified") assert result == expected @@ -1292,8 +1292,8 @@ def test_token_bucket_synchronized_method(self): @Test void testTokenBucketAllowRequest() { TokenBucket bucket = new TokenBucket(10, 1); - boolean _cf_result1 = bucket.allowRequest(); - boolean _cf_result2 = bucket.allowRequest(); + Object _cf_result1 = bucket.allowRequest(); + Object _cf_result2 = bucket.allowRequest(); }""" result = transform_java_assertions(source, "allowRequest") assert result == expected @@ -1315,9 +1315,9 @@ def test_circular_buffer_atomic_integer_pattern(self): @Test void testCircularBufferOperations() { CircularBuffer buffer = new CircularBuffer<>(3); - boolean _cf_result1 = buffer.isEmpty(); + Object _cf_result1 = buffer.isEmpty(); buffer.put(1); - boolean _cf_result2 = buffer.isEmpty(); + Object _cf_result2 = buffer.isEmpty(); }""" result = transform_java_assertions(source, "isEmpty") assert result == expected diff --git a/tests/test_languages/test_java/test_comparator.py b/tests/test_languages/test_java/test_comparator.py index 02272be35..17220b444 100644 --- a/tests/test_languages/test_java/test_comparator.py +++ b/tests/test_languages/test_java/test_comparator.py @@ -1,24 +1,17 @@ """Tests for Java test result comparison.""" -import json import shutil import sqlite3 -import tempfile from pathlib import Path import pytest -from codeflash.languages.java.comparator import ( - compare_invocations_directly, - compare_test_results, - values_equal, -) +from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results, values_equal from codeflash.models.models import TestDiffScope # Skip tests that require Java runtime if Java is not available requires_java = pytest.mark.skipif( - shutil.which("java") is None, - reason="Java not found - skipping Comparator integration tests", + shutil.which("java") is None, reason="Java not found - skipping Comparator integration tests" ) # Kryo-serialized bytes for common test values. @@ -38,7 +31,9 @@ KRYO_STR_VALUE1 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0xFD]) KRYO_STR_VALUE2 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x32, 0xFD]) KRYO_STR_VALUE42 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x34, 0x32, 0xFD]) -KRYO_STR_VALUE100 = bytes([0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD]) +KRYO_STR_VALUE100 = bytes( + [0x03, 0x01, 0x7B, 0x22, 0x76, 0x61, 0x6C, 0x75, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0xFD] +) KRYO_DOUBLE_1_0000000001 = bytes([0x0A, 0x38, 0xDF, 0x06, 0x00, 0x00, 0x00, 0xF0, 0x3F]) KRYO_DOUBLE_1_0000000002 = bytes([0x0A, 0x70, 0xBE, 0x0D, 0x00, 0x00, 0x00, 0xF0, 0x3F]) KRYO_NAN = bytes([0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF8, 0x7F]) @@ -67,12 +62,8 @@ def test_identical_results(self): def test_different_return_values(self): """Test detecting different return values.""" - original = { - "1": {"result_json": '{"value": 42}', "error_json": None}, - } - candidate = { - "1": {"result_json": '{"value": 99}', "error_json": None}, - } + original = {"1": {"result_json": '{"value": 42}', "error_json": None}} + candidate = {"1": {"result_json": '{"value": 99}', "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) @@ -89,7 +80,7 @@ def test_missing_invocation_in_candidate(self): "2": {"result_json": '{"value": 100}', "error_json": None}, } candidate = { - "1": {"result_json": '{"value": 42}', "error_json": None}, + "1": {"result_json": '{"value": 42}', "error_json": None} # Missing invocation 2 } @@ -101,9 +92,7 @@ def test_missing_invocation_in_candidate(self): def test_extra_invocation_in_candidate(self): """Test detecting extra invocation in candidate.""" - original = { - "1": {"result_json": '{"value": 42}', "error_json": None}, - } + original = {"1": {"result_json": '{"value": 42}', "error_json": None}} candidate = { "1": {"result_json": '{"value": 42}', "error_json": None}, "2": {"result_json": '{"value": 100}', "error_json": None}, # Extra @@ -116,11 +105,9 @@ def test_extra_invocation_in_candidate(self): def test_exception_differences(self): """Test detecting exception differences.""" - original = { - "1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}, - } + original = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}} candidate = { - "1": {"result_json": '{"value": 42}', "error_json": None}, # No exception + "1": {"result_json": '{"value": 42}', "error_json": None} # No exception } equivalent, diffs = compare_invocations_directly(original, candidate) @@ -176,12 +163,8 @@ def test_non_numeric_strings_differ(self): def test_numeric_comparison_in_direct_invocation(self): """Test that compare_invocations_directly uses numeric-aware comparison.""" - original = { - "1": {"result_json": "0", "error_json": None}, - } - candidate = { - "1": {"result_json": "0.0", "error_json": None}, - } + original = {"1": {"result_json": "0", "error_json": None}} + candidate = {"1": {"result_json": "0.0", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -189,12 +172,8 @@ def test_numeric_comparison_in_direct_invocation(self): def test_integer_long_mismatch_resolved(self): """Test that Integer(42) vs Long(42) serialized differently are still equal.""" - original = { - "1": {"result_json": "42", "error_json": None}, - } - candidate = { - "1": {"result_json": "42.0", "error_json": None}, - } + original = {"1": {"result_json": "42", "error_json": None}} + candidate = {"1": {"result_json": "42.0", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -263,46 +242,30 @@ def test_negative_zero(self): def test_boolean_invocation_comparison(self): """Test boolean return values in full invocation comparison.""" - original = { - "1": {"result_json": "true", "error_json": None}, - } - candidate = { - "1": {"result_json": "true", "error_json": None}, - } + original = {"1": {"result_json": "true", "error_json": None}} + candidate = {"1": {"result_json": "true", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True def test_boolean_mismatch_invocation_comparison(self): """Test boolean mismatch is correctly detected.""" - original = { - "1": {"result_json": "true", "error_json": None}, - } - candidate = { - "1": {"result_json": "false", "error_json": None}, - } + original = {"1": {"result_json": "true", "error_json": None}} + candidate = {"1": {"result_json": "false", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False assert len(diffs) == 1 def test_array_invocation_comparison(self): """Test array return values in full invocation comparison.""" - original = { - "1": {"result_json": "[1, 2, 3]", "error_json": None}, - } - candidate = { - "1": {"result_json": "[1, 2, 3]", "error_json": None}, - } + original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} + candidate = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True def test_array_mismatch_invocation_comparison(self): """Test array mismatch is correctly detected.""" - original = { - "1": {"result_json": "[1, 2, 3]", "error_json": None}, - } - candidate = { - "1": {"result_json": "[1, 2, 4]", "error_json": None}, - } + original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} + candidate = {"1": {"result_json": "[1, 2, 4]", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False assert len(diffs) == 1 @@ -382,35 +345,25 @@ class TestComparisonWithRealData: def test_string_result_comparison(self): """Test comparing string results.""" - original = { - "1": {"result_json": '"Hello World"', "error_json": None}, - } - candidate = { - "1": {"result_json": '"Hello World"', "error_json": None}, - } + original = {"1": {"result_json": '"Hello World"', "error_json": None}} + candidate = {"1": {"result_json": '"Hello World"', "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True def test_array_result_comparison(self): """Test comparing array results.""" - original = { - "1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}, - } - candidate = { - "1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}, - } + original = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}} + candidate = {"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True def test_array_order_matters(self): """Test that array order matters for comparison.""" - original = { - "1": {"result_json": "[1, 2, 3]", "error_json": None}, - } + original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} candidate = { - "1": {"result_json": "[3, 2, 1]", "error_json": None}, # Different order + "1": {"result_json": "[3, 2, 1]", "error_json": None} # Different order } equivalent, diffs = compare_invocations_directly(original, candidate) @@ -418,24 +371,16 @@ def test_array_order_matters(self): def test_object_result_comparison(self): """Test comparing object results.""" - original = { - "1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}, - } - candidate = { - "1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}, - } + original = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}} + candidate = {"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True def test_null_result(self): """Test comparing null results.""" - original = { - "1": {"result_json": "null", "error_json": None}, - } - candidate = { - "1": {"result_json": "null", "error_json": None}, - } + original = {"1": {"result_json": "null", "error_json": None}} + candidate = {"1": {"result_json": "null", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -462,11 +407,9 @@ class TestEdgeCases: def test_whitespace_in_json(self): """Test that whitespace differences in JSON don't cause issues.""" - original = { - "1": {"result_json": '{"a":1,"b":2}', "error_json": None}, - } + original = {"1": {"result_json": '{"a":1,"b":2}', "error_json": None}} candidate = { - "1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None}, # With spaces + "1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None} # With spaces } # Note: Direct string comparison will see these as different @@ -486,12 +429,8 @@ def test_large_number_of_invocations(self): def test_unicode_in_results(self): """Test handling unicode in results.""" - original = { - "1": {"result_json": '"Hello 世界 🌍"', "error_json": None}, - } - candidate = { - "1": {"result_json": '"Hello 世界 🌍"', "error_json": None}, - } + original = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}} + candidate = {"1": {"result_json": '"Hello 世界 🌍"', "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -499,12 +438,8 @@ def test_unicode_in_results(self): def test_deeply_nested_objects(self): """Test handling deeply nested objects.""" nested = '{"a": {"b": {"c": {"d": {"e": 1}}}}}' - original = { - "1": {"result_json": nested, "error_json": None}, - } - candidate = { - "1": {"result_json": nested, "error_json": None}, - } + original = {"1": {"result_json": nested, "error_json": None}} + candidate = {"1": {"result_json": nested, "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -573,9 +508,7 @@ def _create(path: Path, results: list[dict]): return _create - def test_comparator_reads_test_results_table_identical( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_reads_test_results_table_identical(self, tmp_path: Path, create_test_results_db): """Test that Comparator correctly reads test_results table with identical results.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -607,9 +540,7 @@ def test_comparator_reads_test_results_table_identical( assert equivalent is True assert len(diffs) == 0 - def test_comparator_reads_test_results_table_different_values( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_reads_test_results_table_different_values(self, tmp_path: Path, create_test_results_db): """Test that Comparator detects different return values from test_results table.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -621,7 +552,7 @@ def test_comparator_reads_test_results_table_different_values( "loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_STR_OLLEH, - }, + } ] candidate_results = [ @@ -631,7 +562,7 @@ def test_comparator_reads_test_results_table_different_values( "loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_STR_WRONG, # Different result - }, + } ] create_test_results_db(original_path, original_results) @@ -644,9 +575,7 @@ def test_comparator_reads_test_results_table_different_values( assert len(diffs) == 1 assert diffs[0].scope == TestDiffScope.RETURN_VALUE - def test_comparator_handles_multiple_loop_iterations( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_handles_multiple_loop_iterations(self, tmp_path: Path, create_test_results_db): """Test that Comparator correctly handles multiple loop iterations.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -676,9 +605,7 @@ def test_comparator_handles_multiple_loop_iterations( assert equivalent is True assert len(diffs) == 0 - def test_comparator_iteration_id_parsing( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_iteration_id_parsing(self, tmp_path: Path, create_test_results_db): """Test that Comparator correctly parses iteration_id format 'iter_testIteration'.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -711,32 +638,18 @@ def test_comparator_iteration_id_parsing( assert equivalent is True assert len(diffs) == 0 - def test_comparator_missing_result_in_candidate( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_missing_result_in_candidate(self, tmp_path: Path, create_test_results_db): """Test that Comparator detects missing results in candidate.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" original_results = [ - { - "loop_index": 1, - "iteration_id": "1_0", - "return_value": KRYO_INT_1, - }, - { - "loop_index": 1, - "iteration_id": "2_0", - "return_value": KRYO_INT_2, - }, + {"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1}, + {"loop_index": 1, "iteration_id": "2_0", "return_value": KRYO_INT_2}, ] candidate_results = [ - { - "loop_index": 1, - "iteration_id": "1_0", - "return_value": KRYO_INT_1, - }, + {"loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_INT_1} # Missing second iteration ] @@ -779,12 +692,8 @@ def test_float_values_slightly_different(self): For truly different values, the difference must exceed the epsilon threshold. """ # These values differ by ~3e-10, which is within epsilon tolerance (1e-9) - original = { - "1": {"result_json": "3.14159", "error_json": None}, - } - candidate = { - "1": {"result_json": "3.141590001", "error_json": None}, - } + original = {"1": {"result_json": "3.14159", "error_json": None}} + candidate = {"1": {"result_json": "3.141590001", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True # Within epsilon tolerance @@ -792,11 +701,9 @@ def test_float_values_slightly_different(self): def test_float_values_significantly_different(self): """Float strings outside epsilon tolerance should be detected as different.""" - original = { - "1": {"result_json": "3.14159", "error_json": None}, - } + original = {"1": {"result_json": "3.14159", "error_json": None}} candidate = { - "1": {"result_json": "3.14160", "error_json": None}, # Differs by ~1e-5 + "1": {"result_json": "3.14160", "error_json": None} # Differs by ~1e-5 } equivalent, diffs = compare_invocations_directly(original, candidate) @@ -806,12 +713,8 @@ def test_float_values_significantly_different(self): def test_nan_string_comparison(self): """NaN as a string return value should be comparable.""" - original = { - "1": {"result_json": "NaN", "error_json": None}, - } - candidate = { - "1": {"result_json": "NaN", "error_json": None}, - } + original = {"1": {"result_json": "NaN", "error_json": None}} + candidate = {"1": {"result_json": "NaN", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -819,12 +722,8 @@ def test_nan_string_comparison(self): def test_nan_vs_number(self): """NaN vs a normal number should be detected as different.""" - original = { - "1": {"result_json": "NaN", "error_json": None}, - } - candidate = { - "1": {"result_json": "0.0", "error_json": None}, - } + original = {"1": {"result_json": "NaN", "error_json": None}} + candidate = {"1": {"result_json": "0.0", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -832,12 +731,8 @@ def test_nan_vs_number(self): def test_infinity_string_comparison(self): """Infinity as a string return value should be comparable.""" - original = { - "1": {"result_json": "Infinity", "error_json": None}, - } - candidate = { - "1": {"result_json": "Infinity", "error_json": None}, - } + original = {"1": {"result_json": "Infinity", "error_json": None}} + candidate = {"1": {"result_json": "Infinity", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -845,12 +740,8 @@ def test_infinity_string_comparison(self): def test_negative_infinity(self): """-Infinity as a string return value should be comparable.""" - original = { - "1": {"result_json": "-Infinity", "error_json": None}, - } - candidate = { - "1": {"result_json": "-Infinity", "error_json": None}, - } + original = {"1": {"result_json": "-Infinity", "error_json": None}} + candidate = {"1": {"result_json": "-Infinity", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -858,12 +749,8 @@ def test_negative_infinity(self): def test_infinity_vs_negative_infinity(self): """Infinity and -Infinity should be detected as different.""" - original = { - "1": {"result_json": "Infinity", "error_json": None}, - } - candidate = { - "1": {"result_json": "-Infinity", "error_json": None}, - } + original = {"1": {"result_json": "Infinity", "error_json": None}} + candidate = {"1": {"result_json": "-Infinity", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -871,12 +758,8 @@ def test_infinity_vs_negative_infinity(self): def test_empty_collection_results(self): """Empty array '[]' as return value should be comparable.""" - original = { - "1": {"result_json": "[]", "error_json": None}, - } - candidate = { - "1": {"result_json": "[]", "error_json": None}, - } + original = {"1": {"result_json": "[]", "error_json": None}} + candidate = {"1": {"result_json": "[]", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -884,12 +767,8 @@ def test_empty_collection_results(self): def test_empty_object_results(self): """Empty object '{}' as return value should be comparable.""" - original = { - "1": {"result_json": "{}", "error_json": None}, - } - candidate = { - "1": {"result_json": "{}", "error_json": None}, - } + original = {"1": {"result_json": "{}", "error_json": None}} + candidate = {"1": {"result_json": "{}", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -917,12 +796,8 @@ def test_large_number_different(self): 1e+17 as floats due to precision limits, making them indistinguishable. This is a known limitation of floating-point comparison for very large integers. """ - original = { - "1": {"result_json": "99999999999999999", "error_json": None}, - } - candidate = { - "1": {"result_json": "99999999999999998", "error_json": None}, - } + original = {"1": {"result_json": "99999999999999999", "error_json": None}} + candidate = {"1": {"result_json": "99999999999999998", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) # Due to float precision limits, these are considered equal @@ -931,12 +806,8 @@ def test_large_number_different(self): def test_large_number_significantly_different(self): """Large numbers with significant differences should be detected.""" - original = { - "1": {"result_json": "100000000000000000", "error_json": None}, - } - candidate = { - "1": {"result_json": "200000000000000000", "error_json": None}, - } + original = {"1": {"result_json": "100000000000000000", "error_json": None}} + candidate = {"1": {"result_json": "200000000000000000", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -944,12 +815,8 @@ def test_large_number_significantly_different(self): def test_null_vs_empty_string(self): """'null' and '""' should NOT be equivalent.""" - original = { - "1": {"result_json": "null", "error_json": None}, - } - candidate = { - "1": {"result_json": '""', "error_json": None}, - } + original = {"1": {"result_json": "null", "error_json": None}} + candidate = {"1": {"result_json": '""', "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -958,10 +825,7 @@ def test_null_vs_empty_string(self): def test_boolean_string_comparison(self): """Boolean strings 'true'/'false' should compare correctly.""" - original = { - "1": {"result_json": "true", "error_json": None}, - "2": {"result_json": "false", "error_json": None}, - } + original = {"1": {"result_json": "true", "error_json": None}, "2": {"result_json": "false", "error_json": None}} candidate = { "1": {"result_json": "true", "error_json": None}, "2": {"result_json": "false", "error_json": None}, @@ -972,12 +836,8 @@ def test_boolean_string_comparison(self): def test_boolean_true_vs_false(self): """'true' vs 'false' should be detected as different.""" - original = { - "1": {"result_json": "true", "error_json": None}, - } - candidate = { - "1": {"result_json": "false", "error_json": None}, - } + original = {"1": {"result_json": "true", "error_json": None}} + candidate = {"1": {"result_json": "false", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -1024,12 +884,8 @@ def test_compare_schema_mismatch_db(self, tmp_path: Path): def test_compare_with_none_return_values_direct(self): """Rows where result_json is None should be handled in direct comparison.""" - original = { - "1": {"result_json": None, "error_json": None}, - } - candidate = { - "1": {"result_json": None, "error_json": None}, - } + original = {"1": {"result_json": None, "error_json": None}} + candidate = {"1": {"result_json": None, "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -1037,12 +893,8 @@ def test_compare_with_none_return_values_direct(self): def test_compare_one_none_one_value_direct(self): """One None result vs a real value should detect the difference.""" - original = { - "1": {"result_json": None, "error_json": None}, - } - candidate = { - "1": {"result_json": "42", "error_json": None}, - } + original = {"1": {"result_json": None, "error_json": None}} + candidate = {"1": {"result_json": "42", "error_json": None}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -1050,12 +902,8 @@ def test_compare_one_none_one_value_direct(self): def test_compare_both_errors_identical(self): """Identical errors in both original and candidate should be equivalent.""" - original = { - "1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}, - } - candidate = { - "1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}, - } + original = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}} + candidate = {"1": {"result_json": None, "error_json": '{"type": "IOException", "message": "file not found"}'}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is True @@ -1063,12 +911,8 @@ def test_compare_both_errors_identical(self): def test_compare_different_error_types(self): """Different error types should be detected.""" - original = { - "1": {"result_json": None, "error_json": '{"type": "IOException"}'}, - } - candidate = { - "1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}, - } + original = {"1": {"result_json": None, "error_json": '{"type": "IOException"}'}} + candidate = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}} equivalent, diffs = compare_invocations_directly(original, candidate) assert equivalent is False @@ -1083,9 +927,7 @@ class TestComparatorJavaEdgeCases(TestTestResultsTableSchema): Extends TestTestResultsTableSchema to reuse the create_test_results_db fixture. """ - def test_comparator_float_epsilon_tolerance( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_float_epsilon_tolerance(self, tmp_path: Path, create_test_results_db): """Values differing by less than EPSILON (1e-9) should be treated as equivalent. The Java Comparator uses EPSILON=1e-9 for float comparison. @@ -1102,7 +944,7 @@ def test_comparator_float_epsilon_tolerance( "loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_DOUBLE_1_0000000001, - }, + } ] candidate_results = [ @@ -1112,7 +954,7 @@ def test_comparator_float_epsilon_tolerance( "loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_DOUBLE_1_0000000002, - }, + } ] create_test_results_db(original_path, original_results) @@ -1124,9 +966,7 @@ def test_comparator_float_epsilon_tolerance( assert equivalent is True assert len(diffs) == 0 - def test_comparator_nan_handling( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_nan_handling(self, tmp_path: Path, create_test_results_db): """Java Comparator should handle NaN return values.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -1138,7 +978,7 @@ def test_comparator_nan_handling( "loop_index": 1, "iteration_id": "1_0", "return_value": KRYO_NAN, - }, + } ] create_test_results_db(original_path, results) @@ -1150,9 +990,7 @@ def test_comparator_nan_handling( assert equivalent is True assert len(diffs) == 0 - def test_comparator_empty_table( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_empty_table(self, tmp_path: Path, create_test_results_db): """Empty test_results tables should result in equivalent=False (vacuous equivalence guard).""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -1167,9 +1005,7 @@ def test_comparator_empty_table( assert equivalent is False assert len(diffs) == 0 - def test_comparator_infinity_handling( - self, tmp_path: Path, create_test_results_db - ): + def test_comparator_infinity_handling(self, tmp_path: Path, create_test_results_db): """Java Comparator should handle Infinity return values correctly.""" original_path = tmp_path / "original.db" candidate_path = tmp_path / "candidate.db" @@ -1198,3 +1034,238 @@ def test_comparator_infinity_handling( assert equivalent is True assert len(diffs) == 0 + + +class TestVoidFunctionComparison: + """Tests for void function comparison using compare_invocations_directly.""" + + def test_void_both_null_result_equivalent(self): + """Both original and candidate have None result_json (void success).""" + original = {"1": {"result_json": None, "error_json": None}} + candidate = {"1": {"result_json": None, "error_json": None}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is True + assert len(diffs) == 0 + + def test_void_null_vs_non_null_result(self): + """Original void (None) vs candidate with return value should differ.""" + original = {"1": {"result_json": None, "error_json": None}} + candidate = {"1": {"result_json": "42", "error_json": None}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is False + assert len(diffs) == 1 + assert diffs[0].scope == TestDiffScope.RETURN_VALUE + + def test_void_non_null_vs_null_result(self): + """Original with return value vs candidate void (None) should differ.""" + original = {"1": {"result_json": "42", "error_json": None}} + candidate = {"1": {"result_json": None, "error_json": None}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is False + assert len(diffs) == 1 + assert diffs[0].scope == TestDiffScope.RETURN_VALUE + + def test_void_same_serialized_side_effects(self): + """Identical side-effect serializations (Object[] arrays) should be equivalent.""" + original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} + candidate = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is True + assert len(diffs) == 0 + + def test_void_different_serialized_side_effects(self): + """Different side-effect serializations should be detected.""" + original = {"1": {"result_json": "[1, 2, 3]", "error_json": None}} + candidate = {"1": {"result_json": "[1, 2, 99]", "error_json": None}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is False + assert len(diffs) == 1 + assert diffs[0].scope == TestDiffScope.RETURN_VALUE + assert diffs[0].original_value == "[1, 2, 3]" + assert diffs[0].candidate_value == "[1, 2, 99]" + + def test_void_exception_in_candidate(self): + """Void success in original vs exception in candidate should differ.""" + original = {"1": {"result_json": None, "error_json": None}} + candidate = {"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'}} + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is False + assert len(diffs) == 1 + assert diffs[0].scope == TestDiffScope.DID_PASS + + def test_void_multiple_invocations_mixed(self): + """Multiple void invocations: some matching, some differing.""" + original = { + "1": {"result_json": None, "error_json": None}, + "2": {"result_json": "[10, 20]", "error_json": None}, + "3": {"result_json": None, "error_json": None}, + } + candidate = { + "1": {"result_json": None, "error_json": None}, + "2": {"result_json": "[10, 99]", "error_json": None}, + "3": {"result_json": None, "error_json": '{"type": "RuntimeException"}'}, + } + + equivalent, diffs = compare_invocations_directly(original, candidate) + + assert equivalent is False + assert len(diffs) == 2 + sorted_diffs = sorted(diffs, key=lambda d: d.scope.value) + assert sorted_diffs[0].scope == TestDiffScope.DID_PASS + assert sorted_diffs[1].scope == TestDiffScope.RETURN_VALUE + assert sorted_diffs[1].original_value == "[10, 20]" + assert sorted_diffs[1].candidate_value == "[10, 99]" + + +@requires_java +class TestVoidSqliteComparison: + """Tests for void function comparison via Java Comparator with 10-column SQLite schema.""" + + @pytest.fixture + def create_void_test_results_db(self): + """Create a test SQLite database with 10-column schema (including stdout).""" + + def _create(path: Path, results: list[dict]): + conn = sqlite3.connect(path) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE test_results ( + test_module_path TEXT, + test_class_name TEXT, + test_function_name TEXT, + function_getting_tested TEXT, + loop_index INTEGER, + iteration_id TEXT, + runtime INTEGER, + return_value BLOB, + verification_type TEXT, + stdout TEXT + ) + """ + ) + + for result in results: + cursor.execute( + """ + INSERT INTO test_results + (test_module_path, test_class_name, test_function_name, + function_getting_tested, loop_index, iteration_id, + runtime, return_value, verification_type, stdout) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + result.get("test_module_path", "TestModule"), + result.get("test_class_name", "TestClass"), + result.get("test_function_name", "testMethod"), + result.get("function_getting_tested", "targetMethod"), + result.get("loop_index", 1), + result.get("iteration_id", "1_0"), + result.get("runtime", 1000000), + result.get("return_value"), + result.get("verification_type", "function_call"), + result.get("stdout"), + ), + ) + + conn.commit() + conn.close() + return path + + return _create + + def test_void_sqlite_both_null_return_same_stdout(self, tmp_path: Path, create_void_test_results_db): + """Both DBs have NULL return_value and same stdout — equivalent.""" + original_path = tmp_path / "original.db" + candidate_path = tmp_path / "candidate.db" + + results = [ + { + "test_class_name": "PrinterTest", + "function_getting_tested": "printMessage", + "loop_index": 1, + "iteration_id": "1_0", + "return_value": None, + "stdout": "Hello World\n", + } + ] + + create_void_test_results_db(original_path, results) + create_void_test_results_db(candidate_path, results) + + equivalent, diffs = compare_test_results(original_path, candidate_path) + + assert equivalent is True + assert len(diffs) == 0 + + def test_void_sqlite_different_stdout(self, tmp_path: Path, create_void_test_results_db): + """Both DBs have NULL return_value but different stdout — not equivalent.""" + original_path = tmp_path / "original.db" + candidate_path = tmp_path / "candidate.db" + + original_results = [ + { + "test_class_name": "LoggerTest", + "function_getting_tested": "log", + "loop_index": 1, + "iteration_id": "1_0", + "return_value": None, + "stdout": "INFO: Starting\n", + } + ] + + candidate_results = [ + { + "test_class_name": "LoggerTest", + "function_getting_tested": "log", + "loop_index": 1, + "iteration_id": "1_0", + "return_value": None, + "stdout": "DEBUG: Starting\n", + } + ] + + create_void_test_results_db(original_path, original_results) + create_void_test_results_db(candidate_path, candidate_results) + + equivalent, diffs = compare_test_results(original_path, candidate_path) + + assert equivalent is False + assert len(diffs) == 1 + + def test_void_sqlite_null_stdout_both(self, tmp_path: Path, create_void_test_results_db): + """Both DBs have NULL return_value and NULL stdout — equivalent.""" + original_path = tmp_path / "original.db" + candidate_path = tmp_path / "candidate.db" + + results = [ + { + "test_class_name": "WorkerTest", + "function_getting_tested": "doWork", + "loop_index": 1, + "iteration_id": "1_0", + "return_value": None, + "stdout": None, + } + ] + + create_void_test_results_db(original_path, results) + create_void_test_results_db(candidate_path, results) + + equivalent, diffs = compare_test_results(original_path, candidate_path) + + assert equivalent is True + assert len(diffs) == 0 diff --git a/tests/test_languages/test_java/test_discovery.py b/tests/test_languages/test_java/test_discovery.py index e42cfe8c2..683f5a596 100644 --- a/tests/test_languages/test_java/test_discovery.py +++ b/tests/test_languages/test_java/test_discovery.py @@ -132,7 +132,11 @@ def test_filter_exclude_pattern(self): assert "setData" not in method_names def test_filter_require_return(self): - """Test filtering by require_return.""" + """Test filtering by require_return. + + With require_return=True, void methods are still included (verified via test pass/fail), + but non-void methods without an actual return statement are excluded. + """ source = """ public class Example { public void doSomething() {} @@ -144,8 +148,10 @@ def test_filter_require_return(self): """ criteria = FunctionFilterCriteria(require_return=True) functions = discover_functions_from_source(source, filter_criteria=criteria) - assert len(functions) == 1 - assert functions[0].function_name == "getValue" + names = {f.function_name for f in functions} + assert "getValue" in names + assert "doSomething" in names + assert len(functions) == 2 def test_filter_by_line_count(self): """Test filtering by line count.""" diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 290d39b28..9661531aa 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -149,13 +149,19 @@ def test_instrument_behavior_mode_simple(self, tmp_path: Path): long _cf_end1_1 = -1; long _cf_start1_1 = 0; byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); _cf_start1_1 = System.nanoTime(); _cf_result1_1 = calc.add(2, 2); _cf_end1_1 = System.nanoTime(); _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} long _cf_end1_1_finally = System.nanoTime(); long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); @@ -168,19 +174,20 @@ def test_instrument_behavior_mode_simple(self, tmp_path: Path): _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { _cf_pstmt1_1.setString(1, _cf_mod1); _cf_pstmt1_1.setString(2, _cf_cls1); _cf_pstmt1_1.setString(3, _cf_test1); _cf_pstmt1_1.setString(4, _cf_fn1); _cf_pstmt1_1.setInt(5, _cf_loop1); - _cf_pstmt1_1.setString(6, "1"); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); _cf_pstmt1_1.setLong(7, _cf_dur1_1); _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); _cf_pstmt1_1.executeUpdate(); } } @@ -275,13 +282,19 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path long _cf_end2_1 = -1; long _cf_start2_1 = 0; byte[] _cf_serializedResult2_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture2_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut2_1 = System.out; + String _cf_stdout2_1 = null; System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture2_1)); _cf_start2_1 = System.nanoTime(); _cf_result2_1 = Fibonacci.fibonacci(0); _cf_end2_1 = System.nanoTime(); _cf_serializedResult2_1 = com.codeflash.Serializer.serialize((Object) _cf_result2_1); } finally { + System.setOut(_cf_origOut2_1); + try { _cf_stdout2_1 = _cf_stdoutCapture2_1.toString("UTF-8"); } catch (Exception _cf_encEx2_1) {} long _cf_end2_1_finally = System.nanoTime(); long _cf_dur2_1 = (_cf_end2_1 != -1 ? _cf_end2_1 : _cf_end2_1_finally) - _cf_start2_1; System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + "1" + "######!"); @@ -294,19 +307,20 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path _cf_stmt2_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql2_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql2_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt2_1 = _cf_conn2_1.prepareStatement(_cf_sql2_1)) { _cf_pstmt2_1.setString(1, _cf_mod2); _cf_pstmt2_1.setString(2, _cf_cls2); _cf_pstmt2_1.setString(3, _cf_test2); _cf_pstmt2_1.setString(4, _cf_fn2); _cf_pstmt2_1.setInt(5, _cf_loop2); - _cf_pstmt2_1.setString(6, "1"); + _cf_pstmt2_1.setString(6, _cf_cls2 + "." + _cf_test2 + ".1_" + _cf_testIteration2); _cf_pstmt2_1.setLong(7, _cf_dur2_1); _cf_pstmt2_1.setBytes(8, _cf_serializedResult2_1); _cf_pstmt2_1.setString(9, "function_call"); + _cf_pstmt2_1.setString(10, _cf_stdout2_1); _cf_pstmt2_1.executeUpdate(); } } @@ -403,13 +417,19 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat long _cf_end2_1 = -1; long _cf_start2_1 = 0; byte[] _cf_serializedResult2_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture2_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut2_1 = System.out; + String _cf_stdout2_1 = null; System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture2_1)); _cf_start2_1 = System.nanoTime(); _cf_result2_1 = Fibonacci.fibonacci(0); _cf_end2_1 = System.nanoTime(); _cf_serializedResult2_1 = com.codeflash.Serializer.serialize((Object) _cf_result2_1); } finally { + System.setOut(_cf_origOut2_1); + try { _cf_stdout2_1 = _cf_stdoutCapture2_1.toString("UTF-8"); } catch (Exception _cf_encEx2_1) {} long _cf_end2_1_finally = System.nanoTime(); long _cf_dur2_1 = (_cf_end2_1 != -1 ? _cf_end2_1 : _cf_end2_1_finally) - _cf_start2_1; System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + "1" + "######!"); @@ -422,19 +442,20 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat _cf_stmt2_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql2_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql2_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt2_1 = _cf_conn2_1.prepareStatement(_cf_sql2_1)) { _cf_pstmt2_1.setString(1, _cf_mod2); _cf_pstmt2_1.setString(2, _cf_cls2); _cf_pstmt2_1.setString(3, _cf_test2); _cf_pstmt2_1.setString(4, _cf_fn2); _cf_pstmt2_1.setInt(5, _cf_loop2); - _cf_pstmt2_1.setString(6, "1"); + _cf_pstmt2_1.setString(6, _cf_cls2 + "." + _cf_test2 + ".1_" + _cf_testIteration2); _cf_pstmt2_1.setLong(7, _cf_dur2_1); _cf_pstmt2_1.setBytes(8, _cf_serializedResult2_1); _cf_pstmt2_1.setString(9, "function_call"); + _cf_pstmt2_1.setString(10, _cf_stdout2_1); _cf_pstmt2_1.executeUpdate(); } } @@ -748,13 +769,19 @@ class TestKryoSerializerUsage: long _cf_end1_1 = -1; long _cf_start1_1 = 0; byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); _cf_start1_1 = System.nanoTime(); _cf_result1_1 = obj.foo(); _cf_end1_1 = System.nanoTime(); _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} long _cf_end1_1_finally = System.nanoTime(); long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); @@ -767,19 +794,20 @@ class TestKryoSerializerUsage: _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { _cf_pstmt1_1.setString(1, _cf_mod1); _cf_pstmt1_1.setString(2, _cf_cls1); _cf_pstmt1_1.setString(3, _cf_test1); _cf_pstmt1_1.setString(4, _cf_fn1); _cf_pstmt1_1.setInt(5, _cf_loop1); - _cf_pstmt1_1.setString(6, "1"); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); _cf_pstmt1_1.setLong(7, _cf_dur1_1); _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); _cf_pstmt1_1.executeUpdate(); } } @@ -1261,13 +1289,19 @@ def test_instrument_generated_test_behavior_mode(self): long _cf_end1_1 = -1; long _cf_start1_1 = 0; byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); _cf_start1_1 = System.nanoTime(); _cf_result1_1 = new Calculator().add(2, 2); _cf_end1_1 = System.nanoTime(); _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} long _cf_end1_1_finally = System.nanoTime(); long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); @@ -1280,19 +1314,20 @@ def test_instrument_generated_test_behavior_mode(self): _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { _cf_pstmt1_1.setString(1, _cf_mod1); _cf_pstmt1_1.setString(2, _cf_cls1); _cf_pstmt1_1.setString(3, _cf_test1); _cf_pstmt1_1.setString(4, _cf_fn1); _cf_pstmt1_1.setInt(5, _cf_loop1); - _cf_pstmt1_1.setString(6, "1"); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); _cf_pstmt1_1.setLong(7, _cf_dur1_1); _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); _cf_pstmt1_1.executeUpdate(); } } @@ -2056,6 +2091,12 @@ def java_project(self, tmp_path: Path): yield tmp_path, src_dir, test_dir + # Clean up any SQLite files left in the shared temp dir to prevent cross-test contamination + from codeflash.code_utils.code_utils import get_run_tmp_file + + for i in range(10): + get_run_tmp_file(Path(f"test_return_values_{i}.sqlite")).unlink(missing_ok=True) + # Reset language back to Python current_module._current_language = None set_current_language(Language.PYTHON) @@ -2649,13 +2690,19 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): long _cf_end1_1 = -1; long _cf_start1_1 = 0; byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); _cf_start1_1 = System.nanoTime(); _cf_result1_1 = counter.increment(); _cf_end1_1 = System.nanoTime(); _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} long _cf_end1_1_finally = System.nanoTime(); long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); @@ -2668,19 +2715,20 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + - "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); } - String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { _cf_pstmt1_1.setString(1, _cf_mod1); _cf_pstmt1_1.setString(2, _cf_cls1); _cf_pstmt1_1.setString(3, _cf_test1); _cf_pstmt1_1.setString(4, _cf_fn1); _cf_pstmt1_1.setInt(5, _cf_loop1); - _cf_pstmt1_1.setString(6, "1"); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); _cf_pstmt1_1.setLong(7, _cf_dur1_1); _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); _cf_pstmt1_1.executeUpdate(); } } @@ -2787,6 +2835,7 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): runtime, return_value, verification_type, + stdout, ) = row # Verify fields @@ -3276,3 +3325,535 @@ def __init__(self, path): assert math.isclose(duration, 100_000_000, rel_tol=0.15), ( f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)" ) + + +class TestVoidFunctionInstrumentation: + """Tests for void function instrumentation with exact string equality.""" + + def test_void_instance_method_with_args(self, tmp_path: Path): + """Void instance method serializes receiver + args as side effects.""" + source = """import org.junit.jupiter.api.Test; + +public class WorkerTest { + @Test + public void testDoWork() { + Worker obj = new Worker(); + obj.doWork(42); + } +} +""" + test_file = tmp_path / "WorkerTest.java" + test_file.write_text(source) + + func = FunctionToOptimize( + function_name="doWork", + file_path=tmp_path / "Worker.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + return_type="void", + ) + + success, result = instrument_existing_test( + test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file + ) + + expected = """import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class WorkerTest__perfinstrumented { + @Test + public void testDoWork() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "WorkerTest"; + String _cf_cls1 = "WorkerTest"; + String _cf_fn1 = "doWork"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testDoWork"; + Worker obj = new Worker(); + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); + try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); + _cf_start1_1 = System.nanoTime(); + obj.doWork(42); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{obj, 42}); + } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + } +} +""" + assert success is True + assert result == expected + + def test_void_static_method_excludes_receiver(self, tmp_path: Path): + """Void static method excludes uppercase receiver from serialization.""" + source = """import org.junit.jupiter.api.Test; + +public class UtilsTest { + @Test + public void testProcess() { + Utils.process("data"); + } +} +""" + test_file = tmp_path / "UtilsTest.java" + test_file.write_text(source) + + func = FunctionToOptimize( + function_name="process", + file_path=tmp_path / "Utils.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + return_type="void", + ) + + success, result = instrument_existing_test( + test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file + ) + + expected = """import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class UtilsTest__perfinstrumented { + @Test + public void testProcess() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "UtilsTest"; + String _cf_cls1 = "UtilsTest"; + String _cf_fn1 = "process"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testProcess"; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); + try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); + _cf_start1_1 = System.nanoTime(); + Utils.process("data"); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{"data"}); + } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + } +} +""" + assert success is True + assert result == expected + + def test_void_instance_no_args_serializes_receiver_only(self, tmp_path: Path): + """Void instance method with no args serializes only the receiver.""" + source = """import org.junit.jupiter.api.Test; + +public class CacheTest { + @Test + public void testReset() { + Cache cache = new Cache(); + cache.reset(); + } +} +""" + test_file = tmp_path / "CacheTest.java" + test_file.write_text(source) + + func = FunctionToOptimize( + function_name="reset", + file_path=tmp_path / "Cache.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + return_type="void", + ) + + success, result = instrument_existing_test( + test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file + ) + + expected = """import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class CacheTest__perfinstrumented { + @Test + public void testReset() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "CacheTest"; + String _cf_cls1 = "CacheTest"; + String _cf_fn1 = "reset"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testReset"; + Cache cache = new Cache(); + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); + try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); + _cf_start1_1 = System.nanoTime(); + cache.reset(); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{cache}); + } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + } +} +""" + assert success is True + assert result == expected + + def test_void_static_no_args_serializes_null(self, tmp_path: Path): + """Void static method with no args serializes null (no parts).""" + source = """import org.junit.jupiter.api.Test; + +public class ConfigTest { + @Test + public void testReload() { + Config.reload(); + } +} +""" + test_file = tmp_path / "ConfigTest.java" + test_file.write_text(source) + + func = FunctionToOptimize( + function_name="reload", + file_path=tmp_path / "Config.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + return_type="void", + ) + + success, result = instrument_existing_test( + test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file + ) + + expected = """import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class ConfigTest__perfinstrumented { + @Test + public void testReload() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "ConfigTest"; + String _cf_cls1 = "ConfigTest"; + String _cf_fn1 = "reload"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testReload"; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); + try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); + _cf_start1_1 = System.nanoTime(); + Config.reload(); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = null; + } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + } +} +""" + assert success is True + assert result == expected + + def test_void_instance_multiple_args(self, tmp_path: Path): + """Void instance method with multiple args serializes receiver + all args.""" + source = """import org.junit.jupiter.api.Test; + +public class SwapperTest { + @Test + public void testSwap() { + Swapper s = new Swapper(); + int[] arr = {1, 2}; + s.swap(arr, 0, 1); + } +} +""" + test_file = tmp_path / "SwapperTest.java" + test_file.write_text(source) + + func = FunctionToOptimize( + function_name="swap", + file_path=tmp_path / "Swapper.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + return_type="void", + ) + + success, result = instrument_existing_test( + test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file + ) + + expected = """import org.junit.jupiter.api.Test; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class SwapperTest__perfinstrumented { + @Test + public void testSwap() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "SwapperTest"; + String _cf_cls1 = "SwapperTest"; + String _cf_fn1 = "swap"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testSwap"; + Swapper s = new Swapper(); + int[] arr = {1, 2}; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + java.io.ByteArrayOutputStream _cf_stdoutCapture1_1 = new java.io.ByteArrayOutputStream(); + java.io.PrintStream _cf_origOut1_1 = System.out; + String _cf_stdout1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":1" + "######$!"); + try { + System.setOut(new java.io.PrintStream(_cf_stdoutCapture1_1)); + _cf_start1_1 = System.nanoTime(); + s.swap(arr, 0, 1); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{s, arr, 0, 1}); + } finally { + System.setOut(_cf_origOut1_1); + try { _cf_stdout1_1 = _cf_stdoutCapture1_1.toString("UTF-8"); } catch (Exception _cf_encEx1_1) {} + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT, stdout TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, _cf_cls1 + "." + _cf_test1 + ".1_" + _cf_testIteration1); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.setString(10, _cf_stdout1_1); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + } +} +""" + assert success is True + assert result == expected diff --git a/tests/test_languages/test_java/test_integration.py b/tests/test_languages/test_java/test_integration.py index d0820e38e..f174ed92c 100644 --- a/tests/test_languages/test_java/test_integration.py +++ b/tests/test_languages/test_java/test_integration.py @@ -6,17 +6,11 @@ from codeflash.languages.base import FunctionFilterCriteria, Language from codeflash.languages.java import ( - JavaSupport, - detect_build_tool, detect_java_project, discover_functions, discover_functions_from_source, discover_test_methods, - discover_tests, extract_code_context, - find_helper_functions, - find_test_root, - format_java_code, get_java_analyzer, get_java_support, is_java_project, @@ -226,9 +220,7 @@ def test_full_optimization_cycle(self, support, tmp_path: Path): return new String(chars); }""" - optimized = support.replace_function( - src_file.read_text(), functions[0], new_code - ) + optimized = support.replace_function(src_file.read_text(), functions[0], new_code) assert "Optimized version" in optimized assert "StringUtils" in optimized @@ -334,11 +326,13 @@ def test_filter_by_various_criteria(self): assert "publicMethod" in public_names or len(functions) >= 0 # Test filtering by require_return + # With require_return=True, void methods are still included (verified via test pass/fail), + # but non-void methods without return statements would be excluded criteria = FunctionFilterCriteria(require_return=True) functions = discover_functions_from_source(source, filter_criteria=criteria) - # voidMethod should be excluded names = {f.function_name for f in functions} - assert "voidMethod" not in names + assert "voidMethod" in names + assert "publicMethod" in names class TestNormalizationIntegration: diff --git a/tests/test_languages/test_java/test_remove_asserts.py b/tests/test_languages/test_java/test_remove_asserts.py index edc7138ce..cbb9c4f93 100644 --- a/tests/test_languages/test_java/test_remove_asserts.py +++ b/tests/test_languages/test_java/test_remove_asserts.py @@ -41,7 +41,7 @@ def test_assertfalse_with_message(self): public class BitSetTest { @Test public void testGet_IndexZero_ReturnsFalse() { - boolean _cf_result1 = instance.get(0); + Object _cf_result1 = instance.get(0); } } """ @@ -67,7 +67,7 @@ def test_asserttrue_with_message(self): public class BitSetTest { @Test public void testGet_SetBit_DetectedTrue() { - boolean _cf_result1 = bs.get(67); + Object _cf_result1 = bs.get(67); } } """ @@ -485,7 +485,7 @@ def test_asserttrue_boolean_call(self): public class FibonacciTest { @Test void testIsFibonacci() { - boolean _cf_result1 = Fibonacci.isFibonacci(5); + Object _cf_result1 = Fibonacci.isFibonacci(5); } } """ @@ -511,7 +511,7 @@ def test_assertfalse_boolean_call(self): public class FibonacciTest { @Test void testIsNotFibonacci() { - boolean _cf_result1 = Fibonacci.isFibonacci(4); + Object _cf_result1 = Fibonacci.isFibonacci(4); } } """ @@ -709,7 +709,7 @@ def test_multiple_calls_in_one_assertion(self): public class FibonacciTest { @Test void testConsecutive() { - boolean _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6)); + Object _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6)); } } """ @@ -1053,24 +1053,24 @@ class TestBitSetLikeQuestDB: @Test public void testGet_IndexZero_ReturnsFalse() { - boolean _cf_result1 = instance.get(0); + Object _cf_result1 = instance.get(0); } @Test public void testGet_SpecificIndexWithinRange_ReturnsFalse() { - boolean _cf_result2 = instance.get(100); + Object _cf_result2 = instance.get(100); } @Test public void testGet_LastIndexOfInitialRange_ReturnsFalse() { int lastIndex = 16 * BitSet.BITS_PER_WORD - 1; - boolean _cf_result3 = instance.get(lastIndex); + Object _cf_result3 = instance.get(lastIndex); } @Test public void testGet_IndexBeyondAllocated_ReturnsFalse() { int beyond = 16 * BitSet.BITS_PER_WORD; - boolean _cf_result4 = instance.get(beyond); + Object _cf_result4 = instance.get(beyond); } @Test(expected = ArrayIndexOutOfBoundsException.class) @@ -1086,22 +1086,22 @@ class TestBitSetLikeQuestDB: long[] words = new long[2]; words[1] = 1L << 3; wordsField.set(bs, words); - boolean _cf_result5 = bs.get(64 + 3); + Object _cf_result5 = bs.get(64 + 3); } @Test public void testGet_LargeIndexDoesNotThrow_ReturnsFalse() { - boolean _cf_result6 = instance.get(Integer.MAX_VALUE); + Object _cf_result6 = instance.get(Integer.MAX_VALUE); } @Test public void testGet_BitBoundaryWordEdge63_ReturnsFalse() { - boolean _cf_result7 = instance.get(63); + Object _cf_result7 = instance.get(63); } @Test public void testGet_BitBoundaryWordEdge64_ReturnsFalse() { - boolean _cf_result8 = instance.get(64); + Object _cf_result8 = instance.get(64); } @Test @@ -1109,7 +1109,7 @@ class TestBitSetLikeQuestDB: int nBits = 1_000_000; BitSet big = new BitSet(nBits); int last = nBits - 1; - boolean _cf_result9 = big.get(last); + Object _cf_result9 = big.get(last); } } """ diff --git a/tests/test_languages/test_java/test_run_and_parse.py b/tests/test_languages/test_java/test_run_and_parse.py index 1e8693a51..f14cdb36e 100644 --- a/tests/test_languages/test_java/test_run_and_parse.py +++ b/tests/test_languages/test_java/test_run_and_parse.py @@ -112,18 +112,19 @@ def java_project(tmp_path: Path): yield tmp_path, src_dir, test_dir + # Clean up any SQLite files left in the shared temp dir to prevent cross-test contamination + from codeflash.code_utils.code_utils import get_run_tmp_file + + for i in range(10): + get_run_tmp_file(Path(f"test_return_values_{i}.sqlite")).unlink(missing_ok=True) + current_module._current_language = None set_current_language(Language.PYTHON) def _make_optimizer(project_root: Path, test_dir: Path, function_name: str, src_file: Path) -> tuple: """Create an Optimizer and FunctionOptimizer for the given function.""" - fto = FunctionToOptimize( - function_name=function_name, - file_path=src_file, - parents=[], - language="java", - ) + fto = FunctionToOptimize(function_name=function_name, file_path=src_file, parents=[], language="java") opt = Optimizer( Namespace( project_root=project_root, @@ -493,12 +494,7 @@ def test_performance_inner_loop_count_and_timing(self, java_project): project_root, src_dir, test_dir = self._setup_precise_waiter_project(java_project) test_results = self._instrument_and_run( - project_root, - src_dir, - test_dir, - self.PRECISE_WAITER_TEST, - "PreciseWaiterTest.java", - inner_iterations=2, + project_root, src_dir, test_dir, self.PRECISE_WAITER_TEST, "PreciseWaiterTest.java", inner_iterations=2 ) # 2 outer loops × 2 inner iterations = 4 total results @@ -543,9 +539,7 @@ def test_performance_inner_loop_count_and_timing(self, java_project): runtime_by_test = test_results.usable_runtime_data_by_test_case() # Should have 1 test case (constant iteration_id per call site) - assert len(runtime_by_test) == 1, ( - f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}" - ) + assert len(runtime_by_test) == 1, f"Expected 1 test case (constant iteration_id), got {len(runtime_by_test)}" # The single test case should have 4 runtimes (2 outer loops × 2 inner iterations) for test_id, test_runtimes in runtime_by_test.items(): @@ -585,12 +579,7 @@ def test_performance_multiple_test_methods_inner_loop(self, java_project): } """ test_results = self._instrument_and_run( - project_root, - src_dir, - test_dir, - multi_test_source, - "PreciseWaiterMultiTest.java", - inner_iterations=2, + project_root, src_dir, test_dir, multi_test_source, "PreciseWaiterMultiTest.java", inner_iterations=2 ) # 2 test methods × 2 outer loops × 2 inner iterations = 8 total results @@ -653,5 +642,3 @@ def test_performance_multiple_test_methods_inner_loop(self, java_project): f"total_passed_runtime {total_runtime / 1_000_000:.3f}ms not close to expected " f"{expected_total_ns / 1_000_000:.1f}ms (2 methods × min of 4 runtimes × 10ms, ±3%)" ) - -