From 5c86bbaabdd8729c26fec906ee86a5ef6150870f Mon Sep 17 00:00:00 2001 From: dak2 Date: Sun, 5 Apr 2026 15:50:37 +0900 Subject: [PATCH] Propagate predicate types to pattern variables in case...in Previously, all pattern variables (e.g., `in x`, `in [x, y]`) were assigned `Type::Bot`, losing the predicate's type information. Now the predicate VertexId is threaded through the pattern processing chain so variables inherit the correct type. Why: enables type error detection in pattern match bodies (e.g., `case 42; in x; x.upcase; end` now correctly reports an error). Co-Authored-By: Claude Opus 4.6 (1M context) --- core/src/analyzer/conditionals.rs | 70 ++++++++++++++++++++----------- test/pattern_matching_test.rb | 60 ++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 25 deletions(-) diff --git a/core/src/analyzer/conditionals.rs b/core/src/analyzer/conditionals.rs index d7d616d..9c69621 100644 --- a/core/src/analyzer/conditionals.rs +++ b/core/src/analyzer/conditionals.rs @@ -150,15 +150,15 @@ pub(crate) fn process_case_match_node( source: &str, node: &CaseMatchNode, ) -> Option { - if let Some(pred) = node.predicate() { - install_node(genv, lenv, changes, source, &pred); - } + let predicate_vtx = node + .predicate() + .and_then(|pred| install_node(genv, lenv, changes, source, &pred)); let result_vtx = genv.new_vertex(); for condition in &node.conditions() { if let Some(in_node) = condition.as_in_node() { - let vtx = process_in_clause(genv, lenv, changes, source, &in_node); + let vtx = process_in_clause(genv, lenv, changes, source, &in_node, predicate_vtx); if let Some(vtx) = vtx { genv.add_edge(vtx, result_vtx); } @@ -242,8 +242,9 @@ fn process_in_clause( changes: &mut ChangeSet, source: &str, in_node: &InNode, + predicate_vtx: Option, ) -> Option { - process_pattern(genv, lenv, changes, source, &in_node.pattern()); + process_pattern(genv, lenv, changes, source, &in_node.pattern(), predicate_vtx); in_node .statements() .and_then(|s| install_statements(genv, lenv, changes, source, &s)) @@ -256,12 +257,13 @@ fn process_pattern( changes: &mut ChangeSet, source: &str, pattern: &Node, + predicate_vtx: Option, ) { // Guard pattern (in x if condition) if let Some(if_node) = pattern.as_if_node() { if let Some(stmts) = if_node.statements() { for stmt in &stmts.body() { - process_pattern(genv, lenv, changes, source, &stmt); + process_pattern(genv, lenv, changes, source, &stmt, predicate_vtx); } } install_node(genv, lenv, changes, source, &if_node.predicate()); @@ -275,37 +277,37 @@ fn process_pattern( // ImplicitNode: hash shorthand pattern { name: } wraps LocalVariableTargetNode if let Some(implicit) = pattern.as_implicit_node() { - process_pattern(genv, lenv, changes, source, &implicit.value()); + process_pattern(genv, lenv, changes, source, &implicit.value(), predicate_vtx); return; } // LocalVariableTargetNode: single variable binding (in x) if let Some(target) = pattern.as_local_variable_target_node() { let var_name = bytes_to_name(target.name().as_slice()); - let bot_vtx = genv.new_source(Type::Bot); - install_local_var_write(genv, lenv, changes, var_name, bot_vtx); + let type_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::Bot)); + install_local_var_write(genv, lenv, changes, var_name, type_vtx); return; } if let Some(arr) = pattern.as_array_pattern_node() { - process_array_pattern(genv, lenv, changes, source, &arr); + process_array_pattern(genv, lenv, changes, source, &arr, predicate_vtx); return; } if let Some(find) = pattern.as_find_pattern_node() { - process_find_pattern(genv, lenv, changes, source, &find); + process_find_pattern(genv, lenv, changes, source, &find, predicate_vtx); return; } if let Some(hash) = pattern.as_hash_pattern_node() { - process_hash_pattern(genv, lenv, changes, source, &hash); + process_hash_pattern(genv, lenv, changes, source, &hash, predicate_vtx); return; } // AlternationPatternNode: 1 | 2 | 3 if let Some(alt) = pattern.as_alternation_pattern_node() { - process_pattern(genv, lenv, changes, source, &alt.left()); - process_pattern(genv, lenv, changes, source, &alt.right()); + process_pattern(genv, lenv, changes, source, &alt.left(), predicate_vtx); + process_pattern(genv, lenv, changes, source, &alt.right(), predicate_vtx); return; } @@ -347,6 +349,16 @@ fn process_capture_pattern( install_local_var_write(genv, lenv, changes, var_name, type_vtx); } +// TODO: Remove clone +fn type_arg_source(genv: &mut GlobalEnv, vtx: VertexId, index: usize) -> Option { + let source = genv.get_source(vtx)?; + let ty = match &source.ty { + Type::Generic { type_args, .. } => type_args.get(index)?.clone(), + _ => return None, + }; + Some(genv.new_source(ty)) +} + /// Process array pattern: [x, y] or [x, *rest] fn process_array_pattern( genv: &mut GlobalEnv, @@ -354,9 +366,12 @@ fn process_array_pattern( changes: &mut ChangeSet, source: &str, arr: &ArrayPatternNode, + predicate_vtx: Option, ) { + let element_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 0)); + for elem in &arr.requireds() { - process_pattern(genv, lenv, changes, source, &elem); + process_pattern(genv, lenv, changes, source, &elem, element_vtx); } if let Some(target) = arr @@ -366,12 +381,12 @@ fn process_array_pattern( .and_then(|e| e.as_local_variable_target_node()) { let var_name = bytes_to_name(target.name().as_slice()); - let array_vtx = genv.new_source(Type::array_of(Type::Bot)); - install_local_var_write(genv, lenv, changes, var_name, array_vtx); + let rest_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::array_of(Type::Bot))); + install_local_var_write(genv, lenv, changes, var_name, rest_vtx); } for elem in &arr.posts() { - process_pattern(genv, lenv, changes, source, &elem); + process_pattern(genv, lenv, changes, source, &elem, element_vtx); } } @@ -382,10 +397,13 @@ fn process_hash_pattern( changes: &mut ChangeSet, source: &str, hash: &HashPatternNode, + predicate_vtx: Option, ) { + let value_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 1)); + for elem in &hash.elements() { if let Some(assoc) = elem.as_assoc_node() { - process_pattern(genv, lenv, changes, source, &assoc.value()); + process_pattern(genv, lenv, changes, source, &assoc.value(), value_vtx); } } @@ -396,7 +414,7 @@ fn process_hash_pattern( .and_then(|v| v.as_local_variable_target_node()) { let var_name = bytes_to_name(target.name().as_slice()); - let hash_vtx = genv.new_source(Type::instance("Hash")); + let hash_vtx = genv.new_source(Type::hash()); install_local_var_write(genv, lenv, changes, var_name, hash_vtx); } } @@ -408,19 +426,22 @@ fn process_find_pattern( changes: &mut ChangeSet, source: &str, find: &FindPatternNode, + predicate_vtx: Option, ) { + let element_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 0)); + let rest_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::array_of(Type::Bot))); + if let Some(target) = find .left() .expression() .and_then(|e| e.as_local_variable_target_node()) { let var_name = bytes_to_name(target.name().as_slice()); - let array_vtx = genv.new_source(Type::array_of(Type::Bot)); - install_local_var_write(genv, lenv, changes, var_name, array_vtx); + install_local_var_write(genv, lenv, changes, var_name, rest_vtx); } for elem in &find.requireds() { - process_pattern(genv, lenv, changes, source, &elem); + process_pattern(genv, lenv, changes, source, &elem, element_vtx); } if let Some(target) = find @@ -430,7 +451,6 @@ fn process_find_pattern( .and_then(|e| e.as_local_variable_target_node()) { let var_name = bytes_to_name(target.name().as_slice()); - let array_vtx = genv.new_source(Type::array_of(Type::Bot)); - install_local_var_write(genv, lenv, changes, var_name, array_vtx); + install_local_var_write(genv, lenv, changes, var_name, rest_vtx); } } diff --git a/test/pattern_matching_test.rb b/test/pattern_matching_test.rb index f801609..be74603 100644 --- a/test/pattern_matching_test.rb +++ b/test/pattern_matching_test.rb @@ -142,10 +142,70 @@ def test_variable_binding_pattern assert_no_check_errors(source) end + def test_bare_variable_propagates_predicate_type + source = <<~RUBY + case 42 + in x + x.even? + end + RUBY + assert_no_check_errors(source) + end + + def test_array_pattern_propagates_element_type + source = <<~RUBY + case [1, 2, 3] + in [x, y, z] + x.even? + end + RUBY + assert_no_check_errors(source) + end + + def test_find_pattern_propagates_element_type + source = <<~RUBY + case [1, 2, 3] + in [*pre, x, *post] + x.even? + end + RUBY + assert_no_check_errors(source) + end + + def test_guard_pattern_propagates_predicate_type + source = <<~RUBY + case 42 + in x if x > 0 + x.even? + end + RUBY + assert_no_check_errors(source) + end + # ============================================ # Error Detection # ============================================ + def test_bare_variable_type_error + source = <<~RUBY + case 42 + in x + x.upcase + end + RUBY + assert_check_error(source, method_name: 'upcase', receiver_type: 'Integer') + end + + def test_array_pattern_element_type_error + source = <<~RUBY + case [1, 2, 3] + in [x, y, z] + x.upcase + end + RUBY + assert_check_error(source, method_name: 'upcase', receiver_type: 'Integer') + end + def test_capture_pattern_type_error source = <<~RUBY case 42