diff --git a/docs/changelog.md b/docs/changelog.md index cf43ce81..ce67da63 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -13,6 +13,7 @@ _This project uses semantic versioning_ - Improve doctest support, teaching expressions about their `__module__`, `__dir__`, and special methods. - Surface original Python exceptions from the runtime and tighten pretty-printing of values that cannot be re-parsed to make debugging e-graph executions easier. - Update the bundled Egglog crate, visualizer, and related dev dependencies (including `ipykernel`) to pick up the latest backend fixes. +- Fix lookup of cost model based on value (see [zulip for issue](https://egraphs.zulipchat.com/#narrow/channel/375765-egg.2Fegglog/topic/Cost.20function.3A.20using.20function.20values.20of.20subtrees/near/577062352)) ## 11.4.0 (2025-10-02) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index c82e38ff..2a72418c 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -683,6 +683,8 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str: assert_never(ref) def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value: + if isinstance(typed_expr.expr, ValueDecl): + return typed_expr.expr.value egg_expr = self.typed_expr_to_egg(typed_expr, False) return self.egraph.eval_expr(egg_expr)[1] diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 173e817a..d658844e 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1568,3 +1568,37 @@ def __radd__(self, other: object) -> tuple[X, X]: ... assert X(1) + 10 == (X(1), 10) assert 10 + X(1) == (X(10), X(1)) + + +def test_custom_cost_model_size(): + """ + https://egraphs.zulipchat.com/#narrow/channel/375765-egg.2Fegglog/topic/Cost.20function.3A.20using.20function.20values.20of.20subtrees/near/577062352 + """ + + class KAT(Expr): + @classmethod + def eps(cls) -> KAT: ... + + @classmethod + def emp(cls) -> KAT: ... + + def func(self, other: KAT) -> KAT: ... + + def size(self) -> i64: ... + + eps, emp = KAT.eps(), KAT.emp() + + eg = EGraph() + q0 = eg.let("q0", KAT.func(eps, emp)) + + eg.register(set_(eps.size()).to(i64(1))) + eg.register(set_(emp.size()).to(i64(0))) + + def conv_cost(eg, expr, child_costs): + if isinstance(expr, KAT): + args = get_callable_args(expr) + return sum(int(eg.lookup_function_value(cast("KAT", a).size())) for a in args) + + return 2 + + assert eg.extract(q0, include_cost=True, cost_model=conv_cost) == (KAT.eps().func(KAT.emp()), 1)