Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 43 additions & 81 deletions python-tools/src/meta/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from .target import (
TargetExpr, Var, Lit, Symbol, Builtin, NamedFun, NewMessage, EnumValue, OneOf, ListExpr, Call, Lambda, Let,
Expand Down Expand Up @@ -64,6 +64,7 @@ class CodeGenerator(ABC):
def __init__(self, proto_messages: Optional[Dict[Tuple[str, str], Any]] = None) -> None:
self.builtin_registry: Dict[str, BuiltinSpec] = {}
self.proto_messages = proto_messages or {}
self._generate_cache: Dict[type, Callable] = {}

@abstractmethod
def escape_keyword(self, name: str) -> str:
Expand Down Expand Up @@ -401,81 +402,42 @@ def generate_lines(self, expr: TargetExpr, lines: List[str], indent: str = "") -
Returns the value expression as a string, or None if the expression
returns (i.e., contains a Return node that was executed).
"""
if isinstance(expr, Var):
return self.escape_identifier(expr.name)
t = type(expr)
method = self._generate_cache.get(t)
if method is None:
method = getattr(self, f'_generate_{t.__name__}', None)
if method is None:
raise ValueError(f"Unknown expression type: {t.__name__}")
self._generate_cache[t] = method
return method(expr, lines, indent)

elif isinstance(expr, Lit):
return self.gen_literal(expr.value)
def _generate_Var(self, expr: Var, lines: List[str], indent: str) -> str:
return self.escape_identifier(expr.name)

elif isinstance(expr, Symbol):
return self.gen_symbol(expr.name)
def _generate_Lit(self, expr: Lit, lines: List[str], indent: str) -> str:
return self.gen_literal(expr.value)

elif isinstance(expr, NewMessage):
return self._generate_newmessage(expr, lines, indent)
def _generate_Symbol(self, expr: Symbol, lines: List[str], indent: str) -> str:
return self.gen_symbol(expr.name)

elif isinstance(expr, EnumValue):
return self._generate_enum_value(expr, lines, indent)
def _generate_Builtin(self, expr: Builtin, lines: List[str], indent: str) -> str:
return self.gen_builtin_ref(expr.name)

elif isinstance(expr, Builtin):
return self.gen_builtin_ref(expr.name)
def _generate_NamedFun(self, expr: NamedFun, lines: List[str], indent: str) -> str:
return self.gen_named_fun_ref(expr.name)

elif isinstance(expr, NamedFun):
return self.gen_named_fun_ref(expr.name)
def _generate_PrintNonterminal(self, expr: PrintNonterminal, lines: List[str], indent: str) -> str:
return self.gen_pretty_nonterminal_ref(expr.nonterminal.name)

elif isinstance(expr, PrintNonterminal):
return self.gen_pretty_nonterminal_ref(expr.nonterminal.name)
def _generate_ParseNonterminal(self, expr: ParseNonterminal, lines: List[str], indent: str) -> str:
return self.gen_parse_nonterminal_ref(expr.nonterminal.name)

elif isinstance(expr, ParseNonterminal):
return self.gen_parse_nonterminal_ref(expr.nonterminal.name)
def _generate_GetField(self, expr: GetField, lines: List[str], indent: str) -> str:
obj_code = self.generate_lines(expr.object, lines, indent)
assert obj_code is not None, "GetField object should not contain a return"
return self.gen_field_access(obj_code, expr.field_name)

elif isinstance(expr, OneOf):
return self._generate_oneof(expr, lines, indent)

elif isinstance(expr, ListExpr):
return self._generate_list_expr(expr, lines, indent)

elif isinstance(expr, GetField):
obj_code = self.generate_lines(expr.object, lines, indent)
assert obj_code is not None, "GetField object should not contain a return"
return self.gen_field_access(obj_code, expr.field_name)

elif isinstance(expr, GetElement):
return self._generate_get_element(expr, lines, indent)

elif isinstance(expr, Call):
return self._generate_call(expr, lines, indent)

elif isinstance(expr, Lambda):
return self._generate_lambda(expr, lines, indent)

elif isinstance(expr, Let):
return self._generate_let(expr, lines, indent)

elif isinstance(expr, IfElse):
return self._generate_if_else(expr, lines, indent)

elif isinstance(expr, Seq):
return self._generate_seq(expr, lines, indent)

elif isinstance(expr, While):
return self._generate_while(expr, lines, indent)

elif isinstance(expr, Foreach):
return self._generate_foreach(expr, lines, indent)

elif isinstance(expr, ForeachEnumerated):
return self._generate_foreach_enumerated(expr, lines, indent)

elif isinstance(expr, Assign):
return self._generate_assign(expr, lines, indent)

elif isinstance(expr, Return):
return self._generate_return(expr, lines, indent)

else:
raise ValueError(f"Unknown expression type: {type(expr)}")

def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]:
def _generate_Call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]:
"""Generate code for a function call."""
# NewMessage should be handled directly, not wrapped in Call
assert not isinstance(expr.func, NewMessage), \
Expand Down Expand Up @@ -569,7 +531,7 @@ def _generate_short_circuit_call(self, op: str, left: TargetExpr, right: TargetE
lines.append(f"{indent}{end}")
return tmp

def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str:
def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str:
"""Generate code for a NewMessage expression.

Default implementation uses positional constructor args.
Expand Down Expand Up @@ -599,7 +561,7 @@ def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str)
lines.append(f"{indent}{self.gen_assignment(tmp, call, is_declaration=True)}")
return tmp

