Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions core/src/analyzer/conditionals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ pub(crate) fn process_case_match_node(
source: &str,
node: &CaseMatchNode,
) -> Option<VertexId> {
if let Some(pred) = node.predicate() {
install_node(genv, lenv, changes, source, &pred);
}
let predicate_vtx = node
.predicate()
.and_then(|pred| install_node(genv, lenv, changes, source, &pred));

let result_vtx = genv.new_vertex();

for condition in &node.conditions() {
if let Some(in_node) = condition.as_in_node() {
let vtx = process_in_clause(genv, lenv, changes, source, &in_node);
let vtx = process_in_clause(genv, lenv, changes, source, &in_node, predicate_vtx);
if let Some(vtx) = vtx {
genv.add_edge(vtx, result_vtx);
}
Expand Down Expand Up @@ -242,8 +242,9 @@ fn process_in_clause(
changes: &mut ChangeSet,
source: &str,
in_node: &InNode,
predicate_vtx: Option<VertexId>,
) -> Option<VertexId> {
process_pattern(genv, lenv, changes, source, &in_node.pattern());
process_pattern(genv, lenv, changes, source, &in_node.pattern(), predicate_vtx);
in_node
.statements()
.and_then(|s| install_statements(genv, lenv, changes, source, &s))
Expand All @@ -256,12 +257,13 @@ fn process_pattern(
changes: &mut ChangeSet,
source: &str,
pattern: &Node,
predicate_vtx: Option<VertexId>,
) {
// Guard pattern (in x if condition)
if let Some(if_node) = pattern.as_if_node() {
if let Some(stmts) = if_node.statements() {
for stmt in &stmts.body() {
process_pattern(genv, lenv, changes, source, &stmt);
process_pattern(genv, lenv, changes, source, &stmt, predicate_vtx);
}
}
install_node(genv, lenv, changes, source, &if_node.predicate());
Expand All @@ -275,37 +277,37 @@ fn process_pattern(

// ImplicitNode: hash shorthand pattern { name: } wraps LocalVariableTargetNode
if let Some(implicit) = pattern.as_implicit_node() {
process_pattern(genv, lenv, changes, source, &implicit.value());
process_pattern(genv, lenv, changes, source, &implicit.value(), predicate_vtx);
return;
}

// LocalVariableTargetNode: single variable binding (in x)
if let Some(target) = pattern.as_local_variable_target_node() {
let var_name = bytes_to_name(target.name().as_slice());
let bot_vtx = genv.new_source(Type::Bot);
install_local_var_write(genv, lenv, changes, var_name, bot_vtx);
let type_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::Bot));
install_local_var_write(genv, lenv, changes, var_name, type_vtx);
return;
}

if let Some(arr) = pattern.as_array_pattern_node() {
process_array_pattern(genv, lenv, changes, source, &arr);
process_array_pattern(genv, lenv, changes, source, &arr, predicate_vtx);
return;
}

if let Some(find) = pattern.as_find_pattern_node() {
process_find_pattern(genv, lenv, changes, source, &find);
process_find_pattern(genv, lenv, changes, source, &find, predicate_vtx);
return;
}

if let Some(hash) = pattern.as_hash_pattern_node() {
process_hash_pattern(genv, lenv, changes, source, &hash);
process_hash_pattern(genv, lenv, changes, source, &hash, predicate_vtx);
return;
}

// AlternationPatternNode: 1 | 2 | 3
if let Some(alt) = pattern.as_alternation_pattern_node() {
process_pattern(genv, lenv, changes, source, &alt.left());
process_pattern(genv, lenv, changes, source, &alt.right());
process_pattern(genv, lenv, changes, source, &alt.left(), predicate_vtx);
process_pattern(genv, lenv, changes, source, &alt.right(), predicate_vtx);
return;
}

Expand Down Expand Up @@ -347,16 +349,29 @@ fn process_capture_pattern(
install_local_var_write(genv, lenv, changes, var_name, type_vtx);
}

// TODO: Remove clone
fn type_arg_source(genv: &mut GlobalEnv, vtx: VertexId, index: usize) -> Option<VertexId> {
let source = genv.get_source(vtx)?;
let ty = match &source.ty {
Type::Generic { type_args, .. } => type_args.get(index)?.clone(),
_ => return None,
};
Some(genv.new_source(ty))
}

/// Process array pattern: [x, y] or [x, *rest]
fn process_array_pattern(
genv: &mut GlobalEnv,
lenv: &mut LocalEnv,
changes: &mut ChangeSet,
source: &str,
arr: &ArrayPatternNode,
predicate_vtx: Option<VertexId>,
) {
let element_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 0));

for elem in &arr.requireds() {
process_pattern(genv, lenv, changes, source, &elem);
process_pattern(genv, lenv, changes, source, &elem, element_vtx);
}

if let Some(target) = arr
Expand All @@ -366,12 +381,12 @@ fn process_array_pattern(
.and_then(|e| e.as_local_variable_target_node())
{
let var_name = bytes_to_name(target.name().as_slice());
let array_vtx = genv.new_source(Type::array_of(Type::Bot));
install_local_var_write(genv, lenv, changes, var_name, array_vtx);
let rest_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::array_of(Type::Bot)));
install_local_var_write(genv, lenv, changes, var_name, rest_vtx);
}

for elem in &arr.posts() {
process_pattern(genv, lenv, changes, source, &elem);
process_pattern(genv, lenv, changes, source, &elem, element_vtx);
}
}

Expand All @@ -382,10 +397,13 @@ fn process_hash_pattern(
changes: &mut ChangeSet,
source: &str,
hash: &HashPatternNode,
predicate_vtx: Option<VertexId>,
) {
let value_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 1));

for elem in &hash.elements() {
if let Some(assoc) = elem.as_assoc_node() {
process_pattern(genv, lenv, changes, source, &assoc.value());
process_pattern(genv, lenv, changes, source, &assoc.value(), value_vtx);
}
}

Expand All @@ -396,7 +414,7 @@ fn process_hash_pattern(
.and_then(|v| v.as_local_variable_target_node())
{
let var_name = bytes_to_name(target.name().as_slice());
let hash_vtx = genv.new_source(Type::instance("Hash"));
let hash_vtx = genv.new_source(Type::hash());
install_local_var_write(genv, lenv, changes, var_name, hash_vtx);
}
}
Expand All @@ -408,19 +426,22 @@ fn process_find_pattern(
changes: &mut ChangeSet,
source: &str,
find: &FindPatternNode,
predicate_vtx: Option<VertexId>,
) {
let element_vtx = predicate_vtx.and_then(|vtx| type_arg_source(genv, vtx, 0));
let rest_vtx = predicate_vtx.unwrap_or_else(|| genv.new_source(Type::array_of(Type::Bot)));

if let Some(target) = find
.left()
.expression()
.and_then(|e| e.as_local_variable_target_node())
{
let var_name = bytes_to_name(target.name().as_slice());
let array_vtx = genv.new_source(Type::array_of(Type::Bot));
install_local_var_write(genv, lenv, changes, var_name, array_vtx);
install_local_var_write(genv, lenv, changes, var_name, rest_vtx);
}

for elem in &find.requireds() {
process_pattern(genv, lenv, changes, source, &elem);
process_pattern(genv, lenv, changes, source, &elem, element_vtx);
}

if let Some(target) = find
Expand All @@ -430,7 +451,6 @@ fn process_find_pattern(
.and_then(|e| e.as_local_variable_target_node())
{
let var_name = bytes_to_name(target.name().as_slice());
let array_vtx = genv.new_source(Type::array_of(Type::Bot));
install_local_var_write(genv, lenv, changes, var_name, array_vtx);
install_local_var_write(genv, lenv, changes, var_name, rest_vtx);
}
}
60 changes: 60 additions & 0 deletions test/pattern_matching_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,70 @@ def test_variable_binding_pattern
assert_no_check_errors(source)
end

def test_bare_variable_propagates_predicate_type
source = <<~RUBY
case 42
in x
x.even?
end
RUBY
assert_no_check_errors(source)
end

def test_array_pattern_propagates_element_type
source = <<~RUBY
case [1, 2, 3]
in [x, y, z]
x.even?
end
RUBY
assert_no_check_errors(source)
end

def test_find_pattern_propagates_element_type
source = <<~RUBY
case [1, 2, 3]
in [*pre, x, *post]
x.even?
end
RUBY
assert_no_check_errors(source)
end

def test_guard_pattern_propagates_predicate_type
source = <<~RUBY
case 42
in x if x > 0
x.even?
end
RUBY
assert_no_check_errors(source)
end

# ============================================
# Error Detection
# ============================================

def test_bare_variable_type_error
source = <<~RUBY
case 42
in x
x.upcase
end
RUBY
assert_check_error(source, method_name: 'upcase', receiver_type: 'Integer')
end

def test_array_pattern_element_type_error
source = <<~RUBY
case [1, 2, 3]
in [x, y, z]
x.upcase
end
RUBY
assert_check_error(source, method_name: 'upcase', receiver_type: 'Integer')
end

def test_capture_pattern_type_error
source = <<~RUBY
case 42
Expand Down