From 24f5d4cce508c3ba31b59258f0dbd16aac1787e2 Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Fri, 13 Feb 2026 14:11:09 +0100 Subject: [PATCH] Introduce a visitor for target IR --- python-tools/src/meta/codegen_base.py | 124 +++++++------------ python-tools/src/meta/codegen_go.py | 27 ++--- python-tools/src/meta/codegen_julia.py | 23 ++-- python-tools/src/meta/codegen_python.py | 2 +- python-tools/src/meta/grammar_validator.py | 131 +++++++++------------ python-tools/src/meta/target_utils.py | 97 ++++++++------- python-tools/src/meta/target_visitor.py | 85 +++++++++++++ 7 files changed, 253 insertions(+), 236 deletions(-) create mode 100644 python-tools/src/meta/target_visitor.py diff --git a/python-tools/src/meta/codegen_base.py b/python-tools/src/meta/codegen_base.py index 98518d22..662f9c4f 100644 --- a/python-tools/src/meta/codegen_base.py +++ b/python-tools/src/meta/codegen_base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from .target import ( TargetExpr, Var, Lit, Symbol, Builtin, NamedFun, NewMessage, EnumValue, OneOf, ListExpr, Call, Lambda, Let, @@ -64,6 +64,7 @@ class CodeGenerator(ABC): def __init__(self, proto_messages: Optional[Dict[Tuple[str, str], Any]] = None) -> None: self.builtin_registry: Dict[str, BuiltinSpec] = {} self.proto_messages = proto_messages or {} + self._generate_cache: Dict[type, Callable] = {} @abstractmethod def escape_keyword(self, name: str) -> str: @@ -401,81 +402,42 @@ def generate_lines(self, expr: TargetExpr, lines: List[str], indent: str = "") - Returns the value expression as a string, or None if the expression returns (i.e., contains a Return node that was executed). """ - if isinstance(expr, Var): - return self.escape_identifier(expr.name) + t = type(expr) + method = self._generate_cache.get(t) + if method is None: + method = getattr(self, f'_generate_{t.__name__}', None) + if method is None: + raise ValueError(f"Unknown expression type: {t.__name__}") + self._generate_cache[t] = method + return method(expr, lines, indent) - elif isinstance(expr, Lit): - return self.gen_literal(expr.value) + def _generate_Var(self, expr: Var, lines: List[str], indent: str) -> str: + return self.escape_identifier(expr.name) - elif isinstance(expr, Symbol): - return self.gen_symbol(expr.name) + def _generate_Lit(self, expr: Lit, lines: List[str], indent: str) -> str: + return self.gen_literal(expr.value) - elif isinstance(expr, NewMessage): - return self._generate_newmessage(expr, lines, indent) + def _generate_Symbol(self, expr: Symbol, lines: List[str], indent: str) -> str: + return self.gen_symbol(expr.name) - elif isinstance(expr, EnumValue): - return self._generate_enum_value(expr, lines, indent) + def _generate_Builtin(self, expr: Builtin, lines: List[str], indent: str) -> str: + return self.gen_builtin_ref(expr.name) - elif isinstance(expr, Builtin): - return self.gen_builtin_ref(expr.name) + def _generate_NamedFun(self, expr: NamedFun, lines: List[str], indent: str) -> str: + return self.gen_named_fun_ref(expr.name) - elif isinstance(expr, NamedFun): - return self.gen_named_fun_ref(expr.name) + def _generate_PrintNonterminal(self, expr: PrintNonterminal, lines: List[str], indent: str) -> str: + return self.gen_pretty_nonterminal_ref(expr.nonterminal.name) - elif isinstance(expr, PrintNonterminal): - return self.gen_pretty_nonterminal_ref(expr.nonterminal.name) + def _generate_ParseNonterminal(self, expr: ParseNonterminal, lines: List[str], indent: str) -> str: + return self.gen_parse_nonterminal_ref(expr.nonterminal.name) - elif isinstance(expr, ParseNonterminal): - return self.gen_parse_nonterminal_ref(expr.nonterminal.name) + def _generate_GetField(self, expr: GetField, lines: List[str], indent: str) -> str: + obj_code = self.generate_lines(expr.object, lines, indent) + assert obj_code is not None, "GetField object should not contain a return" + return self.gen_field_access(obj_code, expr.field_name) - elif isinstance(expr, OneOf): - return self._generate_oneof(expr, lines, indent) - - elif isinstance(expr, ListExpr): - return self._generate_list_expr(expr, lines, indent) - - elif isinstance(expr, GetField): - obj_code = self.generate_lines(expr.object, lines, indent) - assert obj_code is not None, "GetField object should not contain a return" - return self.gen_field_access(obj_code, expr.field_name) - - elif isinstance(expr, GetElement): - return self._generate_get_element(expr, lines, indent) - - elif isinstance(expr, Call): - return self._generate_call(expr, lines, indent) - - elif isinstance(expr, Lambda): - return self._generate_lambda(expr, lines, indent) - - elif isinstance(expr, Let): - return self._generate_let(expr, lines, indent) - - elif isinstance(expr, IfElse): - return self._generate_if_else(expr, lines, indent) - - elif isinstance(expr, Seq): - return self._generate_seq(expr, lines, indent) - - elif isinstance(expr, While): - return self._generate_while(expr, lines, indent) - - elif isinstance(expr, Foreach): - return self._generate_foreach(expr, lines, indent) - - elif isinstance(expr, ForeachEnumerated): - return self._generate_foreach_enumerated(expr, lines, indent) - - elif isinstance(expr, Assign): - return self._generate_assign(expr, lines, indent) - - elif isinstance(expr, Return): - return self._generate_return(expr, lines, indent) - - else: - raise ValueError(f"Unknown expression type: {type(expr)}") - - def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: + def _generate_Call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: """Generate code for a function call.""" # NewMessage should be handled directly, not wrapped in Call assert not isinstance(expr.func, NewMessage), \ @@ -569,7 +531,7 @@ def _generate_short_circuit_call(self, op: str, left: TargetExpr, right: TargetE lines.append(f"{indent}{end}") return tmp - def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: + def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: """Generate code for a NewMessage expression. Default implementation uses positional constructor args. @@ -599,7 +561,7 @@ def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) lines.append(f"{indent}{self.gen_assignment(tmp, call, is_declaration=True)}") return tmp - def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) -> str: + def _generate_GetElement(self, expr: GetElement, lines: List[str], indent: str) -> str: """Generate code for a GetElement expression. Default implementation uses 0-based indexing. @@ -608,11 +570,11 @@ def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) tuple_code = self.generate_lines(expr.tuple_expr, lines, indent) return f"{tuple_code}[{expr.index}]" - def _generate_enum_value(self, expr: EnumValue, lines: List[str], indent: str) -> str: + def _generate_EnumValue(self, expr: EnumValue, lines: List[str], indent: str) -> str: """Generate code for an enum value reference.""" return self.gen_enum_value(expr.module, expr.enum_name, expr.value_name) - def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str: + def _generate_OneOf(self, expr: OneOf, lines: List[str], indent: str) -> str: """Generate code for a OneOf expression. Default implementation treats it as an error since OneOf should only @@ -621,7 +583,7 @@ def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str: """ raise ValueError(f"OneOf should only appear as arguments to Message constructors: {expr}") - def _generate_list_expr(self, expr: ListExpr, lines: List[str], indent: str) -> str: + def _generate_ListExpr(self, expr: ListExpr, lines: List[str], indent: str) -> str: """Generate code for a list expression.""" elements: List[str] = [] for elem in expr.elements: @@ -630,7 +592,7 @@ def _generate_list_expr(self, expr: ListExpr, lines: List[str], indent: str) -> elements.append(elem_code) return self.gen_list_literal(elements, expr.element_type) - def _generate_lambda(self, expr: Lambda, lines: List[str], indent: str) -> str: + def _generate_Lambda(self, expr: Lambda, lines: List[str], indent: str) -> str: """Generate code for a lambda expression.""" params = [self.escape_identifier(p.name) for p in expr.params] f = gensym() @@ -650,7 +612,7 @@ def _generate_lambda(self, expr: Lambda, lines: List[str], indent: str) -> str: lines.append(f"{indent}{after}") return f - def _generate_let(self, expr: Let, lines: List[str], indent: str) -> Optional[str]: + def _generate_Let(self, expr: Let, lines: List[str], indent: str) -> Optional[str]: """Generate code for a let binding.""" var_name = self.escape_identifier(expr.var.name) init_val = self.generate_lines(expr.init, lines, indent) @@ -658,7 +620,7 @@ def _generate_let(self, expr: Let, lines: List[str], indent: str) -> Optional[st lines.append(f"{indent}{self.gen_assignment(var_name, init_val, is_declaration=True)}") return self.generate_lines(expr.body, lines, indent) - def _generate_if_else(self, expr: IfElse, lines: List[str], indent: str) -> Optional[str]: + def _generate_IfElse(self, expr: IfElse, lines: List[str], indent: str) -> Optional[str]: """Generate code for an if-else expression.""" cond_code = self.generate_lines(expr.condition, lines, indent) assert cond_code is not None, "If condition should not contain a return" @@ -756,7 +718,7 @@ def _generate_nil_else_branch( lines.append(f"{body_indent}{self.gen_assignment(tmp, else_code)}") return else_code - def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]: + def _generate_Seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]: """Generate code for a sequence of expressions. If any expression returns None (indicating a return statement was executed), @@ -769,7 +731,7 @@ def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[st break return result - def _generate_while(self, expr: While, lines: List[str], indent: str) -> str: + def _generate_While(self, expr: While, lines: List[str], indent: str) -> str: """Generate code for a while loop.""" m = len(lines) cond_code = self.generate_lines(expr.condition, lines, indent) @@ -798,7 +760,7 @@ def _generate_while(self, expr: While, lines: List[str], indent: str) -> str: return self.gen_none() - def _generate_foreach(self, expr: Foreach, lines: List[str], indent: str) -> str: + def _generate_Foreach(self, expr: Foreach, lines: List[str], indent: str) -> str: """Generate code for a foreach loop.""" collection_code = self.generate_lines(expr.collection, lines, indent) assert collection_code is not None, "Foreach collection should not contain a return" @@ -814,7 +776,7 @@ def _generate_foreach(self, expr: Foreach, lines: List[str], indent: str) -> str return self.gen_none() - def _generate_foreach_enumerated(self, expr: ForeachEnumerated, lines: List[str], indent: str) -> str: + def _generate_ForeachEnumerated(self, expr: ForeachEnumerated, lines: List[str], indent: str) -> str: """Generate code for a foreach enumerated loop.""" collection_code = self.generate_lines(expr.collection, lines, indent) assert collection_code is not None, "ForeachEnumerated collection should not contain a return" @@ -831,7 +793,7 @@ def _generate_foreach_enumerated(self, expr: ForeachEnumerated, lines: List[str] return self.gen_none() - def _generate_assign(self, expr: Assign, lines: List[str], indent: str) -> str: + def _generate_Assign(self, expr: Assign, lines: List[str], indent: str) -> str: """Generate code for an assignment.""" var_name = self.escape_identifier(expr.var.name) expr_code = self.generate_lines(expr.expr, lines, indent) @@ -839,7 +801,7 @@ def _generate_assign(self, expr: Assign, lines: List[str], indent: str) -> str: lines.append(f"{indent}{self.gen_assignment(var_name, expr_code)}") return self.gen_none() - def _generate_return(self, expr: Return, lines: List[str], indent: str) -> None: + def _generate_Return(self, expr: Return, lines: List[str], indent: str) -> None: """Generate code for a return statement. Returns None to indicate that the caller should not add another return diff --git a/python-tools/src/meta/codegen_go.py b/python-tools/src/meta/codegen_go.py index fb6212f8..fff2fef8 100644 --- a/python-tools/src/meta/codegen_go.py +++ b/python-tools/src/meta/codegen_go.py @@ -334,7 +334,7 @@ def _generate_nil_else_branch( """Go's var declarations zero-initialize, so no else branch needed.""" return self.gen_none() - def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) -> str: + def _generate_GetElement(self, expr: GetElement, lines: List[str], indent: str) -> str: """Go uses 0-based indexing with type assertion for tuple elements.""" tuple_code = self.generate_lines(expr.tuple_expr, lines, indent) # Add type assertion since tuple elements are interface{} @@ -367,18 +367,15 @@ def _is_optional_scalar_field(self, expr) -> bool: inner_go = self.gen_type(expr.field_type.element_type) return not self._is_nullable_go_type(inner_go) - def generate_lines(self, expr: TargetExpr, lines: List[str], indent: str = "") -> Optional[str]: - from .target import GetField - # For optional scalar proto fields, use direct PascalCase access - # to preserve pointer type (getters strip it). - if isinstance(expr, GetField) and self._is_optional_scalar_field(expr): + def _generate_GetField(self, expr, lines: List[str], indent: str) -> str: + if self._is_optional_scalar_field(expr): obj_code = self.generate_lines(expr.object, lines, indent) assert obj_code is not None pascal_field = to_pascal_case(expr.field_name) return f"{obj_code}.{pascal_field}" - return super().generate_lines(expr, lines, indent) + return super()._generate_GetField(expr, lines, indent) - def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]: + def _generate_Seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]: """Generate Go sequence, suppressing unused variable errors. In Go, declared-but-unused variables are compile errors. When an @@ -395,7 +392,7 @@ def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[st lines.append(f"{indent}_ = {result}") return result - def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: + def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: """Generate Go code for NewMessage with fields containing OneOf calls. In Go protobuf, OneOf fields require wrapping values in the appropriate @@ -542,7 +539,7 @@ def unwrap_if_option(field_expr, field_value: str) -> Tuple[str, str]: return tmp - def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: + def _generate_Call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: """Override to handle OneOf, Parse/PrintNonterminal, NamedFun, and option builtins for Go.""" from .target import NamedFun, FunctionType, ListType, BaseType, Builtin, OptionType @@ -601,7 +598,7 @@ def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[ return tmp # Fall back to base implementation - return super()._generate_call(expr, lines, indent) + return super()._generate_Call(expr, lines, indent) def _generate_option_builtin(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: """Generate Go code for option-related builtins using pointer/nil idioms.""" @@ -680,9 +677,9 @@ def _generate_option_builtin(self, expr: Call, lines: List[str], indent: str) -> return f"deref({opt_code}, {default_code})" # Should not reach here - return super()._generate_call(expr, lines, indent) + return super()._generate_Call(expr, lines, indent) - def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str: + def _generate_OneOf(self, expr: OneOf, lines: List[str], indent: str) -> str: """Generate Go OneOf reference. OneOf should only appear as the function in Call(OneOf(...), [value]). @@ -690,7 +687,7 @@ def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str: """ raise ValueError(f"OneOf should only appear in Call(OneOf(...), [value]) pattern: {expr}") - def _generate_return(self, expr, lines: List[str], indent: str) -> None: + def _generate_Return(self, expr, lines: List[str], indent: str) -> None: """Generate Go return statement, wrapping with ptr() for Option types when needed.""" from .target import Lit, Call, Builtin @@ -715,7 +712,7 @@ def _generate_return(self, expr, lines: List[str], indent: str) -> None: lines.append(f"{indent}{self.gen_return(expr_code)}") return None - def _generate_assign(self, expr, lines: List[str], indent: str) -> str: + def _generate_Assign(self, expr, lines: List[str], indent: str) -> str: """Generate Go assignment, handling type-annotated nil declarations. In Go, `var_name := nil` is not valid because nil has no type. diff --git a/python-tools/src/meta/codegen_julia.py b/python-tools/src/meta/codegen_julia.py index 7c019f3d..abad4174 100644 --- a/python-tools/src/meta/codegen_julia.py +++ b/python-tools/src/meta/codegen_julia.py @@ -133,7 +133,7 @@ def gen_builtin_ref(self, name: str) -> str: def gen_named_fun_ref(self, name: str) -> str: # In Julia, named functions are regular functions (not methods on Parser) - # They take parser as the first argument, handled in _generate_call + # They take parser as the first argument, handled in _generate_Call return name def gen_parse_nonterminal_ref(self, name: str) -> str: @@ -246,7 +246,7 @@ def gen_func_def_header(self, name: str, params: List[Tuple[str, str]], def gen_func_def_end(self) -> str: return "end" - def _generate_foreach_enumerated(self, expr, lines: List[str], indent: str) -> str: + def _generate_ForeachEnumerated(self, expr, lines: List[str], indent: str) -> str: """Override to adjust for Julia's 1-based enumerate. The IR generates guards like `index > 0` assuming 0-based indexing. @@ -268,7 +268,7 @@ def _generate_foreach_enumerated(self, expr, lines: List[str], indent: str) -> s lines.append(f"{indent}end") return self.gen_none() - def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) -> str: + def _generate_GetElement(self, expr: GetElement, lines: List[str], indent: str) -> str: """Julia uses 1-based indexing.""" tuple_code = self.generate_lines(expr.tuple_expr, lines, indent) julia_index = expr.index + 1 @@ -286,7 +286,7 @@ def _build_oneof_alt_set(self) -> Set[tuple]: self._oneof_alt_set = result return result - def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: + def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: """Generate NewMessage for Julia proto structs. Julia ProtoBuf represents oneof fields as a single struct field @@ -305,7 +305,7 @@ def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) proto_msg = self.proto_messages.get(msg_key) if proto_msg is None: - return super()._generate_newmessage(expr, lines, indent) + return super()._generate_NewMessage(expr, lines, indent) # Build alt_name → oneof_parent_name mapping alt_to_parent: Dict[str, str] = {} @@ -384,15 +384,14 @@ def _is_oneof_getfield(self, expr: GetField) -> bool: key = (expr.message_type.module, expr.message_type.name, expr.field_name) return key in self._build_oneof_alt_set() - def generate_lines(self, expr: 'TargetExpr', lines: List[str], indent: str = "") -> Optional[str]: - """Override to intercept GetField for OneOf alternatives.""" - if isinstance(expr, GetField) and self._is_oneof_getfield(expr): + def _generate_GetField(self, expr: GetField, lines: List[str], indent: str) -> str: + if self._is_oneof_getfield(expr): obj_code = self.generate_lines(expr.object, lines, indent) sym = self._gen_oneof_symbol(expr.field_name) return f"_get_oneof_field({obj_code}, {sym})" - return super().generate_lines(expr, lines, indent) + return super()._generate_GetField(expr, lines, indent) - def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: + def _generate_Call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]: """Override to handle OneOf and Parse/PrintNonterminal specially for Julia.""" # Check for Call(OneOf(Symbol), [value]) pattern (not in Message constructor) if isinstance(expr.func, OneOf) and len(expr.args) == 1: @@ -422,9 +421,9 @@ def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[ return tmp # Fall back to base implementation - return super()._generate_call(expr, lines, indent) + return super()._generate_Call(expr, lines, indent) - def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str: + def _generate_OneOf(self, expr: OneOf, lines: List[str], indent: str) -> str: """Generate Julia OneOf reference. OneOf should only appear as the function in Call(OneOf(...), [value]). diff --git a/python-tools/src/meta/codegen_python.py b/python-tools/src/meta/codegen_python.py index 4c76d44b..1fc08fb1 100644 --- a/python-tools/src/meta/codegen_python.py +++ b/python-tools/src/meta/codegen_python.py @@ -201,7 +201,7 @@ def gen_func_def_end(self) -> str: # --- NewMessage generation for Python protobuf --- - def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: + def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str: """Generate Python code for NewMessage with keyword-safe field handling. Python protobuf constructors use keyword arguments, but field names diff --git a/python-tools/src/meta/grammar_validator.py b/python-tools/src/meta/grammar_validator.py index 7afc027d..e9570f35 100644 --- a/python-tools/src/meta/grammar_validator.py +++ b/python-tools/src/meta/grammar_validator.py @@ -26,9 +26,11 @@ from .proto_parser import ProtoParser from .proto_ast import ProtoMessage, ProtoField from .target import ( - TargetType, TargetExpr, Call, NewMessage, Builtin, IfElse, Let, Seq, ListExpr, GetField, - GetElement, BaseType, VarType, MessageType, SequenceType, ListType, OptionType, TupleType, Lambda, OneOf + TargetType, TargetExpr, Call, NewMessage, Builtin, IfElse, Let, ListExpr, GetField, + GetElement, BaseType, VarType, MessageType, SequenceType, ListType, OptionType, TupleType, + Lambda, OneOf, ) +from .target_visitor import TargetExprVisitor from .type_env import TypeEnv from .validation_result import ValidationResult @@ -177,80 +179,8 @@ def _check_rule_types(self, rule: Rule) -> None: return None def _check_expr_types(self, expr: TargetExpr, context: str) -> None: - """Recursively type check an expression. - - Validates: - - Message constructor calls have correct argument types - - Builtin calls have correct argument types - """ - if isinstance(expr, Call): - if isinstance(expr.func, Builtin): - self._check_builtin_call_types(expr.func, expr.args, context) - else: - # Recursively check function and args - self._check_expr_types(expr.func, context) - - for arg in expr.args: - self._check_expr_types(arg, context) - - elif isinstance(expr, Let): - self._check_expr_types(expr.init, context) - self._check_expr_types(expr.body, context) - - elif isinstance(expr, IfElse): - self._check_expr_types(expr.condition, context) - self._check_expr_types(expr.then_branch, context) - self._check_expr_types(expr.else_branch, context) - - elif isinstance(expr, Seq): - for sub in expr.exprs: - self._check_expr_types(sub, context) - - elif isinstance(expr, ListExpr): - for elem in expr.elements: - self._check_expr_types(elem, context) - - elif isinstance(expr, GetField): - # Check the object expression - self._check_expr_types(expr.object, context) - # Check that object type matches expected message_type - obj_type = self._infer_expr_type(expr.object) - if obj_type is not None and not self._is_subtype(obj_type, expr.message_type): - self.result.add_error( - "type_field_access", - f"In {context}: GetField expects object of type {expr.message_type}, got {obj_type}", - rule_name=context - ) - - elif isinstance(expr, GetElement): - # Check the tuple expression - self._check_expr_types(expr.tuple_expr, context) - # Check that tuple_expr has tuple type - tuple_type = self._infer_expr_type(expr.tuple_expr) - if tuple_type is not None and not isinstance(tuple_type, TupleType): - self.result.add_error( - "type_tuple_element", - f"In {context}: GetElement expects tuple type, got {tuple_type}", - rule_name=context - ) - # Check index bounds - if isinstance(tuple_type, TupleType): - if expr.index < 0 or expr.index >= len(tuple_type.elements): - self.result.add_error( - "type_tuple_element", - f"In {context}: GetElement index {expr.index} out of bounds for tuple with {len(tuple_type.elements)} elements", - rule_name=context - ) - - elif isinstance(expr, NewMessage): - # Check field names and types against proto spec - self._check_new_message_types(expr, context) - # Recursively check field expressions - for _, field_expr in expr.fields: - self._check_expr_types(field_expr, context) - - # Other expression types (Var, Lit, Symbol, Builtin, etc.) are leaves - return None + """Recursively type check an expression.""" + _ExprTypeChecker(self, context).visit(expr) def _check_new_message_types(self, new_msg: NewMessage, context: str) -> None: """Check that NewMessage has correct field names and types. @@ -791,6 +721,55 @@ def visit(r: Rhs) -> None: visit(rhs) return names +class _ExprTypeChecker(TargetExprVisitor): + """Type checker for TargetExpr trees.""" + + def __init__(self, validator: GrammarValidator, context: str): + super().__init__() + self._validator = validator + self._context = context + + def visit_Call(self, expr: Call) -> None: + if isinstance(expr.func, Builtin): + self._validator._check_builtin_call_types(expr.func, expr.args, self._context) + else: + self.visit(expr.func) + for arg in expr.args: + self.visit(arg) + + def visit_NewMessage(self, expr: NewMessage) -> None: + self._validator._check_new_message_types(expr, self._context) + for _, field_expr in expr.fields: + self.visit(field_expr) + + def visit_GetField(self, expr: GetField) -> None: + self.visit(expr.object) + obj_type = self._validator._infer_expr_type(expr.object) + if obj_type is not None and not self._validator._is_subtype(obj_type, expr.message_type): + self._validator.result.add_error( + "type_field_access", + f"In {self._context}: GetField expects object of type {expr.message_type}, got {obj_type}", + rule_name=self._context + ) + + def visit_GetElement(self, expr: GetElement) -> None: + self.visit(expr.tuple_expr) + tuple_type = self._validator._infer_expr_type(expr.tuple_expr) + if tuple_type is not None and not isinstance(tuple_type, TupleType): + self._validator.result.add_error( + "type_tuple_element", + f"In {self._context}: GetElement expects tuple type, got {tuple_type}", + rule_name=self._context + ) + if isinstance(tuple_type, TupleType): + if expr.index < 0 or expr.index >= len(tuple_type.elements): + self._validator.result.add_error( + "type_tuple_element", + f"In {self._context}: GetElement index {expr.index} out of bounds for tuple with {len(tuple_type.elements)} elements", + rule_name=self._context + ) + + def validate_grammar( grammar: Grammar, parser: ProtoParser, diff --git a/python-tools/src/meta/target_utils.py b/python-tools/src/meta/target_utils.py index 569f9bdd..684ed7a8 100644 --- a/python-tools/src/meta/target_utils.py +++ b/python-tools/src/meta/target_utils.py @@ -24,9 +24,13 @@ from .target import ( BaseType, Builtin, Lambda, Var, Lit, Call, TargetExpr, TargetType, - SequenceType, ListType, OptionType, TupleType, DictType, VarType + SequenceType, ListType, OptionType, TupleType, DictType, VarType, + Let, Assign, Seq, IfElse, While, Foreach, ForeachEnumerated, Return, + NewMessage, GetField, GetElement, ListExpr, ) +from .gensym import gensym from .target_builtins import make_builtin +from .target_visitor import TargetExprVisitor def is_subtype(t1: TargetType, t2: TargetType) -> bool: @@ -137,7 +141,6 @@ def create_identity_function(param_type: TargetType) -> Lambda: def _is_simple_expr(expr: TargetExpr) -> bool: """Check if an expression is cheap to evaluate and has no side effects.""" - from .target import IfElse, Let if isinstance(expr, (Var, Lit)): return True if isinstance(expr, Call): @@ -154,7 +157,6 @@ def _is_simple_expr(expr: TargetExpr) -> bool: def _count_var_occurrences(expr: TargetExpr, var: str) -> int: """Count occurrences of a variable in an expression.""" - from .target import Let, Assign, Seq, IfElse, While, Foreach, ForeachEnumerated, Return, NewMessage, GetField, GetElement, ListExpr if isinstance(expr, Var) and expr.name == var: return 1 elif isinstance(expr, Lambda): @@ -220,7 +222,6 @@ def _new_mapping(mapping: Mapping[str, TargetExpr], shadowed: list[str]): def _subst_inner(expr: TargetExpr, mapping: Mapping[str, TargetExpr]) -> TargetExpr: """Inner substitution helper - performs actual substitution.""" - from .target import Let, Assign, Seq, IfElse, While, Foreach, ForeachEnumerated, Return, NewMessage, GetField, GetElement, ListExpr if isinstance(expr, Var) and expr.name in mapping: return mapping[expr.name] elif isinstance(expr, Lambda): @@ -262,51 +263,49 @@ def _subst_inner(expr: TargetExpr, mapping: Mapping[str, TargetExpr]) -> TargetE return expr +class _SubstTypeValidator(TargetExprVisitor): + """Validates that substitution values have compatible types with their target variables. + + Scope-introducing nodes create new validators with filtered mappings for the body. + """ + + def __init__(self, mapping: Mapping[str, TargetExpr]): + super().__init__() + self._mapping = mapping + + def visit(self, expr: TargetExpr) -> None: + if not self._mapping: + return + super().visit(expr) + + def visit_Var(self, expr: Var) -> None: + if expr.name in self._mapping: + val = self._mapping[expr.name] + assert is_subtype(val.target_type(), expr.type), \ + f"Type mismatch in subst: {expr.name} has type {expr.type} but value has type {val.target_type()}" + + def _visit_scope(self, shadowed: set, children_before_body, body: TargetExpr) -> None: + for child in children_before_body: + self.visit(child) + filtered = {k: v for k, v in self._mapping.items() if k not in shadowed} + _SubstTypeValidator(filtered).visit(body) + + def visit_Lambda(self, expr: Lambda) -> None: + self._visit_scope({p.name for p in expr.params}, [], expr.body) + + def visit_Let(self, expr: Let) -> None: + self._visit_scope({expr.var.name}, [expr.init], expr.body) + + def visit_Foreach(self, expr: Foreach) -> None: + self._visit_scope({expr.var.name}, [expr.collection], expr.body) + + def visit_ForeachEnumerated(self, expr: ForeachEnumerated) -> None: + self._visit_scope({expr.index_var.name, expr.var.name}, [expr.collection], expr.body) + + def _validate_subst_types(expr: TargetExpr, mapping: Mapping[str, TargetExpr]) -> None: """Check that substitution values have compatible types with their target variables.""" - from .target import Let, Assign, Seq, IfElse, While, Foreach, ForeachEnumerated, Return - if isinstance(expr, Var) and expr.name in mapping: - val = mapping[expr.name] - assert is_subtype(val.target_type(), expr.type), \ - f"Type mismatch in subst: {expr.name} has type {expr.type} but value has type {val.target_type()}" - elif isinstance(expr, Lambda): - shadowed = {p.name for p in expr.params} - filtered = {k: v for k, v in mapping.items() if k not in shadowed} - if filtered: - _validate_subst_types(expr.body, filtered) - elif isinstance(expr, Let): - _validate_subst_types(expr.init, mapping) - filtered = {k: v for k, v in mapping.items() if k != expr.var.name} - if filtered: - _validate_subst_types(expr.body, filtered) - elif isinstance(expr, Call): - _validate_subst_types(expr.func, mapping) - for arg in expr.args: - _validate_subst_types(arg, mapping) - elif isinstance(expr, Seq): - for e in expr.exprs: - _validate_subst_types(e, mapping) - elif isinstance(expr, IfElse): - _validate_subst_types(expr.condition, mapping) - _validate_subst_types(expr.then_branch, mapping) - _validate_subst_types(expr.else_branch, mapping) - elif isinstance(expr, While): - _validate_subst_types(expr.condition, mapping) - _validate_subst_types(expr.body, mapping) - elif isinstance(expr, Foreach): - _validate_subst_types(expr.collection, mapping) - filtered = {k: v for k, v in mapping.items() if k != expr.var.name} - if filtered: - _validate_subst_types(expr.body, filtered) - elif isinstance(expr, ForeachEnumerated): - _validate_subst_types(expr.collection, mapping) - filtered = {k: v for k, v in mapping.items() if k not in {expr.index_var.name, expr.var.name}} - if filtered: - _validate_subst_types(expr.body, filtered) - elif isinstance(expr, Assign): - _validate_subst_types(expr.expr, mapping) - elif isinstance(expr, Return): - _validate_subst_types(expr.expr, mapping) + _SubstTypeValidator(mapping).visit(expr) def subst(expr: TargetExpr, mapping: Mapping[str, TargetExpr]) -> TargetExpr: @@ -321,8 +320,6 @@ def subst(expr: TargetExpr, mapping: Mapping[str, TargetExpr]) -> TargetExpr: and its variable occurs more than once, introduces a Let binding to avoid duplicating side effects. """ - from .target import Let - from .gensym import gensym if not mapping: return expr @@ -368,7 +365,6 @@ def make_which_oneof(msg, oneof_name): def make_get_field(obj, field_name, message_type, field_type): """Get field value from message: obj.field_name.""" - from .target import GetField # If field_name is a Lit, extract the string value if isinstance(field_name, Lit): field_name = field_name.value @@ -384,7 +380,6 @@ def make_tuple(*args): def make_get_element(tuple_expr, index): """Extract element from tuple at constant index: tuple_expr[index].""" - from .target import GetElement return GetElement(tuple_expr, index) def make_fst(pair): diff --git a/python-tools/src/meta/target_visitor.py b/python-tools/src/meta/target_visitor.py new file mode 100644 index 00000000..7c68739b --- /dev/null +++ b/python-tools/src/meta/target_visitor.py @@ -0,0 +1,85 @@ +"""Visitor base class for TargetExpr trees. + +TargetExprVisitor walks a TargetExpr tree for side effects (returns None). +Subclasses override visit_Foo methods and call self.visit_children(expr) +when they want default recursion. +""" + +from typing import Callable + +from .target import ( + TargetExpr, Var, Lit, Symbol, Builtin, NamedFun, EnumValue, OneOf, + ParseNonterminal, PrintNonterminal, + NewMessage, ListExpr, Call, GetField, GetElement, + Lambda, Let, IfElse, Seq, While, Foreach, ForeachEnumerated, + Assign, Return, +) + + +class TargetExprVisitor: + """Side-effecting tree walker over TargetExpr nodes. + + For each node, visit() dispatches to visit_ClassName (e.g., visit_Call). + If no such method exists, falls back to visit_children which recurses + into child TargetExpr nodes. + + Subclasses override visit_Foo to add behavior, calling + self.visit_children(expr) when they want default recursion. + """ + + def __init__(self): + self._visit_cache: dict[type, Callable] = {} + + def visit(self, expr: TargetExpr) -> None: + t = type(expr) + method = self._visit_cache.get(t) + if method is None: + method = getattr(self, f'visit_{t.__name__}', self.visit_children) + self._visit_cache[t] = method + method(expr) + + def visit_children(self, expr: TargetExpr) -> None: + """Visit all child TargetExpr nodes.""" + if isinstance(expr, (Var, Lit, Symbol, Builtin, NamedFun, EnumValue, + OneOf, ParseNonterminal, PrintNonterminal)): + return None + if isinstance(expr, NewMessage): + for _, field_expr in expr.fields: + self.visit(field_expr) + elif isinstance(expr, ListExpr): + for elem in expr.elements: + self.visit(elem) + elif isinstance(expr, Call): + self.visit(expr.func) + for arg in expr.args: + self.visit(arg) + elif isinstance(expr, GetField): + self.visit(expr.object) + elif isinstance(expr, GetElement): + self.visit(expr.tuple_expr) + elif isinstance(expr, Lambda): + self.visit(expr.body) + elif isinstance(expr, Let): + self.visit(expr.init) + self.visit(expr.body) + elif isinstance(expr, IfElse): + self.visit(expr.condition) + self.visit(expr.then_branch) + self.visit(expr.else_branch) + elif isinstance(expr, Seq): + for e in expr.exprs: + self.visit(e) + elif isinstance(expr, While): + self.visit(expr.condition) + self.visit(expr.body) + elif isinstance(expr, Foreach): + self.visit(expr.collection) + self.visit(expr.body) + elif isinstance(expr, ForeachEnumerated): + self.visit(expr.collection) + self.visit(expr.body) + elif isinstance(expr, Assign): + self.visit(expr.expr) + elif isinstance(expr, Return): + self.visit(expr.expr) + return None