def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) -> str:
def _generate_GetElement(self, expr: GetElement, lines: List[str], indent: str) -> str:
"""Generate code for a GetElement expression.

Default implementation uses 0-based indexing.
Expand All @@ -608,11 +570,11 @@ def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str)
tuple_code = self.generate_lines(expr.tuple_expr, lines, indent)
return f"{tuple_code}[{expr.index}]"

def _generate_enum_value(self, expr: EnumValue, lines: List[str], indent: str) -> str:
def _generate_EnumValue(self, expr: EnumValue, lines: List[str], indent: str) -> str:
"""Generate code for an enum value reference."""
return self.gen_enum_value(expr.module, expr.enum_name, expr.value_name)

def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str:
def _generate_OneOf(self, expr: OneOf, lines: List[str], indent: str) -> str:
"""Generate code for a OneOf expression.

Default implementation treats it as an error since OneOf should only
Expand All @@ -621,7 +583,7 @@ def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str:
"""
raise ValueError(f"OneOf should only appear as arguments to Message constructors: {expr}")

def _generate_list_expr(self, expr: ListExpr, lines: List[str], indent: str) -> str:
def _generate_ListExpr(self, expr: ListExpr, lines: List[str], indent: str) -> str:
"""Generate code for a list expression."""
elements: List[str] = []
for elem in expr.elements:
Expand All @@ -630,7 +592,7 @@ def _generate_list_expr(self, expr: ListExpr, lines: List[str], indent: str) ->
elements.append(elem_code)
return self.gen_list_literal(elements, expr.element_type)

def _generate_lambda(self, expr: Lambda, lines: List[str], indent: str) -> str:
def _generate_Lambda(self, expr: Lambda, lines: List[str], indent: str) -> str:
"""Generate code for a lambda expression."""
params = [self.escape_identifier(p.name) for p in expr.params]
f = gensym()
Expand All @@ -650,15 +612,15 @@ def _generate_lambda(self, expr: Lambda, lines: List[str], indent: str) -> str:
lines.append(f"{indent}{after}")
return f

def _generate_let(self, expr: Let, lines: List[str], indent: str) -> Optional[str]:
def _generate_Let(self, expr: Let, lines: List[str], indent: str) -> Optional[str]:
"""Generate code for a let binding."""
var_name = self.escape_identifier(expr.var.name)
init_val = self.generate_lines(expr.init, lines, indent)
assert init_val is not None, "Let initializer should not contain a return"
lines.append(f"{indent}{self.gen_assignment(var_name, init_val, is_declaration=True)}")
return self.generate_lines(expr.body, lines, indent)

def _generate_if_else(self, expr: IfElse, lines: List[str], indent: str) -> Optional[str]:
def _generate_IfElse(self, expr: IfElse, lines: List[str], indent: str) -> Optional[str]:
"""Generate code for an if-else expression."""
cond_code = self.generate_lines(expr.condition, lines, indent)
assert cond_code is not None, "If condition should not contain a return"
Expand Down Expand Up @@ -756,7 +718,7 @@ def _generate_nil_else_branch(
lines.append(f"{body_indent}{self.gen_assignment(tmp, else_code)}")
return else_code

def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]:
def _generate_Seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]:
"""Generate code for a sequence of expressions.

If any expression returns None (indicating a return statement was executed),
Expand All @@ -769,7 +731,7 @@ def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[st
break
return result

def _generate_while(self, expr: While, lines: List[str], indent: str) -> str:
def _generate_While(self, expr: While, lines: List[str], indent: str) -> str:
"""Generate code for a while loop."""
m = len(lines)
cond_code = self.generate_lines(expr.condition, lines, indent)
Expand Down Expand Up @@ -798,7 +760,7 @@ def _generate_while(self, expr: While, lines: List[str], indent: str) -> str:

return self.gen_none()

def _generate_foreach(self, expr: Foreach, lines: List[str], indent: str) -> str:
def _generate_Foreach(self, expr: Foreach, lines: List[str], indent: str) -> str:
"""Generate code for a foreach loop."""
collection_code = self.generate_lines(expr.collection, lines, indent)
assert collection_code is not None, "Foreach collection should not contain a return"
Expand All @@ -814,7 +776,7 @@ def _generate_foreach(self, expr: Foreach, lines: List[str], indent: str) -> str

return self.gen_none()

def _generate_foreach_enumerated(self, expr: ForeachEnumerated, lines: List[str], indent: str) -> str:
def _generate_ForeachEnumerated(self, expr: ForeachEnumerated, lines: List[str], indent: str) -> str:
"""Generate code for a foreach enumerated loop."""
collection_code = self.generate_lines(expr.collection, lines, indent)
assert collection_code is not None, "ForeachEnumerated collection should not contain a return"
Expand All @@ -831,15 +793,15 @@ def _generate_foreach_enumerated(self, expr: ForeachEnumerated, lines: List[str]

return self.gen_none()

def _generate_assign(self, expr: Assign, lines: List[str], indent: str) -> str:
def _generate_Assign(self, expr: Assign, lines: List[str], indent: str) -> str:
"""Generate code for an assignment."""
var_name = self.escape_identifier(expr.var.name)
expr_code = self.generate_lines(expr.expr, lines, indent)
assert expr_code is not None, "Assignment expression should not contain a return"
lines.append(f"{indent}{self.gen_assignment(var_name, expr_code)}")
return self.gen_none()

def _generate_return(self, expr: Return, lines: List[str], indent: str) -> None:
def _generate_Return(self, expr: Return, lines: List[str], indent: str) -> None:
"""Generate code for a return statement.

Returns None to indicate that the caller should not add another return
Expand Down
27 changes: 12 additions & 15 deletions python-tools/src/meta/codegen_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _generate_nil_else_branch(
"""Go's var declarations zero-initialize, so no else branch needed."""
return self.gen_none()

def _generate_get_element(self, expr: GetElement, lines: List[str], indent: str) -> str:
def _generate_GetElement(self, expr: GetElement, lines: List[str], indent: str) -> str:
"""Go uses 0-based indexing with type assertion for tuple elements."""
tuple_code = self.generate_lines(expr.tuple_expr, lines, indent)
# Add type assertion since tuple elements are interface{}
Expand Down Expand Up @@ -367,18 +367,15 @@ def _is_optional_scalar_field(self, expr) -> bool:
inner_go = self.gen_type(expr.field_type.element_type)
return not self._is_nullable_go_type(inner_go)

def generate_lines(self, expr: TargetExpr, lines: List[str], indent: str = "") -> Optional[str]:
from .target import GetField
# For optional scalar proto fields, use direct PascalCase access
# to preserve pointer type (getters strip it).
if isinstance(expr, GetField) and self._is_optional_scalar_field(expr):
def _generate_GetField(self, expr, lines: List[str], indent: str) -> str:
if self._is_optional_scalar_field(expr):
obj_code = self.generate_lines(expr.object, lines, indent)
assert obj_code is not None
pascal_field = to_pascal_case(expr.field_name)
return f"{obj_code}.{pascal_field}"
return super().generate_lines(expr, lines, indent)
return super()._generate_GetField(expr, lines, indent)

def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]:
def _generate_Seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[str]:
"""Generate Go sequence, suppressing unused variable errors.

In Go, declared-but-unused variables are compile errors. When an
Expand All @@ -395,7 +392,7 @@ def _generate_seq(self, expr: Seq, lines: List[str], indent: str) -> Optional[st
lines.append(f"{indent}_ = {result}")
return result

def _generate_newmessage(self, expr: NewMessage, lines: List[str], indent: str) -> str:
def _generate_NewMessage(self, expr: NewMessage, lines: List[str], indent: str) -> str:
"""Generate Go code for NewMessage with fields containing OneOf calls.

In Go protobuf, OneOf fields require wrapping values in the appropriate
Expand Down Expand Up @@ -542,7 +539,7 @@ def unwrap_if_option(field_expr, field_value: str) -> Tuple[str, str]:

return tmp

def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]:
def _generate_Call(self, expr: Call, lines: List[str], indent: str) -> Optional[str]:
"""Override to handle OneOf, Parse/PrintNonterminal, NamedFun, and option builtins for Go."""
from .target import NamedFun, FunctionType, ListType, BaseType, Builtin, OptionType

Expand Down Expand Up @@ -601,7 +598,7 @@ def _generate_call(self, expr: Call, lines: List[str], indent: str) -> Optional[
return tmp

# Fall back to base implementation
return super()._generate_call(expr, lines, indent)
return super()._generate_Call(expr, lines, indent)

def _generate_option_builtin(self, expr: Call, lines: List[str], indent: str) -> Optional[str]:
"""Generate Go code for option-related builtins using pointer/nil idioms."""
Expand Down Expand Up @@ -680,17 +677,17 @@ def _generate_option_builtin(self, expr: Call, lines: List[str], indent: str) ->
return f"deref({opt_code}, {default_code})"

# Should not reach here
return super()._generate_call(expr, lines, indent)
return super()._generate_Call(expr, lines, indent)

def _generate_oneof(self, expr: OneOf, lines: List[str], indent: str) -> str:
def _generate_OneOf(self, expr: OneOf, lines: List[str], indent: str) -> str:
"""Generate Go OneOf reference.

OneOf should only appear as the function in Call(OneOf(...), [value]).
This method shouldn't normally be called.
"""
raise ValueError(f"OneOf should only appear in Call(OneOf(...), [value]) pattern: {expr}")

def _generate_return(self, expr, lines: List[str], indent: str) -> None:
def _generate_Return(self, expr, lines: List[str], indent: str) -> None:
"""Generate Go return statement, wrapping with ptr() for Option types when needed."""
from .target import Lit, Call, Builtin

Expand All @@ -715,7 +712,7 @@ def _generate_return(self, expr, lines: List[str], indent: str) -> None:
lines.append(f"{indent}{self.gen_return(expr_code)}")
return None

def _generate_assign(self, expr, lines: List[str], indent: str) -> str:
def _generate_Assign(self, expr, lines: List[str], indent: str) -> str:
"""Generate Go assignment, handling type-annotated nil declarations.

In Go, `var_name := nil` is not valid because nil has no type.
Expand Down
Loading
Loading