From 27e3d34bc3a38962ab45f9fb9c3079f15d87ed26 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 6 Feb 2026 01:26:09 -0500 Subject: [PATCH 1/5] Add DAG-based kernel typechecker Implement a Lean 4 kernel typechecker using a DAG representation with BUBS (Bottom-Up Beta Substitution) for efficient reduction. The kernel operates on a mutable DAG rather than tree-based expressions, enabling in-place substitution and shared subterm reduction. 12 modules: doubly-linked list, DAG nodes with 10 pointer variants, BUBS upcopy with 12 parent cases, Expr/DAG conversion, universe level operations, WHNF via trail algorithm, definitional equality with lazy delta/proof irrelevance/eta, type inference, and checking for quotients and inductives. --- src/ix.rs | 1 + src/ix/kernel/convert.rs | 813 +++++++++++++++++ src/ix/kernel/dag.rs | 527 +++++++++++ src/ix/kernel/def_eq.rs | 1298 +++++++++++++++++++++++++++ src/ix/kernel/dll.rs | 214 +++++ src/ix/kernel/error.rs | 59 ++ src/ix/kernel/inductive.rs | 772 ++++++++++++++++ src/ix/kernel/level.rs | 393 +++++++++ src/ix/kernel/mod.rs | 11 + src/ix/kernel/quot.rs | 291 +++++++ src/ix/kernel/tc.rs | 1694 ++++++++++++++++++++++++++++++++++++ src/ix/kernel/upcopy.rs | 659 ++++++++++++++ src/ix/kernel/whnf.rs | 1420 ++++++++++++++++++++++++++++++ 13 files changed, 8152 insertions(+) create mode 100644 src/ix/kernel/convert.rs create mode 100644 src/ix/kernel/dag.rs create mode 100644 src/ix/kernel/def_eq.rs create mode 100644 src/ix/kernel/dll.rs create mode 100644 src/ix/kernel/error.rs create mode 100644 src/ix/kernel/inductive.rs create mode 100644 src/ix/kernel/level.rs create mode 100644 src/ix/kernel/mod.rs create mode 100644 src/ix/kernel/quot.rs create mode 100644 src/ix/kernel/tc.rs create mode 100644 src/ix/kernel/upcopy.rs create mode 100644 src/ix/kernel/whnf.rs diff --git a/src/ix.rs b/src/ix.rs index f200d81b..42d298c2 100644 --- a/src/ix.rs +++ b/src/ix.rs @@ -12,6 +12,7 @@ pub mod env; pub mod graph; pub mod ground; pub mod ixon; +pub mod kernel; pub mod mutual; pub mod store; pub mod strong_ordering; diff --git a/src/ix/kernel/convert.rs b/src/ix/kernel/convert.rs new file mode 100644 index 00000000..90811948 --- /dev/null +++ b/src/ix/kernel/convert.rs @@ -0,0 +1,813 @@ +use core::ptr::NonNull; +use std::collections::BTreeMap; + +use crate::ix::env::{Expr, ExprData, Level, Name}; +use crate::lean::nat::Nat; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Expr -> DAG +// ============================================================================ + +pub fn from_expr(expr: &Expr) -> DAG { + let root_parents = DLL::alloc(ParentPtr::Root); + let head = from_expr_go(expr, 0, &BTreeMap::new(), Some(root_parents)); + DAG { head } +} + +fn from_expr_go( + expr: &Expr, + depth: u64, + ctx: &BTreeMap>, + parents: Option>, +) -> DAGPtr { + match expr.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 < depth { + let level = depth - 1 - idx_u64; + match ctx.get(&level) { + Some(&var_ptr) => { + if let Some(parent_link) = parents { + add_to_parents(DAGPtr::Var(var_ptr), parent_link); + } + DAGPtr::Var(var_ptr) + }, + None => { + let var = alloc_val(Var { + depth: level, + binder: BinderPtr::Free, + parents, + }); + DAGPtr::Var(var) + }, + } + } else { + // Free bound variable (dangling de Bruijn index) + let var = + alloc_val(Var { depth: idx_u64, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + } + }, + + ExprData::Fvar(_name, _) => { + // Encode fvar name into depth as a unique ID. + // We'll recover it during to_expr using a side table. + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + // Store name→var mapping (caller should manage the side table) + DAGPtr::Var(var) + }, + + ExprData::Sort(level, _) => { + let sort = alloc_val(Sort { level: level.clone(), parents }); + DAGPtr::Sort(sort) + }, + + ExprData::Const(name, levels, _) => { + let cnst = alloc_val(Cnst { + name: name.clone(), + levels: levels.clone(), + parents, + }); + DAGPtr::Cnst(cnst) + }, + + ExprData::Lit(lit, _) => { + let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); + DAGPtr::Lit(lit_node) + }, + + ExprData::App(fun_expr, arg_expr, _) => { + let app_ptr = alloc_app( + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let app = &mut *app_ptr.as_ptr(); + let fun_ref_ptr = + NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); + let arg_ref_ptr = + NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); + app.fun = from_expr_go(fun_expr, depth, ctx, Some(fun_ref_ptr)); + app.arg = from_expr_go(arg_expr, depth, ctx, Some(arg_ref_ptr)); + } + DAGPtr::App(app_ptr) + }, + + ExprData::Lam(name, typ, body, bi, _) => { + // Lean Lam → DAG Fun(dom, Lam(bod, var)) + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let fun_ptr = alloc_fun( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); + fun.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + // Set Lam's parent to FunImg + let img_ref_ptr = + NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Fun(fun_ptr) + }, + + ExprData::ForallE(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let pi_ptr = alloc_pi( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); + pi.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + let img_ref_ptr = + NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Pi(pi_ptr) + }, + + ExprData::LetE(name, typ, val, body, non_dep, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let let_ptr = alloc_let( + name.clone(), + *non_dep, + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let typ_ref_ptr = + NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); + let val_ref_ptr = + NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); + let_node.typ = from_expr_go(typ, depth, ctx, Some(typ_ref_ptr)); + let_node.val = from_expr_go(val, depth, ctx, Some(val_ref_ptr)); + + let bod_ref_ptr = + NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let inner_bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(inner_bod_ref_ptr)); + } + DAGPtr::Let(let_ptr) + }, + + ExprData::Proj(type_name, idx, structure, _) => { + let proj_ptr = alloc_proj( + type_name.clone(), + idx.clone(), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + let expr_ref_ptr = + NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); + proj.expr = + from_expr_go(structure, depth, ctx, Some(expr_ref_ptr)); + } + DAGPtr::Proj(proj_ptr) + }, + + // Mdata: strip metadata, convert inner expression + ExprData::Mdata(_, inner, _) => from_expr_go(inner, depth, ctx, parents), + + // Mvar: treat as terminal (shouldn't appear in well-typed terms) + ExprData::Mvar(_name, _) => { + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + }, + } +} + +// ============================================================================ +// Literal clone +// ============================================================================ + +impl Clone for crate::ix::env::Literal { + fn clone(&self) -> Self { + match self { + crate::ix::env::Literal::NatVal(n) => { + crate::ix::env::Literal::NatVal(n.clone()) + }, + crate::ix::env::Literal::StrVal(s) => { + crate::ix::env::Literal::StrVal(s.clone()) + }, + } + } +} + +// ============================================================================ +// DAG -> Expr +// ============================================================================ + +pub fn to_expr(dag: &DAG) -> Expr { + let mut var_map: BTreeMap<*const Var, u64> = BTreeMap::new(); + to_expr_go(dag.head, &mut var_map, 0) +} + +fn to_expr_go( + node: DAGPtr, + var_map: &mut BTreeMap<*const Var, u64>, + depth: u64, +) -> Expr { + unsafe { + match node { + DAGPtr::Var(link) => { + let var = link.as_ptr(); + let var_key = var as *const Var; + if let Some(&bind_depth) = var_map.get(&var_key) { + let idx = depth - bind_depth - 1; + Expr::bvar(Nat::from(idx)) + } else { + // Free variable + Expr::bvar(Nat::from((*var).depth)) + } + }, + + DAGPtr::Sort(link) => { + let sort = &*link.as_ptr(); + Expr::sort(sort.level.clone()) + }, + + DAGPtr::Cnst(link) => { + let cnst = &*link.as_ptr(); + Expr::cnst(cnst.name.clone(), cnst.levels.clone()) + }, + + DAGPtr::Lit(link) => { + let lit = &*link.as_ptr(); + Expr::lit(lit.val.clone()) + }, + + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun = to_expr_go(app.fun, var_map, depth); + let arg = to_expr_go(app.arg, var_map, depth); + Expr::app(fun, arg) + }, + + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let lam = &*fun.img.as_ptr(); + let dom = to_expr_go(fun.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::lam( + fun.binder_name.clone(), + dom, + bod, + fun.binder_info.clone(), + ) + }, + + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let lam = &*pi.img.as_ptr(); + let dom = to_expr_go(pi.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::all( + pi.binder_name.clone(), + dom, + bod, + pi.binder_info.clone(), + ) + }, + + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let lam = &*let_node.bod.as_ptr(); + let typ = to_expr_go(let_node.typ, var_map, depth); + let val = to_expr_go(let_node.val, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::letE( + let_node.binder_name.clone(), + typ, + val, + bod, + let_node.non_dep, + ) + }, + + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let structure = to_expr_go(proj.expr, var_map, depth); + Expr::proj(proj.type_name.clone(), proj.idx.clone(), structure) + }, + + DAGPtr::Lam(link) => { + // Standalone Lam shouldn't appear at the top level, + // but handle it gracefully for completeness. + let lam = &*link.as_ptr(); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + // Wrap in a lambda with anonymous name and default binder info + Expr::lam( + Name::anon(), + Expr::sort(Level::zero()), + bod, + crate::ix::env::BinderInfo::Default, + ) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::env::{BinderInfo, Literal}; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + // ========================================================================== + // Terminal roundtrips + // ========================================================================== + + #[test] + fn roundtrip_sort() { + let e = Expr::sort(Level::succ(Level::zero())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_sort_param() { + let e = Expr::sort(Level::param(mk_name("u"))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const() { + let e = Expr::cnst( + mk_name("Foo"), + vec![Level::zero(), Level::succ(Level::zero())], + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nat_lit() { + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_string_lit() { + let e = Expr::lit(Literal::StrVal("hello world".into())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Binder roundtrips + // ========================================================================== + + #[test] + fn roundtrip_identity_lambda() { + // fun (x : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const_lambda() { + // fun (x : Nat) (y : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_pi() { + // (x : Nat) → Nat + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_dependent_pi() { + // (A : Sort 0) → A → A + let sort0 = Expr::sort(Level::zero()); + let e = Expr::all( + mk_name("A"), + sort0, + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), // A + Expr::bvar(Nat::from(1u64)), // A + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // App roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app() { + // f a + let e = Expr::app( + Expr::cnst(mk_name("f"), vec![]), + nat_zero(), + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nested_app() { + // f a b + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f, a), b); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Let roundtrips + // ========================================================================== + + #[test] + fn roundtrip_let() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_let_non_dep() { + // let x : Nat := Nat.zero in Nat.zero (non_dep = true) + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + nat_zero(), + true, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Proj roundtrips + // ========================================================================== + + #[test] + fn roundtrip_proj() { + let e = Expr::proj(mk_name("Prod"), Nat::from(0u64), nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Complex roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app_of_lambda() { + // (fun x : Nat => x) Nat.zero + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_lambda_in_lambda() { + // fun (f : Nat → Nat) (x : Nat) => f x + let nat_to_nat = Expr::all( + mk_name("_"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::lam( + mk_name("f"), + nat_to_nat, + Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(1u64)), // f + Expr::bvar(Nat::from(0u64)), // x + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_bvar_sharing() { + // fun (x : Nat) => App(x, x) + // Both bvar(0) should map to the same Var in DAG + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_free_bvar() { + // Bvar(5) with no enclosing binder — should survive roundtrip + let e = Expr::bvar(Nat::from(5u64)); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_implicit_binder() { + // fun {x : Nat} => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Implicit, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Property tests (quickcheck) + // ========================================================================== + + /// Generate a random well-formed Expr with bound variables properly scoped. + /// `depth` tracks how many binders are in scope (for valid bvar generation). + fn arb_expr(g: &mut Gen, depth: u64, size: usize) -> Expr { + if size == 0 { + // Terminal: pick among Sort, Const, Lit, or Bvar (if depth > 0) + let choices = if depth > 0 { 5 } else { 4 }; + match usize::arbitrary(g) % choices { + 0 => Expr::sort(arb_level(g, 2)), + 1 => { + let names = ["Nat", "Bool", "String", "Unit", "Int"]; + let idx = usize::arbitrary(g) % names.len(); + Expr::cnst(mk_name(names[idx]), vec![]) + }, + 2 => { + let n = u64::arbitrary(g) % 100; + Expr::lit(Literal::NatVal(Nat::from(n))) + }, + 3 => { + let s: String = String::arbitrary(g); + // Truncate at a char boundary to avoid panics + let s: String = s.chars().take(10).collect(); + Expr::lit(Literal::StrVal(s)) + }, + 4 => { + // Bvar within scope + let idx = u64::arbitrary(g) % depth; + Expr::bvar(Nat::from(idx)) + }, + _ => unreachable!(), + } + } else { + let next = size / 2; + match usize::arbitrary(g) % 5 { + 0 => { + // App + let f = arb_expr(g, depth, next); + let a = arb_expr(g, depth, next); + Expr::app(f, a) + }, + 1 => { + // Lam + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::lam(mk_name("x"), dom, bod, BinderInfo::Default) + }, + 2 => { + // Pi + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::all(mk_name("a"), dom, bod, BinderInfo::Default) + }, + 3 => { + // Let + let typ = arb_expr(g, depth, next); + let val = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next / 2); + Expr::letE(mk_name("v"), typ, val, bod, bool::arbitrary(g)) + }, + 4 => { + // Proj + let idx = u64::arbitrary(g) % 4; + let structure = arb_expr(g, depth, next); + Expr::proj(mk_name("S"), Nat::from(idx), structure) + }, + _ => unreachable!(), + } + } + } + + fn arb_level(g: &mut Gen, size: usize) -> Level { + if size == 0 { + match usize::arbitrary(g) % 3 { + 0 => Level::zero(), + 1 => { + let params = ["u", "v", "w"]; + let idx = usize::arbitrary(g) % params.len(); + Level::param(mk_name(params[idx])) + }, + 2 => Level::succ(Level::zero()), + _ => unreachable!(), + } + } else { + match usize::arbitrary(g) % 3 { + 0 => Level::succ(arb_level(g, size - 1)), + 1 => Level::max(arb_level(g, size / 2), arb_level(g, size / 2)), + 2 => Level::imax(arb_level(g, size / 2), arb_level(g, size / 2)), + _ => unreachable!(), + } + } + } + + /// Newtype wrapper for quickcheck Arbitrary derivation. + #[derive(Clone, Debug)] + struct ArbExpr(Expr); + + impl Arbitrary for ArbExpr { + fn arbitrary(g: &mut Gen) -> Self { + let size = usize::arbitrary(g) % 5; + ArbExpr(arb_expr(g, 0, size)) + } + } + + #[quickcheck] + fn prop_roundtrip(e: ArbExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } + + /// Same test but with expressions generated inside binders. + #[derive(Clone, Debug)] + struct ArbBinderExpr(Expr); + + impl Arbitrary for ArbBinderExpr { + fn arbitrary(g: &mut Gen) -> Self { + let inner_size = usize::arbitrary(g) % 4; + let body = arb_expr(g, 1, inner_size); + let dom = arb_expr(g, 0, 0); + ArbBinderExpr(Expr::lam( + mk_name("x"), + dom, + body, + BinderInfo::Default, + )) + } + } + + #[quickcheck] + fn prop_roundtrip_binder(e: ArbBinderExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } +} diff --git a/src/ix/kernel/dag.rs b/src/ix/kernel/dag.rs new file mode 100644 index 00000000..9837405f --- /dev/null +++ b/src/ix/kernel/dag.rs @@ -0,0 +1,527 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Level, Literal, Name}; +use crate::lean::nat::Nat; +use rustc_hash::FxHashSet; + +use super::dll::DLL; + +pub type Parents = DLL; + +// ============================================================================ +// Pointer types +// ============================================================================ + +#[derive(Debug)] +pub enum DAGPtr { + Var(NonNull), + Sort(NonNull), + Cnst(NonNull), + Lit(NonNull), + Lam(NonNull), + Fun(NonNull), + Pi(NonNull), + App(NonNull), + Let(NonNull), + Proj(NonNull), +} + +impl Copy for DAGPtr {} +impl Clone for DAGPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for DAGPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (DAGPtr::Var(a), DAGPtr::Var(b)) => a == b, + (DAGPtr::Sort(a), DAGPtr::Sort(b)) => a == b, + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => a == b, + (DAGPtr::Lit(a), DAGPtr::Lit(b)) => a == b, + (DAGPtr::Lam(a), DAGPtr::Lam(b)) => a == b, + (DAGPtr::Fun(a), DAGPtr::Fun(b)) => a == b, + (DAGPtr::Pi(a), DAGPtr::Pi(b)) => a == b, + (DAGPtr::App(a), DAGPtr::App(b)) => a == b, + (DAGPtr::Let(a), DAGPtr::Let(b)) => a == b, + (DAGPtr::Proj(a), DAGPtr::Proj(b)) => a == b, + _ => false, + } + } +} +impl Eq for DAGPtr {} + +#[derive(Debug)] +pub enum ParentPtr { + Root, + LamBod(NonNull), + FunDom(NonNull), + FunImg(NonNull), + PiDom(NonNull), + PiImg(NonNull), + AppFun(NonNull), + AppArg(NonNull), + LetTyp(NonNull), + LetVal(NonNull), + LetBod(NonNull), + ProjExpr(NonNull), +} + +impl Copy for ParentPtr {} +impl Clone for ParentPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for ParentPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ParentPtr::Root, ParentPtr::Root) => true, + (ParentPtr::LamBod(a), ParentPtr::LamBod(b)) => a == b, + (ParentPtr::FunDom(a), ParentPtr::FunDom(b)) => a == b, + (ParentPtr::FunImg(a), ParentPtr::FunImg(b)) => a == b, + (ParentPtr::PiDom(a), ParentPtr::PiDom(b)) => a == b, + (ParentPtr::PiImg(a), ParentPtr::PiImg(b)) => a == b, + (ParentPtr::AppFun(a), ParentPtr::AppFun(b)) => a == b, + (ParentPtr::AppArg(a), ParentPtr::AppArg(b)) => a == b, + (ParentPtr::LetTyp(a), ParentPtr::LetTyp(b)) => a == b, + (ParentPtr::LetVal(a), ParentPtr::LetVal(b)) => a == b, + (ParentPtr::LetBod(a), ParentPtr::LetBod(b)) => a == b, + (ParentPtr::ProjExpr(a), ParentPtr::ProjExpr(b)) => a == b, + _ => false, + } + } +} +impl Eq for ParentPtr {} + +/// Binder pointer: from a Var to its binding Lam, or Free. +#[derive(Debug)] +pub enum BinderPtr { + Free, + Lam(NonNull), +} + +impl Copy for BinderPtr {} +impl Clone for BinderPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for BinderPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (BinderPtr::Free, BinderPtr::Free) => true, + (BinderPtr::Lam(a), BinderPtr::Lam(b)) => a == b, + _ => false, + } + } +} + +// ============================================================================ +// Node structs +// ============================================================================ + +/// Bound or free variable. +#[repr(C)] +pub struct Var { + /// De Bruijn level (used during from_expr/to_expr conversion). + pub depth: u64, + /// Points to the binding Lam, or Free for free variables. + pub binder: BinderPtr, + /// Parent pointers. + pub parents: Option>, +} + +impl Copy for Var {} +impl Clone for Var { + fn clone(&self) -> Self { + *self + } +} + +/// Sort node (universe). +#[repr(C)] +pub struct Sort { + pub level: Level, + pub parents: Option>, +} + +/// Constant reference. +#[repr(C)] +pub struct Cnst { + pub name: Name, + pub levels: Vec, + pub parents: Option>, +} + +/// Literal value (Nat or String). +#[repr(C)] +pub struct LitNode { + pub val: Literal, + pub parents: Option>, +} + +/// Internal binding node (spine). Carries an embedded Var. +/// Always appears as the img/bod of Fun/Pi/Let. +#[repr(C)] +pub struct Lam { + pub bod: DAGPtr, + pub bod_ref: Parents, + pub var: Var, + pub parents: Option>, +} + +/// Lean lambda: `fun (name : dom) => bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Fun { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Lean Pi/ForallE: `(name : dom) → bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Pi { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Application node. +#[repr(C)] +pub struct App { + pub fun: DAGPtr, + pub arg: DAGPtr, + pub fun_ref: Parents, + pub arg_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Let binding: `let name : typ := val in bod`. +#[repr(C)] +pub struct LetNode { + pub binder_name: Name, + pub non_dep: bool, + pub typ: DAGPtr, + pub val: DAGPtr, + pub bod: NonNull, + pub typ_ref: Parents, + pub val_ref: Parents, + pub bod_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Projection from a structure. +#[repr(C)] +pub struct ProjNode { + pub type_name: Name, + pub idx: Nat, + pub expr: DAGPtr, + pub expr_ref: Parents, + pub parents: Option>, +} + +/// A DAG with a head node. +pub struct DAG { + pub head: DAGPtr, +} + +// ============================================================================ +// Allocation helpers +// ============================================================================ + +#[inline] +pub fn alloc_val(val: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(val))).unwrap() +} + +pub fn alloc_lam( + depth: u64, + bod: DAGPtr, + parents: Option>, +) -> NonNull { + let lam_ptr = alloc_val(Lam { + bod, + bod_ref: DLL::singleton(ParentPtr::Root), + var: Var { depth, binder: BinderPtr::Free, parents: None }, + parents, + }); + unsafe { + let lam = &mut *lam_ptr.as_ptr(); + lam.bod_ref = DLL::singleton(ParentPtr::LamBod(lam_ptr)); + lam.var.binder = BinderPtr::Lam(lam_ptr); + } + lam_ptr +} + +pub fn alloc_app( + fun: DAGPtr, + arg: DAGPtr, + parents: Option>, +) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +pub fn alloc_fun( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +pub fn alloc_pi( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +pub fn alloc_let( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, + parents: Option>, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +pub fn alloc_proj( + type_name: Name, + idx: Nat, + expr: DAGPtr, + parents: Option>, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Parent pointer helpers +// ============================================================================ + +pub fn get_parents(node: DAGPtr) -> Option> { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents, + DAGPtr::App(p) => (*p.as_ptr()).parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents, + } + } +} + +pub fn set_parents(node: DAGPtr, parents: Option>) { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents = parents, + DAGPtr::App(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents = parents, + } + } +} + +pub fn add_to_parents(node: DAGPtr, parent_link: NonNull) { + unsafe { + match get_parents(node) { + None => set_parents(node, Some(parent_link)), + Some(parents) => { + (*parents.as_ptr()).merge(parent_link); + }, + } + } +} + +// ============================================================================ +// DAG-level helpers +// ============================================================================ + +/// Get a unique key for a DAG node pointer (for use in hash sets). +pub fn dag_ptr_key(node: DAGPtr) -> usize { + match node { + DAGPtr::Var(p) => p.as_ptr() as usize, + DAGPtr::Sort(p) => p.as_ptr() as usize, + DAGPtr::Cnst(p) => p.as_ptr() as usize, + DAGPtr::Lit(p) => p.as_ptr() as usize, + DAGPtr::Lam(p) => p.as_ptr() as usize, + DAGPtr::Fun(p) => p.as_ptr() as usize, + DAGPtr::Pi(p) => p.as_ptr() as usize, + DAGPtr::App(p) => p.as_ptr() as usize, + DAGPtr::Let(p) => p.as_ptr() as usize, + DAGPtr::Proj(p) => p.as_ptr() as usize, + } +} + +/// Free all DAG nodes reachable from the head. +/// Only frees the node structs themselves; DLL parent entries that are +/// inline in parent structs are freed with those structs. The root_parents +/// DLL node (heap-allocated in from_expr) is a small accepted leak. +pub fn free_dag(dag: DAG) { + let mut visited = FxHashSet::default(); + free_dag_nodes(dag.head, &mut visited); +} + +fn free_dag_nodes(node: DAGPtr, visited: &mut FxHashSet) { + let key = dag_ptr_key(node); + if !visited.insert(key) { + return; + } + unsafe { + match node { + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + // Only free separately-allocated free vars; bound vars are + // embedded in their Lam struct and freed with it. + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + free_dag_nodes(lam.bod, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + free_dag_nodes(fun.dom, visited); + free_dag_nodes(DAGPtr::Lam(fun.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + free_dag_nodes(pi.dom, visited); + free_dag_nodes(DAGPtr::Lam(pi.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + free_dag_nodes(app.fun, visited); + free_dag_nodes(app.arg, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + free_dag_nodes(let_node.typ, visited); + free_dag_nodes(let_node.val, visited); + free_dag_nodes(DAGPtr::Lam(let_node.bod), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + free_dag_nodes(proj.expr, visited); + drop(Box::from_raw(link.as_ptr())); + }, + } + } +} diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs new file mode 100644 index 00000000..c2110381 --- /dev/null +++ b/src/ix/kernel/def_eq.rs @@ -0,0 +1,1298 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::level::{eq_antisymm, eq_antisymm_many}; +use super::tc::TypeChecker; +use super::whnf::*; + +/// Result of lazy delta reduction. +enum DeltaResult { + Found(bool), + Exhausted(Expr, Expr), +} + +/// Check definitional equality of two expressions. +pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + if let Some(quick) = def_eq_quick_check(x, y) { + return quick; + } + + let x_n = tc.whnf(x); + let y_n = tc.whnf(y); + + if let Some(quick) = def_eq_quick_check(&x_n, &y_n) { + return quick; + } + + if proof_irrel_eq(&x_n, &y_n, tc) { + return true; + } + + match lazy_delta_step(&x_n, &y_n, tc) { + DeltaResult::Found(result) => result, + DeltaResult::Exhausted(x_e, y_e) => { + def_eq_const(&x_e, &y_e) + || def_eq_proj(&x_e, &y_e, tc) + || def_eq_app(&x_e, &y_e, tc) + || def_eq_binder_full(&x_e, &y_e, tc) + || try_eta_expansion(&x_e, &y_e, tc) + || try_eta_struct(&x_e, &y_e, tc) + || is_def_eq_unit_like(&x_e, &y_e, tc) + }, + } +} + +/// Quick syntactic checks. +fn def_eq_quick_check(x: &Expr, y: &Expr) -> Option { + if x == y { + return Some(true); + } + if let Some(r) = def_eq_sort(x, y) { + return Some(r); + } + if let Some(r) = def_eq_binder(x, y) { + return Some(r); + } + None +} + +fn def_eq_sort(x: &Expr, y: &Expr) -> Option { + match (x.as_data(), y.as_data()) { + (ExprData::Sort(l, _), ExprData::Sort(r, _)) => { + Some(eq_antisymm(l, r)) + }, + _ => None, + } +} + +/// Check if two binder expressions (Pi/Lam) are definitionally equal. +/// Always defers to full checking after WHNF, since binder types could be +/// definitionally equal without being syntactically identical. +fn def_eq_binder(_x: &Expr, _y: &Expr) -> Option { + None +} + +fn def_eq_const(x: &Expr, y: &Expr) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Const(xn, xl, _), + ExprData::Const(yn, yl, _), + ) => xn == yn && eq_antisymm_many(xl, yl), + _ => false, + } +} + +fn def_eq_proj(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Proj(_, idx_l, structure_l, _), + ExprData::Proj(_, idx_r, structure_r, _), + ) => idx_l == idx_r && def_eq(structure_l, structure_r, tc), + _ => false, + } +} + +fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let (f1, args1) = unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + + if !def_eq(&f1, &f2, tc) { + return false; + } + args1.iter().zip(args2.iter()).all(|(a, b)| def_eq(a, b, tc)) +} + +/// Full recursive binder comparison: two Pi or two Lam types with +/// definitionally equal domain types and bodies (ignoring binder names). +fn def_eq_binder_full( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::ForallE(_, t1, b1, _, _), + ExprData::ForallE(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + ( + ExprData::Lam(_, t1, b1, _, _), + ExprData::Lam(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + _ => false, + } +} + +/// Proof irrelevance: if both x and y are proofs of the same proposition, +/// they are definitionally equal. +fn proof_irrel_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&x_ty, tc) { + return false; + } + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&y_ty, tc) { + return false; + } + def_eq(&x_ty, &y_ty, tc) +} + +/// Check if an expression's type is Prop (Sort 0). +fn is_proposition(ty: &Expr, tc: &mut TypeChecker) -> bool { + let ty_of_ty = match tc.infer(ty) { + Ok(t) => t, + Err(_) => return false, + }; + let whnfd = tc.whnf(&ty_of_ty); + matches!(whnfd.as_data(), ExprData::Sort(l, _) if super::level::is_zero(l)) +} + +/// Eta expansion: `fun x => f x` ≡ `f` when `f : (x : A) → B`. +fn try_eta_expansion(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_expansion_aux(x, y, tc) || try_eta_expansion_aux(y, x, tc) +} + +fn try_eta_expansion_aux( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + if let ExprData::Lam(_, _, _, _, _) = x.as_data() { + let y_ty = match tc.infer(y) { + Ok(t) => t, + Err(_) => return false, + }; + let y_ty_whnf = tc.whnf(&y_ty); + if let ExprData::ForallE(name, binder_type, _, bi, _) = + y_ty_whnf.as_data() + { + // eta-expand y: fun x => y x + let body = Expr::app(y.clone(), Expr::bvar(crate::lean::nat::Nat::from(0))); + let expanded = Expr::lam( + name.clone(), + binder_type.clone(), + body, + bi.clone(), + ); + return def_eq(x, &expanded, tc); + } + } + false +} + +/// Check if a name refers to a structure-like inductive: +/// exactly 1 constructor, not recursive, no indices. +fn is_structure_like(name: &Name, env: &Env) -> bool { + match env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + iv.ctors.len() == 1 && !iv.is_rec && iv.num_indices == Nat::ZERO + }, + _ => false, + } +} + +/// Structure eta: `p =def= S.mk (S.1 p) (S.2 p)` when S is a +/// single-constructor non-recursive inductive with no indices. +fn try_eta_struct(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_struct_core(x, y, tc) || try_eta_struct_core(y, x, tc) +} + +/// Try to decompose `s` as a constructor application for a structure-like +/// type, then check that each field matches the corresponding projection of `t`. +fn try_eta_struct_core( + t: &Expr, + s: &Expr, + tc: &mut TypeChecker, +) -> bool { + let (head, args) = unfold_apps(s); + let ctor_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + + let ctor_info = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return false, + }; + + if !is_structure_like(&ctor_info.induct, tc.env) { + return false; + } + + let num_params = ctor_info.num_params.to_u64().unwrap() as usize; + let num_fields = ctor_info.num_fields.to_u64().unwrap() as usize; + + if args.len() != num_params + num_fields { + return false; + } + + for i in 0..num_fields { + let field = &args[num_params + i]; + let proj = Expr::proj( + ctor_info.induct.clone(), + Nat::from(i as u64), + t.clone(), + ); + if !def_eq(field, &proj, tc) { + return false; + } + } + + true +} + +/// Unit-like equality: types with a single zero-field constructor have all +/// inhabitants definitionally equal. +fn is_def_eq_unit_like(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + // Types must be def-eq + if !def_eq(&x_ty, &y_ty, tc) { + return false; + } + // Check if the type is a unit-like inductive + let whnf_ty = tc.whnf(&x_ty); + let (head, _) = unfold_apps(&whnf_ty); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + match tc.env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + if iv.ctors.len() != 1 { + return false; + } + // Check single constructor has zero fields + if let Some(ConstantInfo::CtorInfo(c)) = tc.env.get(&iv.ctors[0]) { + c.num_fields == Nat::ZERO + } else { + false + } + }, + _ => false, + } +} + +/// Lazy delta reduction: unfold definitions step by step. +fn lazy_delta_step( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> DeltaResult { + let mut x = x.clone(); + let mut y = y.clone(); + + loop { + let x_def = get_applied_def(&x, tc.env); + let y_def = get_applied_def(&y, tc.env); + + match (&x_def, &y_def) { + (None, None) => return DeltaResult::Exhausted(x, y), + (Some(_), None) => { + x = delta(&x, tc); + }, + (None, Some(_)) => { + y = delta(&y, tc); + }, + (Some((x_name, x_hint)), Some((y_name, y_hint))) => { + // Same name and same height: try congruence first + if x_name == y_name && x_hint == y_hint { + if def_eq_app(&x, &y, tc) { + return DeltaResult::Found(true); + } + x = delta(&x, tc); + y = delta(&y, tc); + } else if hint_lt(x_hint, y_hint) { + y = delta(&y, tc); + } else { + x = delta(&x, tc); + } + }, + } + + if let Some(quick) = def_eq_quick_check(&x, &y) { + return DeltaResult::Found(quick); + } + } +} + +/// Get the name and reducibility hint of an applied definition. +fn get_applied_def( + e: &Expr, + env: &Env, +) -> Option<(Name, ReducibilityHints)> { + let (head, _) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + let ci = env.get(name)?; + match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + None + } else { + Some((name.clone(), d.hints)) + } + }, + ConstantInfo::ThmInfo(_) => { + Some((name.clone(), ReducibilityHints::Opaque)) + }, + _ => None, + } +} + +/// Unfold a definition and do cheap WHNF. +fn delta(e: &Expr, tc: &mut TypeChecker) -> Expr { + match try_unfold_def(e, tc.env) { + Some(unfolded) => tc.whnf(&unfolded), + None => e.clone(), + } +} + +/// Compare reducibility hints for ordering. +fn hint_lt(a: &ReducibilityHints, b: &ReducibilityHints) -> bool { + match (a, b) { + (ReducibilityHints::Opaque, _) => true, + (_, ReducibilityHints::Opaque) => false, + (ReducibilityHints::Abbrev, _) => false, + (_, ReducibilityHints::Abbrev) => true, + (ReducibilityHints::Regular(ha), ReducibilityHints::Regular(hb)) => { + ha < hb + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + /// Minimal env with Nat, Nat.zero, Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + // ========================================================================== + // Reflexivity + // ========================================================================== + + #[test] + fn def_eq_reflexive_sort() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_const() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_lambda() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e, &e)); + } + + // ========================================================================== + // Sort equality + // ========================================================================== + + #[test] + fn def_eq_sort_max_comm() { + // Sort(max u v) =def= Sort(max v u) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let s1 = Expr::sort(Level::max(u.clone(), v.clone())); + let s2 = Expr::sort(Level::max(v, u)); + assert!(tc.def_eq(&s1, &s2)); + } + + #[test] + fn def_eq_sort_not_equal() { + // Sort(0) ≠ Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s0 = Expr::sort(Level::zero()); + let s1 = Expr::sort(Level::succ(Level::zero())); + assert!(!tc.def_eq(&s0, &s1)); + } + + // ========================================================================== + // Alpha equivalence (same structure, different binder names) + // ========================================================================== + + #[test] + fn def_eq_alpha_lambda() { + // fun (x : Nat) => x =def= fun (y : Nat) => y + // (de Bruijn indices are the same, so this is syntactic equality) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e2 = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + #[test] + fn def_eq_alpha_pi() { + // (x : Nat) → Nat =def= (y : Nat) → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e2 = Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + // ========================================================================== + // Beta equivalence + // ========================================================================== + + #[test] + fn def_eq_beta() { + // (fun x : Nat => x) Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let lhs = Expr::app(id_fn, nat_zero()); + let rhs = nat_zero(); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_beta_nested() { + // (fun x y : Nat => x) Nat.zero Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + inner, + BinderInfo::Default, + ); + let lhs = Expr::app(Expr::app(k_fn, nat_zero()), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Delta equivalence (definition unfolding) + // ========================================================================== + + #[test] + fn def_eq_delta() { + // def myZero := Nat.zero + // myZero =def= Nat.zero + let mut env = mk_nat_env(); + let my_zero = mk_name("myZero"); + env.insert( + my_zero.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_zero.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_zero.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(my_zero, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + #[test] + fn def_eq_delta_both_sides() { + // def a := Nat.zero, def b := Nat.zero + // a =def= b + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(a, vec![]); + let rhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Zeta equivalence (let unfolding) + // ========================================================================== + + #[test] + fn def_eq_zeta() { + // (let x : Nat := Nat.zero in x) =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Negative tests + // ========================================================================== + + #[test] + fn def_eq_different_consts() { + // Nat ≠ String + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let nat = nat_type(); + let string = Expr::cnst(mk_name("String"), vec![]); + assert!(!tc.def_eq(&nat, &string)); + } + + #[test] + fn def_eq_different_nat_levels() { + // Nat.zero ≠ Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let zero = nat_zero(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + assert!(!tc.def_eq(&zero, &succ)); + } + + #[test] + fn def_eq_app_congruence() { + // f a =def= f a (for same f, same a) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let f = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let a = nat_zero(); + let lhs = Expr::app(f.clone(), a.clone()); + let rhs = Expr::app(f, a); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_app_different_args() { + // Nat.succ Nat.zero ≠ Nat.succ (Nat.succ Nat.zero) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let lhs = Expr::app(succ.clone(), nat_zero()); + let rhs = + Expr::app(succ.clone(), Expr::app(succ, nat_zero())); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Const-level equality + // ========================================================================== + + #[test] + fn def_eq_const_levels() { + // A.{max u v} =def= A.{max v u} + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let lhs = Expr::cnst(a_name.clone(), vec![Level::max(u.clone(), v.clone()), Level::zero()]); + let rhs = Expr::cnst(a_name, vec![Level::max(v, u), Level::zero()]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Hint ordering + // ========================================================================== + + #[test] + fn hint_lt_opaque_less_than_all() { + assert!(hint_lt(&ReducibilityHints::Opaque, &ReducibilityHints::Abbrev)); + assert!(hint_lt( + &ReducibilityHints::Opaque, + &ReducibilityHints::Regular(0) + )); + } + + #[test] + fn hint_lt_abbrev_greatest() { + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Opaque + )); + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Regular(100) + )); + } + + #[test] + fn hint_lt_regular_ordering() { + assert!(hint_lt( + &ReducibilityHints::Regular(1), + &ReducibilityHints::Regular(2) + )); + assert!(!hint_lt( + &ReducibilityHints::Regular(2), + &ReducibilityHints::Regular(1) + )); + } + + // ========================================================================== + // Eta expansion + // ========================================================================== + + #[test] + fn def_eq_eta_lam_vs_const() { + // fun x : Nat => Nat.succ x =def= Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&eta_expanded, &succ)); + } + + #[test] + fn def_eq_eta_symmetric() { + // Nat.succ =def= fun x : Nat => Nat.succ x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&succ, &eta_expanded)); + } + + // ========================================================================== + // Lazy delta step with different heights + // ========================================================================== + + #[test] + fn def_eq_lazy_delta_higher_unfolds_first() { + // def a := Nat.zero (height 1) + // def b := a (height 2) + // b =def= Nat.zero should work by unfolding b first (higher height) + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Regular(1), + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Regular(2), + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Transitivity through delta + // ========================================================================== + + #[test] + fn def_eq_transitive_delta() { + // def a := Nat.zero, def b := Nat.zero + // def c := Nat.zero + // a =def= b, a =def= c, b =def= c + let mut env = mk_nat_env(); + for name_str in &["a", "b", "c"] { + let n = mk_name(name_str); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + } + let mut tc = TypeChecker::new(&env); + let a = Expr::cnst(mk_name("a"), vec![]); + let b = Expr::cnst(mk_name("b"), vec![]); + let c = Expr::cnst(mk_name("c"), vec![]); + assert!(tc.def_eq(&a, &b)); + assert!(tc.def_eq(&a, &c)); + assert!(tc.def_eq(&b, &c)); + } + + // ========================================================================== + // Nat literal equality through WHNF + // ========================================================================== + + #[test] + fn def_eq_nat_lit_same() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(42u64))); + assert!(tc.def_eq(&a, &b)); + } + + #[test] + fn def_eq_nat_lit_different() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(1u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(2u64))); + assert!(!tc.def_eq(&a, &b)); + } + + // ========================================================================== + // Beta-delta combined + // ========================================================================== + + #[test] + fn def_eq_beta_delta_combined() { + // def myId := fun x : Nat => x + // myId Nat.zero =def= Nat.zero + let mut env = mk_nat_env(); + let my_id = mk_name("myId"); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + my_id.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_id.clone(), + level_params: vec![], + typ: fun_ty, + }, + value: Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_id.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::app(Expr::cnst(my_id, vec![]), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Structure eta + // ========================================================================== + + /// Build an env with Nat + Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_ctor_name = mk_name2("Prod", "mk"); + + // Prod.{u,v} (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_ctor_name.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_ctor_name, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn eta_struct_ctor_eq_proj() { + // Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) =def= p + // where p is a free variable of type Prod Nat Nat + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&ctor_app, &p)); + } + + #[test] + fn eta_struct_symmetric() { + // p =def= Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&p, &ctor_app)); + } + + #[test] + fn eta_struct_nat_not_structure_like() { + // Nat has 2 constructors, so it is NOT structure-like + let env = mk_nat_env(); + assert!(!super::is_structure_like(&mk_name("Nat"), &env)); + } + + // ========================================================================== + // Binder full comparison + // ========================================================================== + + #[test] + fn def_eq_binder_full_different_domains() { + // (x : myNat) → Nat =def= (x : Nat) → Nat + // where myNat unfolds to Nat + let mut env = mk_nat_env(); + let my_nat = mk_name("myNat"); + env.insert( + my_nat.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_nat.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: nat_type(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_nat.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::all( + mk_name("x"), + Expr::cnst(my_nat, vec![]), + nat_type(), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Proj congruence + // ========================================================================== + + #[test] + fn def_eq_proj_congruence() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(0u64), s); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_proj_different_idx() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(1u64), s); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Unit-like equality + // ========================================================================== + + #[test] + fn def_eq_unit_like() { + // Unit-type: single ctor, zero fields + // Any two inhabitants should be def-eq + let mut env = mk_nat_env(); + let unit_name = mk_name("Unit"); + let unit_star = mk_name2("Unit", "star"); + + env.insert( + unit_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: unit_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![unit_name.clone()], + ctors: vec![unit_star.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + unit_star.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: unit_star.clone(), + level_params: vec![], + typ: Expr::cnst(unit_name.clone(), vec![]), + }, + induct: unit_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + + // Two distinct fvars of type Unit should be def-eq + let unit_ty = Expr::cnst(unit_name, vec![]); + let x = tc.mk_local(&mk_name("x"), &unit_ty); + let y = tc.mk_local(&mk_name("y"), &unit_ty); + assert!(tc.def_eq(&x, &y)); + } +} diff --git a/src/ix/kernel/dll.rs b/src/ix/kernel/dll.rs new file mode 100644 index 00000000..07dfe135 --- /dev/null +++ b/src/ix/kernel/dll.rs @@ -0,0 +1,214 @@ +use core::marker::PhantomData; +use core::ptr::NonNull; + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +pub struct DLL { + pub next: Option>>, + pub prev: Option>>, + pub elem: T, +} + +pub struct Iter<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &*node.as_ptr() }; + self.next = deref.next; + &deref.elem + }) + } +} + +pub struct IterMut<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &mut *node.as_ptr() }; + self.next = deref.next; + &mut deref.elem + }) + } +} + +impl DLL { + #[inline] + pub fn singleton(elem: T) -> Self { + DLL { next: None, prev: None, elem } + } + + #[inline] + pub fn alloc(elem: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(Self::singleton(elem)))).unwrap() + } + + #[inline] + pub fn is_singleton(dll: Option>) -> bool { + dll.is_some_and(|dll| unsafe { + let dll = &*dll.as_ptr(); + dll.prev.is_none() && dll.next.is_none() + }) + } + + #[inline] + pub fn is_empty(dll: Option>) -> bool { + dll.is_none() + } + + pub fn merge(&mut self, node: NonNull) { + unsafe { + (*node.as_ptr()).prev = self.prev; + (*node.as_ptr()).next = NonNull::new(self); + if let Some(ptr) = self.prev { + (*ptr.as_ptr()).next = Some(node); + } + self.prev = Some(node); + } + } + + pub fn unlink_node(&self) -> Option> { + unsafe { + let next = self.next; + let prev = self.prev; + if let Some(next) = next { + (*next.as_ptr()).prev = prev; + } + if let Some(prev) = prev { + (*prev.as_ptr()).next = next; + } + prev.or(next) + } + } + + pub fn first(mut node: NonNull) -> NonNull { + loop { + let prev = unsafe { (*node.as_ptr()).prev }; + match prev { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn last(mut node: NonNull) -> NonNull { + loop { + let next = unsafe { (*node.as_ptr()).next }; + match next { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn concat(dll: NonNull, rest: Option>) { + let last = DLL::last(dll); + let first = rest.map(DLL::first); + unsafe { + (*last.as_ptr()).next = first; + } + if let Some(first) = first { + unsafe { + (*first.as_ptr()).prev = Some(last); + } + } + } + + #[inline] + pub fn iter_option(dll: Option>) -> Iter<'static, T> { + Iter { next: dll.map(DLL::first), marker: PhantomData } + } + + #[inline] + #[allow(dead_code)] + pub fn iter_mut_option(dll: Option>) -> IterMut<'static, T> { + IterMut { next: dll.map(DLL::first), marker: PhantomData } + } + + #[allow(unsafe_op_in_unsafe_fn)] + pub unsafe fn free_all(dll: Option>) { + if let Some(start) = dll { + let first = DLL::first(start); + let mut current = Some(first); + while let Some(node) = current { + let next = (*node.as_ptr()).next; + drop(Box::from_raw(node.as_ptr())); + current = next; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn to_vec(dll: Option>>) -> Vec { + DLL::iter_option(dll).copied().collect() + } + + #[test] + fn test_singleton() { + let dll = DLL::alloc(42); + assert!(DLL::is_singleton(Some(dll))); + unsafe { + assert_eq!((*dll.as_ptr()).elem, 42); + drop(Box::from_raw(dll.as_ptr())); + } + } + + #[test] + fn test_is_empty() { + assert!(DLL::::is_empty(None)); + let dll = DLL::alloc(1); + assert!(!DLL::is_empty(Some(dll))); + unsafe { DLL::free_all(Some(dll)) }; + } + + #[test] + fn test_merge() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + (*a.as_ptr()).merge(b); + assert_eq!(to_vec(Some(a)), vec![2, 1]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_concat() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + DLL::concat(a, Some(b)); + assert_eq!(to_vec(Some(a)), vec![1, 2]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_unlink_singleton() { + unsafe { + let dll = DLL::alloc(42); + let remaining = (*dll.as_ptr()).unlink_node(); + assert!(remaining.is_none()); + drop(Box::from_raw(dll.as_ptr())); + } + } +} diff --git a/src/ix/kernel/error.rs b/src/ix/kernel/error.rs new file mode 100644 index 00000000..33816246 --- /dev/null +++ b/src/ix/kernel/error.rs @@ -0,0 +1,59 @@ +use crate::ix::env::{Expr, Name}; + +#[derive(Debug)] +pub enum TcError { + TypeExpected { + expr: Expr, + inferred: Expr, + }, + FunctionExpected { + expr: Expr, + inferred: Expr, + }, + TypeMismatch { + expected: Expr, + found: Expr, + expr: Expr, + }, + DefEqFailure { + lhs: Expr, + rhs: Expr, + }, + UnknownConst { + name: Name, + }, + DuplicateUniverse { + name: Name, + }, + FreeBoundVariable { + idx: u64, + }, + KernelException { + msg: String, + }, +} + +impl std::fmt::Display for TcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TcError::TypeExpected { .. } => write!(f, "type expected"), + TcError::FunctionExpected { .. } => write!(f, "function expected"), + TcError::TypeMismatch { .. } => write!(f, "type mismatch"), + TcError::DefEqFailure { .. } => { + write!(f, "definitional equality failure") + }, + TcError::UnknownConst { name } => { + write!(f, "unknown constant: {}", name.pretty()) + }, + TcError::DuplicateUniverse { name } => { + write!(f, "duplicate universe: {}", name.pretty()) + }, + TcError::FreeBoundVariable { idx } => { + write!(f, "free bound variable at index {}", idx) + }, + TcError::KernelException { msg } => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for TcError {} diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs new file mode 100644 index 00000000..a06ed819 --- /dev/null +++ b/src/ix/kernel/inductive.rs @@ -0,0 +1,772 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::error::TcError; +use super::level; +use super::tc::TypeChecker; +use super::whnf::{inst, unfold_apps}; + +type TcResult = Result; + +/// Validate an inductive type declaration. +/// Performs structural checks: constructors exist, belong to this inductive, +/// and have well-formed types. Mutual types are verified to exist. +pub fn check_inductive( + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Verify the type is well-formed + tc.check_declar_info(&ind.cnst)?; + + // Verify all constructors exist and belong to this inductive + for ctor_name in &ind.ctors { + let ctor_ci = tc.env.get(ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + let ctor = match ctor_ci { + ConstantInfo::CtorInfo(c) => c, + _ => { + return Err(TcError::KernelException { + msg: format!( + "{} is not a constructor", + ctor_name.pretty() + ), + }) + }, + }; + // Verify constructor's induct field matches + if ctor.induct != ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} belongs to {} but expected {}", + ctor_name.pretty(), + ctor.induct.pretty(), + ind.cnst.name.pretty() + ), + }); + } + // Verify constructor type is well-formed + tc.check_declar_info(&ctor.cnst)?; + } + + // Verify constructor return types and positivity + for ctor_name in &ind.ctors { + let ctor = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => continue, // already checked above + }; + check_ctor_return_type(ctor, ind, tc)?; + if !ind.is_unsafe { + check_ctor_positivity(ctor, ind, tc)?; + check_field_universe_constraints(ctor, ind, tc)?; + } + } + + // Verify all mutual types exist + for name in &ind.all { + if tc.env.get(name).is_none() { + return Err(TcError::UnknownConst { name: name.clone() }); + } + } + + Ok(()) +} + +/// Validate that a recursor's K flag is consistent with the inductive's structure. +/// K-target requires: non-mutual, in Prop, single constructor, zero fields. +/// If `rec.k == true` but conditions don't hold, reject. +pub fn validate_k_flag( + rec: &RecursorVal, + env: &Env, +) -> TcResult<()> { + if !rec.k { + return Ok(()); // conservative false is always fine + } + + // Must be non-mutual: `rec.all` should have exactly 1 inductive + if rec.all.len() != 1 { + return Err(TcError::KernelException { + msg: "recursor claims K but inductive is mutual".into(), + }); + } + + let ind_name = &rec.all[0]; + let ind = match env.get(ind_name) { + Some(ConstantInfo::InductInfo(iv)) => iv, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not an inductive", + ind_name.pretty() + ), + }) + }, + }; + + // Must be in Prop (Sort 0) + // Walk type telescope past all binders to get the sort + let mut ty = ind.cnst.typ.clone(); + loop { + match ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ty = body.clone(); + }, + _ => break, + } + } + let is_prop = match ty.as_data() { + ExprData::Sort(l, _) => level::is_zero(l), + _ => false, + }; + if !is_prop { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not in Prop", + ind_name.pretty() + ), + }); + } + + // Must have single constructor + if ind.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} has {} constructors (need 1)", + ind_name.pretty(), + ind.ctors.len() + ), + }); + } + + // Constructor must have zero fields (all args are params) + let ctor_name = &ind.ctors[0]; + if let Some(ConstantInfo::CtorInfo(c)) = env.get(ctor_name) { + if c.num_fields != Nat::ZERO { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but constructor {} has {} fields (need 0)", + ctor_name.pretty(), + c.num_fields + ), + }); + } + } + + Ok(()) +} + +/// Check if an expression mentions a constant by name. +fn expr_mentions_const(e: &Expr, name: &Name) -> bool { + match e.as_data() { + ExprData::Const(n, _, _) => n == name, + ExprData::App(f, a, _) => { + expr_mentions_const(f, name) || expr_mentions_const(a, name) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + expr_mentions_const(t, name) || expr_mentions_const(b, name) + }, + ExprData::LetE(_, t, v, b, _, _) => { + expr_mentions_const(t, name) + || expr_mentions_const(v, name) + || expr_mentions_const(b, name) + }, + ExprData::Proj(_, _, s, _) => expr_mentions_const(s, name), + ExprData::Mdata(_, inner, _) => expr_mentions_const(inner, name), + _ => false, + } +} + +/// Check that no inductive name from `ind.all` appears in a negative position +/// in the constructor's field types. +fn check_ctor_positivity( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ty = ctor.cnst.typ.clone(); + + // Skip parameter binders + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => return Ok(()), // fewer binders than params — odd but not our problem + } + } + + // For each remaining field, check its domain for positivity + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // The domain is the field type — check strict positivity + check_strict_positivity(binder_type, &ind.all, tc)?; + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Check strict positivity of a field type w.r.t. a set of inductive names. +/// +/// Strict positivity for `T` w.r.t. `I`: +/// - If `T` doesn't mention `I`, OK. +/// - If `T = I args...`, OK (the inductive itself at the head). +/// - If `T = (x : A) → B`, then `A` must NOT mention `I` at all, +/// and `B` must satisfy strict positivity w.r.t. `I`. +/// - Otherwise (I appears but not at head and not in Pi), reject. +fn check_strict_positivity( + ty: &Expr, + ind_names: &[Name], + tc: &mut TypeChecker, +) -> TcResult<()> { + let whnf_ty = tc.whnf(ty); + + // If no inductive name is mentioned, we're fine + if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { + return Ok(()); + } + + match whnf_ty.as_data() { + ExprData::ForallE(_, domain, body, _, _) => { + // Domain must NOT mention any inductive name + for ind_name in ind_names { + if expr_mentions_const(domain, ind_name) { + return Err(TcError::KernelException { + msg: format!( + "inductive {} occurs in negative position (strict positivity violation)", + ind_name.pretty() + ), + }); + } + } + // Recurse into body + check_strict_positivity(body, ind_names, tc) + }, + _ => { + // The inductive is mentioned and we're not in a Pi — check if + // it's simply an application `I args...` (which is OK). + let (head, _) = unfold_apps(&whnf_ty); + match head.as_data() { + ExprData::Const(name, _, _) + if ind_names.iter().any(|n| n == name) => + { + Ok(()) + }, + _ => Err(TcError::KernelException { + msg: "inductive type occurs in a non-positive position".into(), + }), + } + }, + } +} + +/// Check that constructor field types live in universes ≤ the inductive's universe. +fn check_field_universe_constraints( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Walk the inductive type telescope past num_params binders to find the sort level. + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ind_ty = ind.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + // Skip remaining binders (indices) to get to the target sort + loop { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => { + ind_ty = whnf_ty; + break; + }, + } + } + let ind_level = match ind_ty.as_data() { + ExprData::Sort(l, _) => l.clone(), + _ => return Ok(()), // can't extract sort, skip + }; + + // Walk ctor type, skip params, then check each field + let mut ctor_ty = ctor.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + + // For each remaining field binder, check its sort level ≤ ind_level + loop { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // Infer the sort of the binder_type + if let Ok(field_level) = tc.infer_sort_of(binder_type) { + if !level::leq(&field_level, &ind_level) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} field type lives in a universe larger than the inductive's universe", + ctor.cnst.name.pretty() + ), + }); + } + } + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Verify that a constructor's return type targets the parent inductive. +/// Walks the constructor type telescope, then checks that the resulting +/// type is an application of the parent inductive with at least `num_params` args. +fn check_ctor_return_type( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let mut ty = ctor.cnst.typ.clone(); + + // Walk past all Pi binders + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => { + ty = whnf_ty; + break; + }, + } + } + + // The return type should be `I args...` + let (head, args) = unfold_apps(&ty); + let head_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type head is not a constant", + ctor.cnst.name.pretty() + ), + }) + }, + }; + + if head_name != &ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} returns {} but should return {}", + ctor.cnst.name.pretty(), + head_name.pretty(), + ind.cnst.name.pretty() + ), + }); + } + + let num_params = ind.num_params.to_u64().unwrap() as usize; + if args.len() < num_params { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type has {} args but inductive has {} params", + ctor.cnst.name.pretty(), + args.len(), + num_params + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn check_nat_inductive_passes() { + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn check_ctor_wrong_return_type() { + let mut env = mk_nat_env(); + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name.clone()], + ctors: vec![mk_name2("Bool", "bad")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // Constructor returns Nat instead of Bool + let bad_ctor_name = mk_name2("Bool", "bad"); + env.insert( + bad_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: bad_ctor_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: bool_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&bool_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Positivity checking + // ========================================================================== + + fn bool_type() -> Expr { + Expr::cnst(mk_name("Bool"), vec![]) + } + + /// Helper to make a simple inductive + ctor env for positivity tests. + fn mk_single_ctor_env( + ind_name: &str, + ctor_name: &str, + ctor_typ: Expr, + num_fields: u64, + ) -> Env { + let mut env = mk_nat_env(); + // Bool + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name], + ctors: vec![mk_name2("Bool", "true"), mk_name2("Bool", "false")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let iname = mk_name(ind_name); + let cname = mk_name2(ind_name, ctor_name); + env.insert( + iname.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: iname.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![iname.clone()], + ctors: vec![cname.clone()], + num_nested: Nat::from(0u64), + is_rec: num_fields > 0, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + cname.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: cname, + level_params: vec![], + typ: ctor_typ, + }, + induct: iname, + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(num_fields), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn positivity_bad_negative() { + // inductive Bad | mk : (Bad → Bool) → Bad + let bad = mk_name("Bad"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("x"), Expr::cnst(bad, vec![]), bool_type(), BinderInfo::Default), + Expr::cnst(mk_name("Bad"), vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + #[test] + fn positivity_nat_succ_ok() { + // Nat.succ : Nat → Nat (positive) + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_tree_positive_function() { + // inductive Tree | node : (Nat → Tree) → Tree + // Tree appears positive in `Nat → Tree` + let tree = mk_name("Tree"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("n"), nat_type(), Expr::cnst(tree.clone(), vec![]), BinderInfo::Default), + Expr::cnst(tree, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Tree", "node", ctor_ty, 1); + let ind = match env.get(&mk_name("Tree")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_depth2_negative() { + // inductive Bad2 | mk : ((Bad2 → Nat) → Nat) → Bad2 + // Bad2 appears in negative position at depth 2 + let bad2 = mk_name("Bad2"); + let inner = Expr::all( + mk_name("g"), + Expr::all(mk_name("x"), Expr::cnst(bad2.clone(), vec![]), nat_type(), BinderInfo::Default), + nat_type(), + BinderInfo::Default, + ); + let ctor_ty = Expr::all( + mk_name("f"), + inner, + Expr::cnst(bad2, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad2", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad2")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Field universe constraints + // ========================================================================== + + #[test] + fn field_universe_nat_field_in_type1_ok() { + // Nat : Sort 1, Nat.succ field is Nat : Sort 1 — leq(1, 1) passes + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn field_universe_prop_inductive_with_type_field_fails() { + // inductive PropBad : Prop | mk : Nat → PropBad + // PropBad lives in Sort 0, Nat lives in Sort 1 — leq(1, 0) fails + let mut env = mk_nat_env(); + let pb_name = mk_name("PropBad"); + let pb_mk = mk_name2("PropBad", "mk"); + env.insert( + pb_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: pb_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), // Prop + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![pb_name.clone()], + ctors: vec![pb_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + pb_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: pb_mk, + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), // Nat : Sort 1 + Expr::cnst(pb_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: pb_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&pb_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } +} diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs new file mode 100644 index 00000000..90931ca6 --- /dev/null +++ b/src/ix/kernel/level.rs @@ -0,0 +1,393 @@ +use crate::ix::env::{Expr, ExprData, Level, LevelData, Name}; + +/// Simplify a universe level expression. +pub fn simplify(l: &Level) -> Level { + match l.as_data() { + LevelData::Zero(_) | LevelData::Param(..) | LevelData::Mvar(..) => { + l.clone() + }, + LevelData::Succ(inner, _) => { + let inner_s = simplify(inner); + Level::succ(inner_s) + }, + LevelData::Max(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + combining(&a_s, &b_s) + }, + LevelData::Imax(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + if is_zero(&a_s) || is_one(&a_s) { + b_s + } else { + match b_s.as_data() { + LevelData::Zero(_) => b_s, + LevelData::Succ(..) => combining(&a_s, &b_s), + _ => Level::imax(a_s, b_s), + } + } + }, + } +} + +/// Combine two levels, simplifying Max(Zero, x) = x and +/// Max(Succ a, Succ b) = Succ(Max(a, b)). +fn combining(l: &Level, r: &Level) -> Level { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) => r.clone(), + (_, LevelData::Zero(_)) => l.clone(), + (LevelData::Succ(a, _), LevelData::Succ(b, _)) => { + let inner = combining(a, b); + Level::succ(inner) + }, + _ => Level::max(l.clone(), r.clone()), + } +} + +fn is_one(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Succ(inner, _) if is_zero(inner)) +} + +/// Check if a level is definitionally zero: l <= 0. +pub fn is_zero(l: &Level) -> bool { + leq(l, &Level::zero()) +} + +/// Check if `l <= r`. +pub fn leq(l: &Level, r: &Level) -> bool { + let l_s = simplify(l); + let r_s = simplify(r); + leq_core(&l_s, &r_s, 0) +} + +/// Check `l <= r + diff`. +fn leq_core(l: &Level, r: &Level, diff: isize) -> bool { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) if diff >= 0 => true, + (_, LevelData::Zero(_)) if diff < 0 => false, + (LevelData::Param(a, _), LevelData::Param(b, _)) => a == b && diff >= 0, + (LevelData::Param(..), LevelData::Zero(_)) => false, + (LevelData::Zero(_), LevelData::Param(..)) => diff >= 0, + (LevelData::Succ(s, _), _) => leq_core(s, r, diff - 1), + (_, LevelData::Succ(s, _)) => leq_core(l, s, diff + 1), + (LevelData::Max(a, b, _), _) => { + leq_core(a, r, diff) && leq_core(b, r, diff) + }, + (LevelData::Param(..) | LevelData::Zero(_), LevelData::Max(x, y, _)) => { + leq_core(l, x, diff) || leq_core(l, y, diff) + }, + (LevelData::Imax(a, b, _), LevelData::Imax(x, y, _)) + if a == x && b == y => + { + true + }, + (LevelData::Imax(_, b, _), _) if is_param(b) => { + leq_imax_by_cases(b, l, r, diff) + }, + (_, LevelData::Imax(_, y, _)) if is_param(y) => { + leq_imax_by_cases(y, l, r, diff) + }, + (LevelData::Imax(a, b, _), _) if is_any_max(b) => { + match b.as_data() { + LevelData::Imax(x, y, _) => { + let new_lhs = Level::imax(a.clone(), y.clone()); + let new_rhs = Level::imax(x.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(&new_max, r, diff) + }, + LevelData::Max(x, y, _) => { + let new_lhs = Level::imax(a.clone(), x.clone()); + let new_rhs = Level::imax(a.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(&simplified, r, diff) + }, + _ => unreachable!(), + } + }, + (_, LevelData::Imax(x, y, _)) if is_any_max(y) => { + match y.as_data() { + LevelData::Imax(j, k, _) => { + let new_lhs = Level::imax(x.clone(), k.clone()); + let new_rhs = Level::imax(j.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(l, &new_max, diff) + }, + LevelData::Max(j, k, _) => { + let new_lhs = Level::imax(x.clone(), j.clone()); + let new_rhs = Level::imax(x.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(l, &simplified, diff) + }, + _ => unreachable!(), + } + }, + _ => false, + } +} + +/// Test l <= r by substituting param with 0 and Succ(param) and checking both. +fn leq_imax_by_cases( + param: &Level, + lhs: &Level, + rhs: &Level, + diff: isize, +) -> bool { + let zero = Level::zero(); + let succ_param = Level::succ(param.clone()); + + let lhs_0 = subst_and_simplify(lhs, param, &zero); + let rhs_0 = subst_and_simplify(rhs, param, &zero); + let lhs_s = subst_and_simplify(lhs, param, &succ_param); + let rhs_s = subst_and_simplify(rhs, param, &succ_param); + + leq_core(&lhs_0, &rhs_0, diff) && leq_core(&lhs_s, &rhs_s, diff) +} + +fn subst_and_simplify(level: &Level, from: &Level, to: &Level) -> Level { + let substituted = subst_single_level(level, from, to); + simplify(&substituted) +} + +/// Substitute a single level parameter. +fn subst_single_level(level: &Level, from: &Level, to: &Level) -> Level { + if level == from { + return to.clone(); + } + match level.as_data() { + LevelData::Zero(_) | LevelData::Mvar(..) => level.clone(), + LevelData::Param(..) => { + if level == from { + to.clone() + } else { + level.clone() + } + }, + LevelData::Succ(inner, _) => { + Level::succ(subst_single_level(inner, from, to)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + } +} + +fn is_param(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Param(..)) +} + +fn is_any_max(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Max(..) | LevelData::Imax(..)) +} + +/// Check universe level equality via antisymmetry: l == r iff l <= r && r <= l. +pub fn eq_antisymm(l: &Level, r: &Level) -> bool { + leq(l, r) && leq(r, l) +} + +/// Check that two lists of levels are pointwise equal. +pub fn eq_antisymm_many(ls: &[Level], rs: &[Level]) -> bool { + ls.len() == rs.len() + && ls.iter().zip(rs.iter()).all(|(l, r)| eq_antisymm(l, r)) +} + +/// Substitute universe parameters: `level[params[i] := values[i]]`. +pub fn subst_level( + level: &Level, + params: &[Name], + values: &[Level], +) -> Level { + match level.as_data() { + LevelData::Zero(_) => level.clone(), + LevelData::Succ(inner, _) => { + Level::succ(subst_level(inner, params, values)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Param(name, _) => { + for (i, p) in params.iter().enumerate() { + if name == p { + return values[i].clone(); + } + } + level.clone() + }, + LevelData::Mvar(..) => level.clone(), + } +} + +/// Check that all universe parameters in `level` are contained in `params`. +pub fn all_uparams_defined(level: &Level, params: &[Name]) -> bool { + match level.as_data() { + LevelData::Zero(_) => true, + LevelData::Succ(inner, _) => all_uparams_defined(inner, params), + LevelData::Max(a, b, _) | LevelData::Imax(a, b, _) => { + all_uparams_defined(a, params) && all_uparams_defined(b, params) + }, + LevelData::Param(name, _) => params.iter().any(|p| p == name), + LevelData::Mvar(..) => true, + } +} + +/// Check that all universe parameters in an expression are contained in `params`. +/// Recursively walks the Expr, checking all Levels in Sort and Const nodes. +pub fn all_expr_uparams_defined(e: &Expr, params: &[Name]) -> bool { + match e.as_data() { + ExprData::Sort(level, _) => all_uparams_defined(level, params), + ExprData::Const(_, levels, _) => { + levels.iter().all(|l| all_uparams_defined(l, params)) + }, + ExprData::App(f, a, _) => { + all_expr_uparams_defined(f, params) + && all_expr_uparams_defined(a, params) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::LetE(_, t, v, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(v, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::Proj(_, _, s, _) => all_expr_uparams_defined(s, params), + ExprData::Mdata(_, inner, _) => all_expr_uparams_defined(inner, params), + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => true, + } +} + +/// Check that a list of levels are all Params with no duplicates. +pub fn no_dupes_all_params(levels: &[Name]) -> bool { + for (i, a) in levels.iter().enumerate() { + for b in &levels[i + 1..] { + if a == b { + return false; + } + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simplify_zero() { + let z = Level::zero(); + assert_eq!(simplify(&z), z); + } + + #[test] + fn test_simplify_max_zero() { + let z = Level::zero(); + let p = Level::param(Name::str(Name::anon(), "u".into())); + let m = Level::max(z, p.clone()); + assert_eq!(simplify(&m), p); + } + + #[test] + fn test_simplify_imax_zero_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let z = Level::zero(); + let im = Level::imax(p, z.clone()); + assert_eq!(simplify(&im), z); + } + + #[test] + fn test_simplify_imax_succ_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let one = Level::succ(Level::zero()); + let im = Level::imax(p.clone(), one.clone()); + let simplified = simplify(&im); + // imax(p, 1) where p is nonzero → combining(p, 1) + // Actually: imax(u, 1) simplifies since a_s = u, b_s = 1 = Succ(0) + // → combining(u, 1) = max(u, 1) since u is Param, 1 is Succ + let expected = Level::max(p, one); + assert_eq!(simplified, expected); + } + + #[test] + fn test_simplify_idempotent() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let l = Level::max( + Level::imax(p.clone(), q.clone()), + Level::succ(Level::zero()), + ); + let s1 = simplify(&l); + let s2 = simplify(&s1); + assert_eq!(s1, s2); + } + + #[test] + fn test_leq_reflexive() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&p, &p)); + assert!(leq(&Level::zero(), &Level::zero())); + } + + #[test] + fn test_leq_zero_anything() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&Level::zero(), &p)); + assert!(leq(&Level::zero(), &Level::succ(Level::zero()))); + } + + #[test] + fn test_leq_succ_not_zero() { + let one = Level::succ(Level::zero()); + assert!(!leq(&one, &Level::zero())); + } + + #[test] + fn test_eq_antisymm_identity() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(eq_antisymm(&p, &p)); + } + + #[test] + fn test_eq_antisymm_max_comm() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let m1 = Level::max(p.clone(), q.clone()); + let m2 = Level::max(q, p); + assert!(eq_antisymm(&m1, &m2)); + } + + #[test] + fn test_subst_level() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let one = Level::succ(Level::zero()); + let result = subst_level(&p, &[u_name], &[one.clone()]); + assert_eq!(result, one); + } + + #[test] + fn test_subst_level_nested() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let l = Level::succ(p); + let zero = Level::zero(); + let result = subst_level(&l, &[u_name], &[zero]); + let expected = Level::succ(Level::zero()); + assert_eq!(result, expected); + } +} diff --git a/src/ix/kernel/mod.rs b/src/ix/kernel/mod.rs new file mode 100644 index 00000000..d6a5750e --- /dev/null +++ b/src/ix/kernel/mod.rs @@ -0,0 +1,11 @@ +pub mod convert; +pub mod dag; +pub mod def_eq; +pub mod dll; +pub mod error; +pub mod inductive; +pub mod level; +pub mod quot; +pub mod tc; +pub mod upcopy; +pub mod whnf; diff --git a/src/ix/kernel/quot.rs b/src/ix/kernel/quot.rs new file mode 100644 index 00000000..51a1e070 --- /dev/null +++ b/src/ix/kernel/quot.rs @@ -0,0 +1,291 @@ +use crate::ix::env::*; + +use super::error::TcError; + +type TcResult = Result; + +/// Verify that the quotient declarations are consistent with the environment. +/// Checks that Quot is an inductive, Quot.mk is its constructor, and +/// Quot.lift and Quot.ind exist. +pub fn check_quot(env: &Env) -> TcResult<()> { + let quot_name = Name::str(Name::anon(), "Quot".into()); + let quot_mk_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "mk".into()); + let quot_lift_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "lift".into()); + let quot_ind_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "ind".into()); + + // Check Quot exists and is an inductive + let quot = + env.get("_name).ok_or(TcError::UnknownConst { name: quot_name })?; + match quot { + ConstantInfo::InductInfo(_) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot is not an inductive type".into(), + }) + }, + } + + // Check Quot.mk exists and is a constructor of Quot + let mk = env + .get("_mk_name) + .ok_or(TcError::UnknownConst { name: quot_mk_name })?; + match mk { + ConstantInfo::CtorInfo(c) + if c.induct + == Name::str(Name::anon(), "Quot".into()) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot.mk is not a constructor of Quot".into(), + }) + }, + } + + // Check Eq exists as an inductive with exactly 1 universe param and 1 ctor + let eq_name = Name::str(Name::anon(), "Eq".into()); + if let Some(eq_ci) = env.get(&eq_name) { + match eq_ci { + ConstantInfo::InductInfo(iv) => { + if iv.cnst.level_params.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }); + } + if iv.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 constructor, found {}", + iv.ctors.len() + ), + }); + } + }, + _ => { + return Err(TcError::KernelException { + msg: "Eq is not an inductive type".into(), + }) + }, + } + } else { + return Err(TcError::KernelException { + msg: "Eq not found in environment (required for quotient types)".into(), + }); + } + + // Check Quot has exactly 1 level param + match quot { + ConstantInfo::InductInfo(iv) if iv.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.mk has 1 level param + match mk { + ConstantInfo::CtorInfo(c) if c.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot.mk should have 1 universe parameter, found {}", + c.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.lift exists and has 2 level params + let lift = env + .get("_lift_name) + .ok_or(TcError::UnknownConst { name: quot_lift_name })?; + if lift.get_level_params().len() != 2 { + return Err(TcError::KernelException { + msg: format!( + "Quot.lift should have 2 universe parameters, found {}", + lift.get_level_params().len() + ), + }); + } + + // Check Quot.ind exists and has 1 level param + let ind = env + .get("_ind_name) + .ok_or(TcError::UnknownConst { name: quot_ind_name })?; + if ind.get_level_params().len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Quot.ind should have 1 universe parameter, found {}", + ind.get_level_params().len() + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + /// Build a well-formed quotient environment. + fn mk_quot_env() -> Env { + let mut env = Env::default(); + let u = mk_name("u"); + let v = mk_name("v"); + let dummy_ty = Expr::sort(Level::param(u.clone())); + + // Eq.{u} — 1 uparam, 1 ctor + let eq_name = mk_name("Eq"); + let eq_refl = mk_name2("Eq", "refl"); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Eq"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + // Quot.{u} — 1 uparam + let quot_name = mk_name("Quot"); + let quot_mk = mk_name2("Quot", "mk"); + env.insert( + quot_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: quot_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![quot_name], + ctors: vec![quot_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + quot_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: quot_mk, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Quot"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Quot.lift.{u,v} — 2 uparams + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![u.clone(), v.clone()], + typ: dummy_ty.clone(), + }, + is_unsafe: false, + }), + ); + + // Quot.ind.{u} — 1 uparam + let quot_ind = mk_name2("Quot", "ind"); + env.insert( + quot_ind.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_ind, + level_params: vec![u], + typ: dummy_ty, + }, + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_quot_well_formed() { + let env = mk_quot_env(); + assert!(check_quot(&env).is_ok()); + } + + #[test] + fn check_quot_missing_eq() { + let mut env = mk_quot_env(); + env.remove(&mk_name("Eq")); + assert!(check_quot(&env).is_err()); + } + + #[test] + fn check_quot_wrong_lift_levels() { + let mut env = mk_quot_env(); + // Replace Quot.lift with 1 level param instead of 2 + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }), + ); + assert!(check_quot(&env).is_err()); + } +} diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs new file mode 100644 index 00000000..e80416fd --- /dev/null +++ b/src/ix/kernel/tc.rs @@ -0,0 +1,1694 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; +use rustc_hash::FxHashMap; + +use super::def_eq::def_eq; +use super::error::TcError; +use super::level::{all_expr_uparams_defined, no_dupes_all_params}; +use super::whnf::*; + +type TcResult = Result; + +/// The kernel type checker. +pub struct TypeChecker<'env> { + pub env: &'env Env, + pub whnf_cache: FxHashMap, + pub infer_cache: FxHashMap, + pub local_counter: u64, + pub local_types: FxHashMap, +} + +impl<'env> TypeChecker<'env> { + pub fn new(env: &'env Env) -> Self { + TypeChecker { + env, + whnf_cache: FxHashMap::default(), + infer_cache: FxHashMap::default(), + local_counter: 0, + local_types: FxHashMap::default(), + } + } + + // ========================================================================== + // WHNF with caching + // ========================================================================== + + pub fn whnf(&mut self, e: &Expr) -> Expr { + if let Some(cached) = self.whnf_cache.get(e) { + return cached.clone(); + } + let result = whnf(e, self.env); + self.whnf_cache.insert(e.clone(), result.clone()); + result + } + + // ========================================================================== + // Local context management + // ========================================================================== + + /// Create a fresh free variable for entering a binder. + pub fn mk_local(&mut self, name: &Name, ty: &Expr) -> Expr { + let id = self.local_counter; + self.local_counter += 1; + let local_name = Name::num(name.clone(), Nat::from(id)); + self.local_types.insert(local_name.clone(), ty.clone()); + Expr::fvar(local_name) + } + + // ========================================================================== + // Ensure helpers + // ========================================================================== + + pub fn ensure_sort(&mut self, e: &Expr) -> TcResult { + if let ExprData::Sort(level, _) = e.as_data() { + return Ok(level.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::Sort(level, _) => Ok(level.clone()), + _ => Err(TcError::TypeExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + pub fn ensure_pi(&mut self, e: &Expr) -> TcResult { + if let ExprData::ForallE(..) = e.as_data() { + return Ok(e.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::ForallE(..) => Ok(whnfd), + _ => Err(TcError::FunctionExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + /// Infer the type of `e` and ensure it's a sort; return the universe level. + pub fn infer_sort_of(&mut self, e: &Expr) -> TcResult { + let ty = self.infer(e)?; + let whnfd = self.whnf(&ty); + self.ensure_sort(&whnfd) + } + + // ========================================================================== + // Type inference + // ========================================================================== + + pub fn infer(&mut self, e: &Expr) -> TcResult { + if let Some(cached) = self.infer_cache.get(e) { + return Ok(cached.clone()); + } + let result = self.infer_core(e)?; + self.infer_cache.insert(e.clone(), result.clone()); + Ok(result) + } + + fn infer_core(&mut self, e: &Expr) -> TcResult { + match e.as_data() { + ExprData::Sort(level, _) => self.infer_sort(level), + ExprData::Const(name, levels, _) => self.infer_const(name, levels), + ExprData::App(..) => self.infer_app(e), + ExprData::Lam(..) => self.infer_lambda(e), + ExprData::ForallE(..) => self.infer_pi(e), + ExprData::LetE(_, typ, val, body, _, _) => { + self.infer_let(typ, val, body) + }, + ExprData::Lit(lit, _) => self.infer_lit(lit), + ExprData::Proj(type_name, idx, structure, _) => { + self.infer_proj(type_name, idx, structure) + }, + ExprData::Mdata(_, inner, _) => self.infer(inner), + ExprData::Fvar(name, _) => { + match self.local_types.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context".into(), + }), + } + }, + ExprData::Bvar(idx, _) => Err(TcError::FreeBoundVariable { + idx: idx.to_u64().unwrap_or(u64::MAX), + }), + ExprData::Mvar(..) => Err(TcError::KernelException { + msg: "cannot infer type of metavariable".into(), + }), + } + } + + fn infer_sort(&mut self, level: &Level) -> TcResult { + Ok(Expr::sort(Level::succ(level.clone()))) + } + + fn infer_const( + &mut self, + name: &Name, + levels: &[Level], + ) -> TcResult { + let ci = self + .env + .get(name) + .ok_or_else(|| TcError::UnknownConst { name: name.clone() })?; + + let decl_params = ci.get_level_params(); + if levels.len() != decl_params.len() { + return Err(TcError::KernelException { + msg: format!( + "universe parameter count mismatch for {}", + name.pretty() + ), + }); + } + + let ty = ci.get_type(); + Ok(subst_expr_levels(ty, decl_params, levels)) + } + + fn infer_app(&mut self, e: &Expr) -> TcResult { + let (fun, args) = unfold_apps(e); + let mut fun_ty = self.infer(&fun)?; + + for arg in &args { + let pi = self.ensure_pi(&fun_ty)?; + match pi.as_data() { + ExprData::ForallE(_, binder_type, body, _, _) => { + // Check argument type matches binder + let arg_ty = self.infer(arg)?; + self.assert_def_eq(&arg_ty, binder_type)?; + fun_ty = inst(body, &[arg.clone()]); + }, + _ => unreachable!(), + } + } + + Ok(fun_ty) + } + + fn infer_lambda(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut binder_types = Vec::new(); + let mut binder_infos = Vec::new(); + let mut binder_names = Vec::new(); + + while let ExprData::Lam(name, binder_type, body, bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + self.infer_sort_of(&binder_type_inst)?; + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + binder_types.push(binder_type_inst); + binder_infos.push(bi.clone()); + binder_names.push(name.clone()); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let body_ty = self.infer(&body_inst)?; + + // Abstract back: build Pi telescope + let mut result = abstr(&body_ty, &locals); + for i in (0..locals.len()).rev() { + let binder_type_abstrd = abstr(&binder_types[i], &locals[..i]); + result = Expr::all( + binder_names[i].clone(), + binder_type_abstrd, + result, + binder_infos[i].clone(), + ); + } + + Ok(result) + } + + fn infer_pi(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut universes = Vec::new(); + + while let ExprData::ForallE(name, binder_type, body, _bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + let dom_univ = self.infer_sort_of(&binder_type_inst)?; + universes.push(dom_univ); + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let mut result_level = self.infer_sort_of(&body_inst)?; + + for univ in universes.into_iter().rev() { + result_level = Level::imax(univ, result_level); + } + + Ok(Expr::sort(result_level)) + } + + fn infer_let( + &mut self, + typ: &Expr, + val: &Expr, + body: &Expr, + ) -> TcResult { + // Verify value matches declared type + let val_ty = self.infer(val)?; + self.assert_def_eq(&val_ty, typ)?; + let body_inst = inst(body, &[val.clone()]); + self.infer(&body_inst) + } + + fn infer_lit(&mut self, lit: &Literal) -> TcResult { + match lit { + Literal::NatVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "Nat".into()), vec![])) + }, + Literal::StrVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "String".into()), vec![])) + }, + } + } + + fn infer_proj( + &mut self, + type_name: &Name, + idx: &Nat, + structure: &Expr, + ) -> TcResult { + let structure_ty = self.infer(structure)?; + let structure_ty_whnf = self.whnf(&structure_ty); + + let (_, struct_ty_args) = unfold_apps(&structure_ty_whnf); + let struct_ty_head = match unfold_apps(&structure_ty_whnf).0.as_data() { + ExprData::Const(name, levels, _) => (name.clone(), levels.clone()), + _ => { + return Err(TcError::KernelException { + msg: "projection structure type is not a constant".into(), + }) + }, + }; + + let ind = self.env.get(&struct_ty_head.0).ok_or_else(|| { + TcError::UnknownConst { name: struct_ty_head.0.clone() } + })?; + + let (num_params, ctor_name) = match ind { + ConstantInfo::InductInfo(iv) => { + let ctor = iv.ctors.first().ok_or_else(|| { + TcError::KernelException { + msg: "inductive has no constructors".into(), + } + })?; + (iv.num_params.to_u64().unwrap(), ctor.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection type is not an inductive".into(), + }) + }, + }; + + let ctor_ci = self.env.get(&ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + + let mut ctor_ty = subst_expr_levels( + ctor_ci.get_type(), + ctor_ci.get_level_params(), + &struct_ty_head.1, + ); + + // Skip params + for i in 0..num_params as usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ctor_ty = inst(body, &[struct_ty_args[i].clone()]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (params)".into(), + }) + }, + } + } + + // Walk to the idx-th field + let idx_usize = idx.to_u64().unwrap() as usize; + for i in 0..idx_usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + let proj = + Expr::proj(type_name.clone(), Nat::from(i as u64), structure.clone()); + ctor_ty = inst(body, &[proj]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (fields)".into(), + }) + }, + } + } + + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, binder_type, _, _, _) => { + Ok(binder_type.clone()) + }, + _ => Err(TcError::KernelException { + msg: "ran out of constructor telescope (target field)".into(), + }), + } + } + + // ========================================================================== + // Definitional equality (delegated to def_eq module) + // ========================================================================== + + pub fn def_eq(&mut self, x: &Expr, y: &Expr) -> bool { + def_eq(x, y, self) + } + + pub fn assert_def_eq(&mut self, x: &Expr, y: &Expr) -> TcResult<()> { + if self.def_eq(x, y) { + Ok(()) + } else { + Err(TcError::DefEqFailure { lhs: x.clone(), rhs: y.clone() }) + } + } + + // ========================================================================== + // Declaration checking + // ========================================================================== + + /// Check that a declaration's type is well-formed. + pub fn check_declar_info( + &mut self, + info: &ConstantVal, + ) -> TcResult<()> { + // Check for duplicate universe params + if !no_dupes_all_params(&info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "duplicate universe parameters in {}", + info.name.pretty() + ), + }); + } + + // Check that the type has no loose bound variables + if has_loose_bvars(&info.typ) { + return Err(TcError::KernelException { + msg: format!( + "free bound variables in type of {}", + info.name.pretty() + ), + }); + } + + // Check that all universe parameters in the type are declared + if !all_expr_uparams_defined(&info.typ, &info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in type of {}", + info.name.pretty() + ), + }); + } + + // Check that the type is a type (infers to a Sort) + let inferred = self.infer(&info.typ)?; + self.ensure_sort(&inferred)?; + + Ok(()) + } + + /// Check a single declaration. + pub fn check_declar( + &mut self, + ci: &ConstantInfo, + ) -> TcResult<()> { + match ci { + ConstantInfo::AxiomInfo(v) => { + self.check_declar_info(&v.cnst)?; + }, + ConstantInfo::DefnInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::ThmInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::OpaqueInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::QuotInfo(v) => { + self.check_declar_info(&v.cnst)?; + super::quot::check_quot(self.env)?; + }, + ConstantInfo::InductInfo(v) => { + super::inductive::check_inductive(v, self)?; + }, + ConstantInfo::CtorInfo(v) => { + self.check_declar_info(&v.cnst)?; + // Verify the parent inductive exists + if self.env.get(&v.induct).is_none() { + return Err(TcError::UnknownConst { + name: v.induct.clone(), + }); + } + }, + ConstantInfo::RecInfo(v) => { + self.check_declar_info(&v.cnst)?; + for ind_name in &v.all { + if self.env.get(ind_name).is_none() { + return Err(TcError::UnknownConst { + name: ind_name.clone(), + }); + } + } + super::inductive::validate_k_flag(v, self.env)?; + }, + } + Ok(()) + } +} + +/// Check all declarations in an environment. +pub fn check_env(env: &Env) -> Vec<(Name, TcError)> { + let mut errors = Vec::new(); + for (name, ci) in env.iter() { + let mut tc = TypeChecker::new(env); + if let Err(e) = tc.check_declar(ci) { + errors.push((name.clone(), e)); + } + } + errors +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + fn prop() -> Expr { + Expr::sort(Level::zero()) + } + + fn type_u() -> Expr { + Expr::sort(Level::param(mk_name("u"))) + } + + /// Build a minimal environment with Nat, Nat.zero, and Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + + let nat_name = mk_name("Nat"); + // Nat : Sort 1 + let nat = ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }); + env.insert(nat_name, nat); + + // Nat.zero : Nat + let zero_name = mk_name2("Nat", "zero"); + let zero = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + env.insert(zero_name, zero); + + // Nat.succ : Nat → Nat + let succ_name = mk_name2("Nat", "succ"); + let succ_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let succ = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: succ_ty, + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }); + env.insert(succ_name, succ); + + env + } + + // ========================================================================== + // Infer: Sort + // ========================================================================== + + #[test] + fn infer_sort_zero() { + // Sort(0) : Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = prop(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_sort_succ() { + // Sort(1) : Sort(2) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::succ(Level::zero())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::succ(Level::zero())))); + } + + #[test] + fn infer_sort_param() { + // Sort(u) : Sort(u+1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let e = Expr::sort(u.clone()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(u))); + } + + // ========================================================================== + // Infer: Const + // ========================================================================== + + #[test] + fn infer_const_nat() { + // Nat : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_const_nat_zero() { + // Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_const_nat_succ() { + // Nat.succ : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let ty = tc.infer(&e).unwrap(); + let expected = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_unknown() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("NonExistent"), vec![]); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_const_universe_mismatch() { + // Nat has 0 universe params; passing 1 should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![Level::zero()]); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // Infer: Lit + // ========================================================================== + + #[test] + fn infer_nat_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_string_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::StrVal("hello".into())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::cnst(mk_name("String"), vec![])); + } + + // ========================================================================== + // Infer: Lambda + // ========================================================================== + + #[test] + fn infer_identity_lambda() { + // fun (x : Nat) => x : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let ty = tc.infer(&id_fn).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_lambda() { + // fun (x : Nat) (y : Nat) => x : Nat → Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let body = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + body, + BinderInfo::Default, + ); + let ty = tc.infer(&k_fn).unwrap(); + // Nat → Nat → Nat + let expected = Expr::all( + mk_name("x"), + nat_type(), + Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + // ========================================================================== + // Infer: App + // ========================================================================== + + #[test] + fn infer_app_succ_zero() { + // Nat.succ Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_app_identity() { + // (fun x : Nat => x) Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Pi + // ========================================================================== + + #[test] + fn infer_pi_nat_to_nat() { + // (Nat → Nat) : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let pi = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(1, 1)) which simplifies to Sort(1) + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::eq_antisymm( + level, + &Level::succ(Level::zero()) + ), + "Nat → Nat should live in Sort 1, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + #[test] + fn infer_pi_prop_to_prop() { + // (Prop → Prop) : Sort 1 + // An axiom P : Prop, then P → P : Sort 1 + let mut env = Env::default(); + let p_name = mk_name("P"); + env.insert( + p_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: prop(), + }, + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + let p = Expr::cnst(p_name, vec![]); + let pi = Expr::all(mk_name("x"), p.clone(), p.clone(), BinderInfo::Default); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(0, 0)) = Sort(0) = Prop + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::is_zero(level), + "Prop → Prop should live in Prop, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + // ========================================================================== + // Infer: Let + // ========================================================================== + + #[test] + fn infer_let_simple() { + // let x : Nat := Nat.zero in x : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: errors + // ========================================================================== + + #[test] + fn infer_free_bvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::bvar(Nat::from(0u64)); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_fvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::fvar(mk_name("x")); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_app_wrong_arg_type() { + // Nat.succ expects Nat, but we pass Sort(0) — should fail with DefEqFailure + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + prop(), // Sort(0), not Nat + ); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_let_type_mismatch() { + // let x : Nat → Nat := Nat.zero in x + // Nat.zero : Nat, but annotation says Nat → Nat — should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let nat_to_nat = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::letE( + mk_name("x"), + nat_to_nat, + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // check_declar + // ========================================================================== + + #[test] + fn check_axiom_declar() { + // axiom myAxiom : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("myAxiom"), + level_params: vec![], + typ: ax_ty, + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_declar() { + // def myId : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }); + assert!(tc.check_declar(&defn).is_ok()); + } + + #[test] + fn check_defn_type_mismatch() { + // def bad : Nat := Nat.succ (wrong: Nat.succ : Nat → Nat, not Nat) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + #[test] + fn check_declar_loose_bvar() { + // Type with a dangling bound variable should fail + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: Expr::bvar(Nat::from(0u64)), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_declar_duplicate_uparams() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![u.clone(), u], + typ: type_u(), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + // ========================================================================== + // check_env + // ========================================================================== + + #[test] + fn check_nat_env() { + let env = mk_nat_env(); + let errors = check_env(&env); + assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors); + } + + // ========================================================================== + // Polymorphic constants + // ========================================================================== + + #[test] + fn infer_polymorphic_const() { + // axiom A.{u} : Sort u + // A.{0} should give Sort(0) + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone()], + typ: Expr::sort(Level::param(u_name)), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + // A.{0} : Sort(0) + let e = Expr::cnst(a_name, vec![Level::zero()]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::zero())); + } + + // ========================================================================== + // Infer: whnf caching + // ========================================================================== + + #[test] + fn whnf_cache_works() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + let r1 = tc.whnf(&e); + let r2 = tc.whnf(&e); + assert_eq!(r1, r2); + } + + // ========================================================================== + // check_declar: Theorem + // ========================================================================== + + #[test] + fn check_theorem_declar() { + // theorem myThm : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("myThm"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + all: vec![mk_name("myThm")], + }); + assert!(tc.check_declar(&thm).is_ok()); + } + + #[test] + fn check_theorem_type_mismatch() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("badThm"), + level_params: vec![], + typ: nat_type(), // claims : Nat + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), // but is : Nat → Nat + all: vec![mk_name("badThm")], + }); + assert!(tc.check_declar(&thm).is_err()); + } + + // ========================================================================== + // check_declar: Opaque + // ========================================================================== + + #[test] + fn check_opaque_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let opaque = ConstantInfo::OpaqueInfo(OpaqueVal { + cnst: ConstantVal { + name: mk_name("myOpaque"), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + is_unsafe: false, + all: vec![mk_name("myOpaque")], + }); + assert!(tc.check_declar(&opaque).is_ok()); + } + + // ========================================================================== + // check_declar: Ctor (parent existence check) + // ========================================================================== + + #[test] + fn check_ctor_missing_parent() { + // A constructor whose parent inductive doesn't exist + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "mk"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + induct: mk_name("Fake"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_err()); + } + + #[test] + fn check_ctor_with_parent() { + // Nat.zero : Nat, with Nat in env + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "zero"), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_ok()); + } + + // ========================================================================== + // check_declar: Rec (mutual reference check) + // ========================================================================== + + #[test] + fn check_rec_missing_inductive() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "rec"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + all: vec![mk_name("Fake")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(0u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_with_inductive() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + // ========================================================================== + // Infer: App with delta (definition in head) + // ========================================================================== + + #[test] + fn infer_app_through_delta() { + // def myId : Nat → Nat := fun x => x + // myId Nat.zero : Nat + let mut env = mk_nat_env(); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + env.insert( + mk_name("myId"), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }), + ); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name("myId"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Proj + // ========================================================================== + + /// Build an env with a simple Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_name_prod = mk_name2("Prod", "mk"); + + // Prod.{u,v} : Sort u → Sort v → Sort (max u v) + // Simplified: Prod (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_name_prod.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + // Type: (α : Sort u) → (β : Sort v) → α → β → Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_name_prod.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name_prod, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn infer_proj_fst() { + // Given p : Prod Nat Nat, (Prod.1 p) : Nat + // Build: Prod.mk Nat Nat Nat.zero Nat.zero, then project field 0 + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let pair = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + nat_zero(), + ), + nat_zero(), + ); + + let proj = Expr::proj(mk_name("Prod"), Nat::from(0u64), pair); + let ty = tc.infer(&proj).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: nested let + // ========================================================================== + + #[test] + fn infer_nested_let() { + // let x := Nat.zero in let y := x in y : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::letE( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), // x + Expr::bvar(Nat::from(0u64)), // y + false, + ); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + inner, + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer caching + // ========================================================================== + + #[test] + fn infer_cache_hit() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty1 = tc.infer(&e).unwrap(); + let ty2 = tc.infer(&e).unwrap(); + assert_eq!(ty1, ty2); + assert_eq!(tc.infer_cache.len(), 1); + } + + // ========================================================================== + // Universe parameter validation + // ========================================================================== + + #[test] + fn check_axiom_undeclared_uparam_in_type() { + // axiom bad.{u} : Sort v — v is not declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("v"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_axiom_declared_uparam_in_type() { + // axiom good.{u} : Sort u — u is declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("good"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_undeclared_uparam_in_value() { + // def bad.{u} : Sort 1 := Sort v — v not declared, in value + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: Expr::sort(Level::param(mk_name("v"))), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + // ========================================================================== + // K-flag validation + // ========================================================================== + + /// Build an env with a Prop inductive + single zero-field ctor (Eq-like). + fn mk_eq_like_env() -> Env { + let mut env = mk_nat_env(); + let u = mk_name("u"); + let eq_name = mk_name("MyEq"); + let eq_refl = mk_name2("MyEq", "refl"); + + // MyEq.{u} (α : Sort u) (a : α) : α → Prop + // Simplified: type lives in Prop (Sort 0) + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + // MyEq.refl.{u} (α : Sort u) (a : α) : MyEq α a a + // zero fields + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_rec_k_flag_valid() { + let env = mk_eq_like_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("MyEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("MyEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + #[test] + fn check_rec_k_flag_invalid_2_ctors() { + // Nat has 2 constructors — K should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, // invalid: Nat is not in Prop and has 2 ctors + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } +} diff --git a/src/ix/kernel/upcopy.rs b/src/ix/kernel/upcopy.rs new file mode 100644 index 00000000..89dae8a0 --- /dev/null +++ b/src/ix/kernel/upcopy.rs @@ -0,0 +1,659 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Name}; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Upcopy +// ============================================================================ + +pub fn upcopy(new_child: DAGPtr, cc: ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + let var = &lam.var; + let new_lam = alloc_lam(var.depth, new_child, None); + let new_lam_ref = &mut *new_lam.as_ptr(); + let bod_ref_ptr = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_child, bod_ref_ptr); + let new_var_ptr = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + for parent in DLL::iter_option(var.parents) { + upcopy(DAGPtr::Var(new_var_ptr), *parent); + } + for parent in DLL::iter_option(lam.parents) { + upcopy(DAGPtr::Lam(new_lam), *parent); + } + }, + ParentPtr::AppFun(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).fun = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(new_child, app.arg); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).arg = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(app.fun, new_child); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::FunDom(link) => { + let fun = &mut *link.as_ptr(); + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_child, + fun.img, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + // new_child must be a Lam + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("FunImg parent expects Lam child"), + }; + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + new_lam, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::PiDom(link) => { + let pi = &mut *link.as_ptr(); + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_child, + pi.img, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("PiImg parent expects Lam child"), + }; + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + pi.dom, + new_lam, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::LetTyp(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).typ = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + new_child, + let_node.val, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetVal(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).val = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + new_child, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("LetBod parent expects Lam child"), + }; + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).bod = new_lam; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + let_node.val, + new_lam, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + let new_proj = alloc_proj_no_uplinks( + proj.type_name.clone(), + proj.idx.clone(), + new_child, + ); + for parent in DLL::iter_option(proj.parents) { + upcopy(DAGPtr::Proj(new_proj), *parent); + } + }, + } + } +} + +// ============================================================================ +// No-uplink allocators for upcopy +// ============================================================================ + +fn alloc_app_no_uplinks(fun: DAGPtr, arg: DAGPtr) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +fn alloc_fun_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +fn alloc_pi_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +fn alloc_let_no_uplinks( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +fn alloc_proj_no_uplinks( + type_name: Name, + idx: crate::lean::nat::Nat, + expr: DAGPtr, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents: None, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Clean up: Clear copy caches after reduction +// ============================================================================ + +pub fn clean_up(cc: &ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + for parent in DLL::iter_option(lam.var.parents) { + clean_up(parent); + } + for parent in DLL::iter_option(lam.parents) { + clean_up(parent); + } + }, + ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + if let Some(app_copy) = app.copy { + let App { fun, arg, fun_ref, arg_ref, .. } = + &mut *app_copy.as_ptr(); + app.copy = None; + add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); + add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); + for parent in DLL::iter_option(app.parents) { + clean_up(parent); + } + } + }, + ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + if let Some(fun_copy) = fun.copy { + let Fun { dom, img, dom_ref, img_ref, .. } = + &mut *fun_copy.as_ptr(); + fun.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(fun.parents) { + clean_up(parent); + } + } + }, + ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + if let Some(pi_copy) = pi.copy { + let Pi { dom, img, dom_ref, img_ref, .. } = + &mut *pi_copy.as_ptr(); + pi.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(pi.parents) { + clean_up(parent); + } + } + }, + ParentPtr::LetTyp(link) + | ParentPtr::LetVal(link) + | ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + if let Some(let_copy) = let_node.copy { + let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = + &mut *let_copy.as_ptr(); + let_node.copy = None; + add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); + add_to_parents(*val, NonNull::new(val_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); + for parent in DLL::iter_option(let_node.parents) { + clean_up(parent); + } + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + for parent in DLL::iter_option(proj.parents) { + clean_up(parent); + } + }, + } + } +} + +// ============================================================================ +// Replace child +// ============================================================================ + +pub fn replace_child(old: DAGPtr, new: DAGPtr) { + unsafe { + if let Some(parents) = get_parents(old) { + for parent in DLL::iter_option(Some(parents)) { + match parent { + ParentPtr::Root => {}, + ParentPtr::LamBod(p) => (*p.as_ptr()).bod = new, + ParentPtr::FunDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::FunImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("FunImg expects Lam"), + }, + ParentPtr::PiDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::PiImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("PiImg expects Lam"), + }, + ParentPtr::AppFun(p) => (*p.as_ptr()).fun = new, + ParentPtr::AppArg(p) => (*p.as_ptr()).arg = new, + ParentPtr::LetTyp(p) => (*p.as_ptr()).typ = new, + ParentPtr::LetVal(p) => (*p.as_ptr()).val = new, + ParentPtr::LetBod(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).bod = lam, + _ => panic!("LetBod expects Lam"), + }, + ParentPtr::ProjExpr(p) => (*p.as_ptr()).expr = new, + } + } + set_parents(old, None); + match get_parents(new) { + None => set_parents(new, Some(parents)), + Some(new_parents) => { + DLL::concat(new_parents, Some(parents)); + }, + } + } + } +} + +// ============================================================================ +// Free dead nodes +// ============================================================================ + +pub fn free_dead_node(node: DAGPtr) { + unsafe { + match node { + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + let bod_ref_ptr = &lam.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(lam.bod, Some(remaining)); + } else { + set_parents(lam.bod, None); + free_dead_node(lam.bod); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun_ref_ptr = &app.fun_ref as *const Parents; + if let Some(remaining) = (*fun_ref_ptr).unlink_node() { + set_parents(app.fun, Some(remaining)); + } else { + set_parents(app.fun, None); + free_dead_node(app.fun); + } + let arg_ref_ptr = &app.arg_ref as *const Parents; + if let Some(remaining) = (*arg_ref_ptr).unlink_node() { + set_parents(app.arg, Some(remaining)); + } else { + set_parents(app.arg, None); + free_dead_node(app.arg); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let dom_ref_ptr = &fun.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(fun.dom, Some(remaining)); + } else { + set_parents(fun.dom, None); + free_dead_node(fun.dom); + } + let img_ref_ptr = &fun.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(fun.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(fun.img), None); + free_dead_node(DAGPtr::Lam(fun.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let dom_ref_ptr = &pi.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(pi.dom, Some(remaining)); + } else { + set_parents(pi.dom, None); + free_dead_node(pi.dom); + } + let img_ref_ptr = &pi.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(pi.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(pi.img), None); + free_dead_node(DAGPtr::Lam(pi.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let typ_ref_ptr = &let_node.typ_ref as *const Parents; + if let Some(remaining) = (*typ_ref_ptr).unlink_node() { + set_parents(let_node.typ, Some(remaining)); + } else { + set_parents(let_node.typ, None); + free_dead_node(let_node.typ); + } + let val_ref_ptr = &let_node.val_ref as *const Parents; + if let Some(remaining) = (*val_ref_ptr).unlink_node() { + set_parents(let_node.val, Some(remaining)); + } else { + set_parents(let_node.val, None); + free_dead_node(let_node.val); + } + let bod_ref_ptr = &let_node.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(let_node.bod), None); + free_dead_node(DAGPtr::Lam(let_node.bod)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let expr_ref_ptr = &proj.expr_ref as *const Parents; + if let Some(remaining) = (*expr_ref_ptr).unlink_node() { + set_parents(proj.expr, Some(remaining)); + } else { + set_parents(proj.expr, None); + free_dead_node(proj.expr); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + } + } +} + +// ============================================================================ +// Lambda reduction +// ============================================================================ + +/// Contract a lambda redex: (Fun dom (Lam bod var)) arg → [arg/var]bod. +pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { + unsafe { + let app = &*redex.as_ptr(); + let lambda = &*lam.as_ptr(); + let var = &lambda.var; + let arg = app.arg; + + if DLL::is_singleton(lambda.parents) { + if DLL::is_empty(var.parents) { + return lambda.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + return lambda.bod; + } + + if DLL::is_empty(var.parents) { + return lambda.bod; + } + + // General case: upcopy arg through var's parents + for parent in DLL::iter_option(var.parents) { + upcopy(arg, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lambda.bod + } +} + +/// Contract a let redex: Let(typ, val, Lam(bod, var)) → [val/var]bod. +pub fn reduce_let(let_node: NonNull) -> DAGPtr { + unsafe { + let ln = &*let_node.as_ptr(); + let lam = &*ln.bod.as_ptr(); + let var = &lam.var; + let val = ln.val; + + if DLL::is_singleton(lam.parents) { + if DLL::is_empty(var.parents) { + return lam.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), val); + return lam.bod; + } + + if DLL::is_empty(var.parents) { + return lam.bod; + } + + for parent in DLL::iter_option(var.parents) { + upcopy(val, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lam.bod + } +} diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs new file mode 100644 index 00000000..4fdde07a --- /dev/null +++ b/src/ix/kernel/whnf.rs @@ -0,0 +1,1420 @@ +use core::ptr::NonNull; + +use crate::ix::env::*; +use crate::lean::nat::Nat; +use num_bigint::BigUint; + +use super::convert::{from_expr, to_expr}; +use super::dag::*; +use super::level::{simplify, subst_level}; +use super::upcopy::{reduce_lam, reduce_let}; + + +// ============================================================================ +// Expression helpers (inst, unfold_apps, foldl_apps, subst_expr_levels) +// ============================================================================ + +/// Instantiate bound variables: `body[0 := substs[0], 1 := substs[1], ...]`. +/// `substs[0]` replaces `Bvar(0)` (innermost). +pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { + if substs.is_empty() { + return body.clone(); + } + inst_aux(body, substs, 0) +} + +fn inst_aux(e: &Expr, substs: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 >= offset { + let adjusted = (idx_u64 - offset) as usize; + if adjusted < substs.len() { + return substs[adjusted].clone(); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = inst_aux(f, substs, offset); + let a2 = inst_aux(a, substs, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = inst_aux(t, substs, offset); + let v2 = inst_aux(v, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = inst_aux(s, substs, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = inst_aux(inner, substs, offset); + Expr::mdata(kvs.clone(), inner2) + }, + // Terminals with no bound vars + ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Abstract: replace free variable `fvar` with `Bvar(offset)` in `e`. +pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { + if fvars.is_empty() { + return e.clone(); + } + abstr_aux(e, fvars, 0) +} + +fn abstr_aux(e: &Expr, fvars: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Fvar(..) => { + for (i, fv) in fvars.iter().enumerate().rev() { + if e == fv { + return Expr::bvar(Nat::from(i as u64 + offset)); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = abstr_aux(f, fvars, offset); + let a2 = abstr_aux(a, fvars, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = abstr_aux(t, fvars, offset); + let v2 = abstr_aux(v, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = abstr_aux(s, fvars, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = abstr_aux(inner, fvars, offset); + Expr::mdata(kvs.clone(), inner2) + }, + ExprData::Bvar(..) + | ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])`. +pub fn unfold_apps(e: &Expr) -> (Expr, Vec) { + let mut args = Vec::new(); + let mut cursor = e.clone(); + loop { + match cursor.as_data() { + ExprData::App(f, a, _) => { + args.push(a.clone()); + cursor = f.clone(); + }, + _ => break, + } + } + args.reverse(); + (cursor, args) +} + +/// Reconstruct `f a1 a2 ... an`. +pub fn foldl_apps(mut fun: Expr, args: impl Iterator) -> Expr { + for arg in args { + fun = Expr::app(fun, arg); + } + fun +} + +/// Substitute universe level parameters in an expression. +pub fn subst_expr_levels( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + if params.is_empty() { + return e.clone(); + } + subst_expr_levels_aux(e, params, values) +} + +fn subst_expr_levels_aux( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + match e.as_data() { + ExprData::Sort(level, _) => { + Expr::sort(subst_level(level, params, values)) + }, + ExprData::Const(name, levels, _) => { + let new_levels: Vec = + levels.iter().map(|l| subst_level(l, params, values)).collect(); + Expr::cnst(name.clone(), new_levels) + }, + ExprData::App(f, a, _) => { + let f2 = subst_expr_levels_aux(f, params, values); + let a2 = subst_expr_levels_aux(a, params, values); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let v2 = subst_expr_levels_aux(v, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = subst_expr_levels_aux(s, params, values); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = subst_expr_levels_aux(inner, params, values); + Expr::mdata(kvs.clone(), inner2) + }, + // No levels to substitute + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Check if an expression has any loose bound variables above `offset`. +pub fn has_loose_bvars(e: &Expr) -> bool { + has_loose_bvars_aux(e, 0) +} + +fn has_loose_bvars_aux(e: &Expr, depth: u64) -> bool { + match e.as_data() { + ExprData::Bvar(idx, _) => idx.to_u64().unwrap_or(u64::MAX) >= depth, + ExprData::App(f, a, _) => { + has_loose_bvars_aux(f, depth) || has_loose_bvars_aux(a, depth) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_loose_bvars_aux(t, depth) || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_loose_bvars_aux(t, depth) + || has_loose_bvars_aux(v, depth) + || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::Proj(_, _, s, _) => has_loose_bvars_aux(s, depth), + ExprData::Mdata(_, inner, _) => has_loose_bvars_aux(inner, depth), + _ => false, + } +} + +/// Check if expression contains any free variables (Fvar). +pub fn has_fvars(e: &Expr) -> bool { + match e.as_data() { + ExprData::Fvar(..) => true, + ExprData::App(f, a, _) => has_fvars(f) || has_fvars(a), + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_fvars(t) || has_fvars(b) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_fvars(t) || has_fvars(v) || has_fvars(b) + }, + ExprData::Proj(_, _, s, _) => has_fvars(s), + ExprData::Mdata(_, inner, _) => has_fvars(inner), + _ => false, + } +} + +// ============================================================================ +// Name helpers +// ============================================================================ + +pub(crate) fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) +} + +// ============================================================================ +// WHNF +// ============================================================================ + +/// Weak head normal form reduction. +/// +/// Uses DAG-based reduction internally: converts Expr to DAG, reduces using +/// BUBS (reduce_lam/reduce_let) for beta/zeta, falls back to Expr level for +/// iota/quot/nat/projection, and uses DAG-level splicing for delta. +pub fn whnf(e: &Expr, env: &Env) -> Expr { + let mut dag = from_expr(e); + whnf_dag(&mut dag, env); + let result = to_expr(&dag); + free_dag(dag); + result +} + +/// Trail-based WHNF on DAG. Walks down the App spine collecting a trail, +/// then dispatches on the head node. +fn whnf_dag(dag: &mut DAG, env: &Env) { + loop { + // Build trail of App nodes by walking down the fun chain + let mut trail: Vec> = Vec::new(); + let mut cursor = dag.head; + + loop { + match cursor { + DAGPtr::App(app) => { + trail.push(app); + cursor = unsafe { (*app.as_ptr()).fun }; + }, + _ => break, + } + } + + match cursor { + // Beta: Fun at head with args on trail + DAGPtr::Fun(fun_ptr) if !trail.is_empty() => { + let app = trail.pop().unwrap(); + let lam = unsafe { (*fun_ptr.as_ptr()).img }; + let result = reduce_lam(app, lam); + set_dag_head(dag, result, &trail); + continue; + }, + + // Zeta: Let at head + DAGPtr::Let(let_ptr) => { + let result = reduce_let(let_ptr); + set_dag_head(dag, result, &trail); + continue; + }, + + // Const: try iota, quot, nat, then delta + DAGPtr::Cnst(_) => { + // Try iota, quot, nat at Expr level + if try_expr_reductions(dag, env) { + continue; + } + // Try delta (definition unfolding) on DAG + if try_dag_delta(dag, &trail, env) { + continue; + } + return; // stuck + }, + + // Proj: try projection reduction (Expr-level fallback) + DAGPtr::Proj(_) => { + if try_expr_reductions(dag, env) { + continue; + } + return; // stuck + }, + + // Sort: simplify level in place + DAGPtr::Sort(sort_ptr) => { + unsafe { + let sort = &mut *sort_ptr.as_ptr(); + sort.level = simplify(&sort.level); + } + return; + }, + + // Mdata: strip metadata (Expr-level fallback) + DAGPtr::Lit(_) => { + // Check if this is a Nat literal that could be a Nat.succ application + // by trying Expr-level reductions (which handles nat ops) + if !trail.is_empty() { + if try_expr_reductions(dag, env) { + continue; + } + } + return; + }, + + // Everything else (Var, Pi, Lam without args, etc.): already WHNF + _ => return, + } + } +} + +/// Set the DAG head after a reduction step. +/// If trail is empty, the result becomes the new head. +/// If trail is non-empty, splice result into the innermost remaining App. +fn set_dag_head( + dag: &mut DAG, + result: DAGPtr, + trail: &[NonNull], +) { + if trail.is_empty() { + dag.head = result; + } else { + unsafe { + (*trail.last().unwrap().as_ptr()).fun = result; + } + dag.head = DAGPtr::App(trail[0]); + } +} + +/// Try iota/quot/nat/projection reductions at Expr level. +/// Converts current DAG to Expr, attempts reduction, converts back if +/// successful. +fn try_expr_reductions(dag: &mut DAG, env: &Env) -> bool { + let current_expr = to_expr(&DAG { head: dag.head }); + + let (head, args) = unfold_apps(¤t_expr); + + let reduced = match head.as_data() { + ExprData::Const(name, levels, _) => { + // Try iota (recursor) reduction + if let Some(result) = try_reduce_rec(name, levels, &args, env) { + Some(result) + } + // Try quotient reduction + else if let Some(result) = try_reduce_quot(name, &args, env) { + Some(result) + } + // Try nat reduction + else if let Some(result) = + try_reduce_nat(¤t_expr, env) + { + Some(result) + } else { + None + } + }, + ExprData::Proj(type_name, idx, structure, _) => { + reduce_proj(type_name, idx, structure, env) + .map(|result| foldl_apps(result, args.into_iter())) + }, + ExprData::Mdata(_, inner, _) => { + Some(foldl_apps(inner.clone(), args.into_iter())) + }, + _ => None, + }; + + if let Some(result_expr) = reduced { + let result_dag = from_expr(&result_expr); + dag.head = result_dag.head; + true + } else { + false + } +} + +/// Try delta (definition) unfolding on DAG. +/// Looks up the constant, substitutes universe levels in the definition body, +/// converts it to a DAG, and splices it into the current DAG. +fn try_dag_delta( + dag: &mut DAG, + trail: &[NonNull], + env: &Env, +) -> bool { + // Extract constant info from head + let cnst_ref = match dag_head_past_trail(dag, trail) { + DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, + _ => return false, + }; + + let ci = match env.get(&cnst_ref.name) { + Some(c) => c, + None => return false, + }; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) + if d.hints != ReducibilityHints::Opaque => + { + (&d.cnst.level_params, &d.value) + }, + _ => return false, + }; + + if cnst_ref.levels.len() != def_params.len() { + return false; + } + + // Substitute levels at Expr level, then convert to DAG + let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); + let body_dag = from_expr(&val); + + // Splice body into the working DAG + set_dag_head(dag, body_dag.head, trail); + true +} + +/// Get the head node past the trail (the non-App node at the bottom). +fn dag_head_past_trail( + dag: &DAG, + trail: &[NonNull], +) -> DAGPtr { + if trail.is_empty() { + dag.head + } else { + unsafe { (*trail.last().unwrap().as_ptr()).fun } + } +} + +/// Try to unfold a definition at the head. +pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + let (name, levels) = match head.as_data() { + ExprData::Const(name, levels, _) => (name, levels), + _ => return None, + }; + + let ci = env.get(name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + _ => return None, + }; + + if levels.len() != def_params.len() { + return None; + } + + let val = subst_expr_levels(def_value, def_params, levels); + Some(foldl_apps(val, args.into_iter())) +} + +/// Try to reduce a recursor application (iota reduction). +fn try_reduce_rec( + name: &Name, + levels: &[Level], + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let rec = match ci { + ConstantInfo::RecInfo(r) => r, + _ => return None, + }; + + let major_idx = rec.num_params.to_u64().unwrap() as usize + + rec.num_motives.to_u64().unwrap() as usize + + rec.num_minors.to_u64().unwrap() as usize + + rec.num_indices.to_u64().unwrap() as usize; + + let major = args.get(major_idx)?; + + // WHNF the major premise + let major_whnf = whnf(major, env); + + // Handle nat literal → constructor + let major_ctor = match major_whnf.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => nat_lit_to_constructor(n), + _ => major_whnf.clone(), + }; + + let (ctor_head, ctor_args) = unfold_apps(&major_ctor); + + // Find the matching rec rule + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + let rule = rec.rules.iter().find(|r| &r.ctor == ctor_name)?; + + let n_fields = rule.n_fields.to_u64().unwrap() as usize; + let num_params = rec.num_params.to_u64().unwrap() as usize; + let num_motives = rec.num_motives.to_u64().unwrap() as usize; + let num_minors = rec.num_minors.to_u64().unwrap() as usize; + + // The constructor args may have extra params for nested inductives + let ctor_args_wo_params = + if ctor_args.len() >= n_fields { + &ctor_args[ctor_args.len() - n_fields..] + } else { + return None; + }; + + // Substitute universe levels in the rule's RHS + let rhs = subst_expr_levels( + &rule.rhs, + &rec.cnst.level_params, + levels, + ); + + // Apply: params, motives, minors + let prefix_count = num_params + num_motives + num_minors; + let mut result = rhs; + for arg in args.iter().take(prefix_count) { + result = Expr::app(result, arg.clone()); + } + + // Apply constructor fields + for arg in ctor_args_wo_params { + result = Expr::app(result, arg.clone()); + } + + // Apply remaining args after major + for arg in args.iter().skip(major_idx + 1) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Convert a Nat literal to its constructor form. +fn nat_lit_to_constructor(n: &Nat) -> Expr { + if n.0 == BigUint::ZERO { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } else { + let pred = Nat(n.0.clone() - BigUint::from(1u64)); + let pred_expr = Expr::lit(Literal::NatVal(pred)); + Expr::app(Expr::cnst(mk_name2("Nat", "succ"), vec![]), pred_expr) + } +} + +/// Convert a string literal to its constructor form: +/// `"hello"` → `String.mk (List.cons 'h' (List.cons 'e' ... List.nil))` +/// where chars are represented as `Char.ofNat n`. +fn string_lit_to_constructor(s: &str) -> Expr { + let list_name = Name::str(Name::anon(), "List".into()); + let char_name = Name::str(Name::anon(), "Char".into()); + let char_type = Expr::cnst(char_name.clone(), vec![]); + + // Build the list from right to left + // List.nil.{0} : List Char + let nil = Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "nil".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ); + + let result = s.chars().rev().fold(nil, |acc, c| { + let char_val = Expr::app( + Expr::cnst(Name::str(char_name.clone(), "ofNat".into()), vec![]), + Expr::lit(Literal::NatVal(Nat::from(c as u64))), + ); + // List.cons.{0} Char char_val acc + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "cons".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ), + char_val, + ), + acc, + ) + }); + + // String.mk list + Expr::app( + Expr::cnst( + Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), + vec![], + ), + result, + ) +} + +/// Try to reduce a projection. +fn reduce_proj( + _type_name: &Name, + idx: &Nat, + structure: &Expr, + env: &Env, +) -> Option { + let structure_whnf = whnf(structure, env); + + // Handle string literal → constructor + let structure_ctor = match structure_whnf.as_data() { + ExprData::Lit(Literal::StrVal(s), _) => { + string_lit_to_constructor(s) + }, + _ => structure_whnf, + }; + + let (ctor_head, ctor_args) = unfold_apps(&structure_ctor); + + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + // Look up constructor to get num_params + let ci = env.get(ctor_name)?; + let num_params = match ci { + ConstantInfo::CtorInfo(c) => c.num_params.to_u64().unwrap() as usize, + _ => return None, + }; + + let field_idx = num_params + idx.to_u64().unwrap() as usize; + ctor_args.get(field_idx).cloned() +} + +/// Try to reduce a quotient operation. +fn try_reduce_quot( + name: &Name, + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let kind = match ci { + ConstantInfo::QuotInfo(q) => &q.kind, + _ => return None, + }; + + let (qmk_idx, rest_idx) = match kind { + QuotKind::Lift => (5, 6), + QuotKind::Ind => (4, 5), + _ => return None, + }; + + let qmk = args.get(qmk_idx)?; + let qmk_whnf = whnf(qmk, env); + + // Check that the head is Quot.mk + let (qmk_head, _) = unfold_apps(&qmk_whnf); + match qmk_head.as_data() { + ExprData::Const(n, _, _) if *n == mk_name2("Quot", "mk") => {}, + _ => return None, + } + + let f = args.get(3)?; + + // Extract the argument of Quot.mk + let qmk_arg = match qmk_whnf.as_data() { + ExprData::App(_, arg, _) => arg, + _ => return None, + }; + + let mut result = Expr::app(f.clone(), qmk_arg.clone()); + for arg in args.iter().skip(rest_idx) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Try to reduce nat operations. +fn try_reduce_nat(e: &Expr, env: &Env) -> Option { + if has_fvars(e) { + return None; + } + + let (head, args) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + match args.len() { + 1 => { + if *name == mk_name2("Nat", "succ") { + let arg_whnf = whnf(&args[0], env); + let n = get_nat_value(&arg_whnf)?; + Some(Expr::lit(Literal::NatVal(Nat(n + BigUint::from(1u64))))) + } else { + None + } + }, + 2 => { + let a_whnf = whnf(&args[0], env); + let b_whnf = whnf(&args[1], env); + let a = get_nat_value(&a_whnf)?; + let b = get_nat_value(&b_whnf)?; + + let result = if *name == mk_name2("Nat", "add") { + Some(Expr::lit(Literal::NatVal(Nat(a + b)))) + } else if *name == mk_name2("Nat", "sub") { + Some(Expr::lit(Literal::NatVal(Nat(if a >= b { + a - b + } else { + BigUint::ZERO + })))) + } else if *name == mk_name2("Nat", "mul") { + Some(Expr::lit(Literal::NatVal(Nat(a * b)))) + } else if *name == mk_name2("Nat", "div") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + BigUint::ZERO + } else { + a / b + })))) + } else if *name == mk_name2("Nat", "mod") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + a + } else { + a % b + })))) + } else if *name == mk_name2("Nat", "beq") { + bool_to_expr(a == b) + } else if *name == mk_name2("Nat", "ble") { + bool_to_expr(a <= b) + } else if *name == mk_name2("Nat", "pow") { + let exp = u32::try_from(&b).unwrap_or(u32::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a.pow(exp))))) + } else if *name == mk_name2("Nat", "land") { + Some(Expr::lit(Literal::NatVal(Nat(a & b)))) + } else if *name == mk_name2("Nat", "lor") { + Some(Expr::lit(Literal::NatVal(Nat(a | b)))) + } else if *name == mk_name2("Nat", "xor") { + Some(Expr::lit(Literal::NatVal(Nat(a ^ b)))) + } else if *name == mk_name2("Nat", "shiftLeft") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a << shift)))) + } else if *name == mk_name2("Nat", "shiftRight") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a >> shift)))) + } else if *name == mk_name2("Nat", "blt") { + bool_to_expr(a < b) + } else { + None + }; + result + }, + _ => None, + } +} + +fn get_nat_value(e: &Expr) -> Option { + match e.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => Some(n.0.clone()), + ExprData::Const(name, _, _) if *name == mk_name2("Nat", "zero") => { + Some(BigUint::ZERO) + }, + _ => None, + } +} + +fn bool_to_expr(b: bool) -> Option { + let name = if b { + mk_name2("Bool", "true") + } else { + mk_name2("Bool", "false") + }; + Some(Expr::cnst(name, vec![])) +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + #[test] + fn test_inst_bvar() { + let body = Expr::bvar(Nat::from(0)); + let arg = nat_zero(); + let result = inst(&body, &[arg.clone()]); + assert_eq!(result, arg); + } + + #[test] + fn test_inst_nested() { + // body = Lam(_, Nat, Bvar(1)) — references outer binder + // After inst with [zero], should become Lam(_, Nat, zero) + let body = Expr::lam( + Name::anon(), + nat_type(), + Expr::bvar(Nat::from(1)), + BinderInfo::Default, + ); + let result = inst(&body, &[nat_zero()]); + let expected = Expr::lam( + Name::anon(), + nat_type(), + nat_zero(), + BinderInfo::Default, + ); + assert_eq!(result, expected); + } + + #[test] + fn test_unfold_apps() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + assert_eq!(head, f); + assert_eq!(args.len(), 2); + assert_eq!(args[0], a); + assert_eq!(args[1], b); + } + + #[test] + fn test_beta_reduce_identity() { + // (fun x : Nat => x) Nat.zero + let id = Expr::lam( + Name::str(Name::anon(), "x".into()), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let e = Expr::app(id, nat_zero()); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_zeta_reduce() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + Name::str(Name::anon(), "x".into()), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0)), + false, + ); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Delta reduction + // ========================================================================== + + fn mk_defn_env(name: &str, value: Expr, typ: Expr) -> Env { + let mut env = Env::default(); + let n = mk_name(name); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + env + } + + #[test] + fn test_delta_unfold() { + // def myZero := Nat.zero + // whnf(myZero) = Nat.zero + let env = mk_defn_env("myZero", nat_zero(), nat_type()); + let e = Expr::cnst(mk_name("myZero"), vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_delta_opaque_no_unfold() { + // An opaque definition should NOT unfold + let mut env = Env::default(); + let n = mk_name("opaqueVal"); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Opaque, + safety: DefinitionSafety::Safe, + all: vec![n.clone()], + }), + ); + let e = Expr::cnst(n.clone(), vec![]); + let result = whnf(&e, &env); + // Should still be the constant, not unfolded + assert_eq!(result, e); + } + + #[test] + fn test_delta_chained() { + // def a := Nat.zero, def b := a => whnf(b) = Nat.zero + let mut env = Env::default(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let e = Expr::cnst(b, vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Nat arithmetic reduction + // ========================================================================== + + fn nat_lit(n: u64) -> Expr { + Expr::lit(Literal::NatVal(Nat::from(n))) + } + + #[test] + fn test_nat_add() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "add"), vec![]), nat_lit(3)), + nat_lit(4), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub_underflow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(3)), + nat_lit(10), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mul() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mul"), vec![]), nat_lit(6)), + nat_lit(7), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(42)); + } + + #[test] + fn test_nat_div() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(3)); + } + + #[test] + fn test_nat_div_by_zero() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(0), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mod() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mod"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(1)); + } + + #[test] + fn test_nat_beq_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_beq_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_ble_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(3)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_ble_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_pow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "pow"), vec![]), nat_lit(2)), + nat_lit(10), + ); + assert_eq!(whnf(&e, &env), nat_lit(1024)); + } + + #[test] + fn test_nat_land() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "land"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1000)); + } + + #[test] + fn test_nat_lor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "lor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1110)); + } + + #[test] + fn test_nat_xor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "xor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b0110)); + } + + #[test] + fn test_nat_shift_left() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftLeft"), vec![]), nat_lit(1)), + nat_lit(8), + ); + assert_eq!(whnf(&e, &env), nat_lit(256)); + } + + #[test] + fn test_nat_shift_right() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), nat_lit(256)), + nat_lit(4), + ); + assert_eq!(whnf(&e, &env), nat_lit(16)); + } + + #[test] + fn test_nat_blt_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(3)), + nat_lit(5), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_blt_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(5)), + nat_lit(3), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + // ========================================================================== + // Sort simplification in WHNF + // ========================================================================== + + #[test] + fn test_string_lit_proj_reduces() { + // Build an env with String, String.mk ctor, List, List.cons, List.nil, Char + let mut env = Env::default(); + let string_name = mk_name("String"); + let string_mk = mk_name2("String", "mk"); + let list_name = mk_name("List"); + let char_name = mk_name("Char"); + + // String : Sort 1 + env.insert( + string_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: string_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![string_name.clone()], + ctors: vec![string_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // String.mk : List Char → String (1 field, 0 params) + let list_char = Expr::app( + Expr::cnst(list_name, vec![Level::succ(Level::zero())]), + Expr::cnst(char_name, vec![]), + ); + env.insert( + string_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: string_mk, + level_params: vec![], + typ: Expr::all( + mk_name("data"), + list_char, + Expr::cnst(string_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: string_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Proj String 0 "hi" should reduce (not return None) + let proj = Expr::proj( + string_name, + Nat::from(0u64), + Expr::lit(Literal::StrVal("hi".into())), + ); + let result = whnf(&proj, &env); + // The result should NOT be a Proj anymore (it should have reduced) + assert!( + !matches!(result.as_data(), ExprData::Proj(..)), + "String projection should reduce, got: {:?}", + result + ); + } + + #[test] + fn test_whnf_sort_simplifies() { + // Sort(max 0 u) should simplify to Sort(u) + let env = Env::default(); + let u = Level::param(mk_name("u")); + let e = Expr::sort(Level::max(Level::zero(), u.clone())); + let result = whnf(&e, &env); + assert_eq!(result, Expr::sort(u)); + } + + // ========================================================================== + // Already-WHNF terms + // ========================================================================== + + #[test] + fn test_whnf_sort_unchanged() { + let env = Env::default(); + let e = Expr::sort(Level::zero()); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_lambda_unchanged() { + // A lambda without applied arguments is already WHNF + let env = Env::default(); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_pi_unchanged() { + let env = Env::default(); + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + // ========================================================================== + // Helper function tests + // ========================================================================== + + #[test] + fn test_has_loose_bvars_true() { + assert!(has_loose_bvars(&Expr::bvar(Nat::from(0)))); + } + + #[test] + fn test_has_loose_bvars_false_under_binder() { + // fun x : Nat => x — bvar(0) is bound, not loose + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + assert!(!has_loose_bvars(&e)); + } + + #[test] + fn test_has_loose_bvars_const() { + assert!(!has_loose_bvars(&nat_zero())); + } + + #[test] + fn test_has_fvars_true() { + assert!(has_fvars(&Expr::fvar(mk_name("x")))); + } + + #[test] + fn test_has_fvars_false() { + assert!(!has_fvars(&nat_zero())); + } + + #[test] + fn test_foldl_apps_roundtrip() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = nat_type(); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + let rebuilt = foldl_apps(head, args.into_iter()); + assert_eq!(rebuilt, e); + } + + #[test] + fn test_abstr_simple() { + // abstr(fvar("x"), [fvar("x")]) = bvar(0) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&x, &[x.clone()]); + assert_eq!(result, Expr::bvar(Nat::from(0))); + } + + #[test] + fn test_abstr_not_found() { + // abstr(Nat.zero, [fvar("x")]) = Nat.zero (unchanged) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&nat_zero(), &[x]); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_subst_expr_levels_simple() { + // Sort(u) with u := 0 => Sort(0) + let u_name = mk_name("u"); + let e = Expr::sort(Level::param(u_name.clone())); + let result = subst_expr_levels(&e, &[u_name], &[Level::zero()]); + assert_eq!(result, Expr::sort(Level::zero())); + } +} From 13da42f245f403af3588018ab89cdadce4e1763f Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 07:45:54 -0500 Subject: [PATCH 2/5] WIP kernel --- Ix/Address.lean | 1 + Ix/Cli/CheckCmd.lean | 122 ++ Ix/CompileM.lean | 16 +- Ix/DecompileM.lean | 24 +- Ix/Ixon.lean | 16 +- Ix/Kernel.lean | 44 + Ix/Kernel/Convert.lean | 841 +++++++++++ Ix/Kernel/Datatypes.lean | 181 +++ Ix/Kernel/DecompileM.lean | 254 ++++ Ix/Kernel/Equal.lean | 168 +++ Ix/Kernel/Eval.lean | 530 +++++++ Ix/Kernel/Infer.lean | 406 +++++ Ix/Kernel/Level.lean | 131 ++ Ix/Kernel/TypecheckM.lean | 180 +++ Ix/Kernel/Types.lean | 569 +++++++ Main.lean | 2 + Tests/Ix/Check.lean | 107 ++ Tests/Ix/Compile.lean | 73 +- Tests/Ix/KernelTests.lean | 761 ++++++++++ Tests/Ix/PP.lean | 333 +++++ Tests/Main.lean | 17 + docs/Ixon.md | 5 +- src/ix/decompile.rs | 57 +- src/ix/ixon/env.rs | 47 +- src/ix/ixon/serialize.rs | 2 - src/ix/kernel/convert.rs | 835 +++++++---- src/ix/kernel/dag.rs | 645 +++++++- src/ix/kernel/dag_tc.rs | 2857 ++++++++++++++++++++++++++++++++++++ src/ix/kernel/def_eq.rs | 480 +++++- src/ix/kernel/inductive.rs | 121 +- src/ix/kernel/level.rs | 58 +- src/ix/kernel/mod.rs | 1 + src/ix/kernel/tc.rs | 663 +++++++-- src/ix/kernel/upcopy.rs | 872 +++++------ src/ix/kernel/whnf.rs | 1674 +++++++++++++++------ src/lean/ffi.rs | 1 + src/lean/ffi/check.rs | 182 +++ src/lean/ffi/lean_env.rs | 6 +- 38 files changed, 11748 insertions(+), 1534 deletions(-) create mode 100644 Ix/Cli/CheckCmd.lean create mode 100644 Ix/Kernel.lean create mode 100644 Ix/Kernel/Convert.lean create mode 100644 Ix/Kernel/Datatypes.lean create mode 100644 Ix/Kernel/DecompileM.lean create mode 100644 Ix/Kernel/Equal.lean create mode 100644 Ix/Kernel/Eval.lean create mode 100644 Ix/Kernel/Infer.lean create mode 100644 Ix/Kernel/Level.lean create mode 100644 Ix/Kernel/TypecheckM.lean create mode 100644 Ix/Kernel/Types.lean create mode 100644 Tests/Ix/Check.lean create mode 100644 Tests/Ix/KernelTests.lean create mode 100644 Tests/Ix/PP.lean create mode 100644 src/ix/kernel/dag_tc.rs create mode 100644 src/lean/ffi/check.rs diff --git a/Ix/Address.lean b/Ix/Address.lean index ee11eb85..562dd028 100644 --- a/Ix/Address.lean +++ b/Ix/Address.lean @@ -14,6 +14,7 @@ structure Address where /-- Compute the Blake3 hash of a `ByteArray`, returning an `Address`. -/ def Address.blake3 (x: ByteArray) : Address := ⟨(Blake3.hash x).val⟩ + /-- Convert a nibble (0--15) to its lowercase hexadecimal character. -/ def hexOfNat : Nat -> Option Char | 0 => .some '0' diff --git a/Ix/Cli/CheckCmd.lean b/Ix/Cli/CheckCmd.lean new file mode 100644 index 00000000..f8e388f0 --- /dev/null +++ b/Ix/Cli/CheckCmd.lean @@ -0,0 +1,122 @@ +import Cli +import Ix.Common +import Ix.Kernel +import Ix.Meta +import Ix.CompileM +import Lean + +open System (FilePath) + +/-- If the project depends on Mathlib, download the Mathlib cache. -/ +private def fetchMathlibCache (cwd : Option FilePath) : IO Unit := do + let root := cwd.getD "." + let manifest := root / "lake-manifest.json" + let contents ← IO.FS.readFile manifest + if contents.containsSubstr "leanprover-community/mathlib4" then + let mathlibBuild := root / ".lake" / "packages" / "mathlib" / ".lake" / "build" + if ← mathlibBuild.pathExists then + println! "Mathlib cache already present, skipping fetch." + return + println! "Detected Mathlib dependency. Fetching Mathlib cache..." + let child ← IO.Process.spawn { + cmd := "lake" + args := #["exe", "cache", "get"] + cwd := cwd + stdout := .inherit + stderr := .inherit + } + let exitCode ← child.wait + if exitCode != 0 then + throw $ IO.userError "lake exe cache get failed" + +/-- Build the Lean module at the given file path using Lake. -/ +private def buildFile (path : FilePath) : IO Unit := do + let path ← IO.FS.realPath path + let some moduleName := path.fileStem + | throw $ IO.userError s!"cannot determine module name from {path}" + fetchMathlibCache path.parent + let child ← IO.Process.spawn { + cmd := "lake" + args := #["build", moduleName] + cwd := path.parent + stdout := .inherit + stderr := .inherit + } + let exitCode ← child.wait + if exitCode != 0 then + throw $ IO.userError "lake build failed" + +/-- Run the Lean NbE kernel checker. -/ +private def runLeanCheck (leanEnv : Lean.Environment) : IO UInt32 := do + println! "Compiling to Ixon..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileElapsed := (← IO.monoMsNow) - compileStart + let numConsts := ixonEnv.consts.size + println! "Compiled {numConsts} constants in {compileElapsed.formatMs}" + + println! "Converting Ixon → Kernel..." + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + println! "Conversion error: {e}" + return 1 + | .ok (kenv, prims, quotInit) => + let convertElapsed := (← IO.monoMsNow) - convertStart + println! "Converted {kenv.size} constants in {convertElapsed.formatMs}" + + println! "Typechecking..." + let checkStart ← IO.monoMsNow + match Ix.Kernel.typecheckAll kenv prims quotInit with + | .error e => + let elapsed := (← IO.monoMsNow) - checkStart + println! "Kernel check failed in {elapsed.formatMs}: {e}" + return 1 + | .ok () => + let elapsed := (← IO.monoMsNow) - checkStart + println! "Checked {kenv.size} constants in {elapsed.formatMs}" + return 0 + +/-- Run the Rust kernel checker. -/ +private def runRustCheck (leanEnv : Lean.Environment) : IO UInt32 := do + let totalConsts := leanEnv.constants.toList.length + println! "Total constants: {totalConsts}" + + let start ← IO.monoMsNow + let errors ← Ix.Kernel.rsCheckEnv leanEnv + let elapsed := (← IO.monoMsNow) - start + + if errors.isEmpty then + println! "Checked {totalConsts} constants in {elapsed.formatMs}" + return 0 + else + println! "Kernel check failed with {errors.size} error(s) in {elapsed.formatMs}:" + for (name, err) in errors[:min 50 errors.size] do + println! " {repr name}: {repr err}" + return 1 + +def runCheckCmd (p : Cli.Parsed) : IO UInt32 := do + let some path := p.flag? "path" + | p.printError "error: must specify --path" + return 1 + let pathStr := path.as! String + let useLean := p.hasFlag "lean" + + buildFile pathStr + let leanEnv ← getFileEnv pathStr + + if useLean then + println! "Running Lean NbE kernel checker on {pathStr}" + runLeanCheck leanEnv + else + println! "Running Rust kernel checker on {pathStr}" + runRustCheck leanEnv + +def checkCmd : Cli.Cmd := `[Cli| + check VIA runCheckCmd; + "Type-check Lean file with kernel" + + FLAGS: + path : String; "Path to file to check" + lean; "Use Lean NbE kernel instead of Rust kernel" +] diff --git a/Ix/CompileM.lean b/Ix/CompileM.lean index e527f62c..efd8abd2 100644 --- a/Ix/CompileM.lean +++ b/Ix/CompileM.lean @@ -1604,11 +1604,10 @@ def compileEnv (env : Ix.Environment) (blocks : Ix.CondensedBlocks) (dbg : Bool -- Build reverse index and names map, storing name string components as blobs -- Seed with blockNames collected during compilation (binder names, level params, etc.) - let (addrToNameMap, namesMap, nameBlobs) := - compileEnv.nameToNamed.fold (init := ({}, blockNames, {})) fun (addrMap, namesMap, blobs) name named => - let addrMap := addrMap.insert named.addr name + let (namesMap, nameBlobs) := + compileEnv.nameToNamed.fold (init := (blockNames, {})) fun (namesMap, blobs) name _named => let (namesMap, blobs) := Ixon.RawEnv.addNameComponentsWithBlobs namesMap blobs name - (addrMap, namesMap, blobs) + (namesMap, blobs) -- Merge name string blobs into the main blobs map let allBlobs := nameBlobs.fold (fun m k v => m.insert k v) compileEnv.blobs @@ -1619,7 +1618,6 @@ def compileEnv (env : Ix.Environment) (blocks : Ix.CondensedBlocks) (dbg : Bool blobs := allBlobs names := namesMap comms := {} - addrToName := addrToNameMap } return .ok (ixonEnv, compileEnv.totalBytes) @@ -1890,11 +1888,10 @@ def compileEnvParallel (env : Ix.Environment) (blocks : Ix.CondensedBlocks) -- Build reverse index and names map, storing name string components as blobs -- Seed with blockNames collected during compilation (binder names, level params, etc.) - let (addrToNameMap, namesMap, nameBlobs) := - nameToNamed.fold (init := ({}, blockNames, {})) fun (addrMap, namesMap, nameBlobs) name named => - let addrMap := addrMap.insert named.addr name + let (namesMap, nameBlobs) := + nameToNamed.fold (init := (blockNames, {})) fun (namesMap, nameBlobs) name _named => let (namesMap, nameBlobs) := Ixon.RawEnv.addNameComponentsWithBlobs namesMap nameBlobs name - (addrMap, namesMap, nameBlobs) + (namesMap, nameBlobs) -- Merge name string blobs into the main blobs map let blockBlobCount := blobs.size @@ -1912,7 +1909,6 @@ def compileEnvParallel (env : Ix.Environment) (blocks : Ix.CondensedBlocks) blobs := allBlobs names := namesMap comms := {} - addrToName := addrToNameMap } return .ok (ixonEnv, totalBytes) diff --git a/Ix/DecompileM.lean b/Ix/DecompileM.lean index d22fb8f7..e1e8050b 100644 --- a/Ix/DecompileM.lean +++ b/Ix/DecompileM.lean @@ -117,12 +117,6 @@ def lookupNameAddrOrAnon (addr : Address) : DecompileM Ix.Name := do | some n => pure n | none => pure Ix.Name.mkAnon -/-- Resolve constant Address → Ix.Name via addrToName. -/ -def lookupConstName (addr : Address) : DecompileM Ix.Name := do - match (← getEnv).ixonEnv.addrToName.get? addr with - | some n => pure n - | none => throw (.missingAddress addr) - def lookupBlob (addr : Address) : DecompileM ByteArray := do match (← getEnv).ixonEnv.blobs.get? addr with | some blob => pure blob @@ -390,18 +384,14 @@ partial def decompileExpr (e : Ixon.Expr) (arenaIdx : UInt64) : DecompileM Ix.Ex pure (applyMdata (Ix.Expr.mkLit (.strVal s)) mdataLayers) -- Ref with arena metadata - | .ref nameAddr, .ref refIdx univIndices => do - let name ← match (← getEnv).ixonEnv.names.get? nameAddr with - | some n => pure n - | none => getRef refIdx >>= lookupConstName + | .ref nameAddr, .ref _refIdx univIndices => do + let name ← lookupNameAddr nameAddr let lvls ← decompileUnivIndices univIndices pure (applyMdata (Ix.Expr.mkConst name lvls) mdataLayers) -- Ref without arena metadata - | _, .ref refIdx univIndices => do - let name ← getRef refIdx >>= lookupConstName - let lvls ← decompileUnivIndices univIndices - pure (applyMdata (Ix.Expr.mkConst name lvls) mdataLayers) + | _, .ref _refIdx _univIndices => do + throw (.badConstantFormat "ref without arena metadata") -- Rec with arena metadata | .ref nameAddr, .recur recIdx univIndices => do @@ -472,10 +462,8 @@ partial def decompileExpr (e : Ixon.Expr) (arenaIdx : UInt64) : DecompileM Ix.Ex let valExpr ← decompileExpr val child pure (applyMdata (Ix.Expr.mkProj typeName fieldIdx.toNat valExpr) mdataLayers) - | _, .prj typeRefIdx fieldIdx val => do - let typeName ← getRef typeRefIdx >>= lookupConstName - let valExpr ← decompileExpr val UInt64.MAX - pure (applyMdata (Ix.Expr.mkProj typeName fieldIdx.toNat valExpr) mdataLayers) + | _, .prj _typeRefIdx _fieldIdx _val => do + throw (.badConstantFormat "prj without arena metadata") | _, .share _ => throw (.badConstantFormat "unexpected Share in decompileExpr") diff --git a/Ix/Ixon.lean b/Ix/Ixon.lean index 5432d12c..cc4d1d11 100644 --- a/Ix/Ixon.lean +++ b/Ix/Ixon.lean @@ -1380,12 +1380,10 @@ structure Env where named : Std.HashMap Ix.Name Named := {} /-- Raw data blobs: Address → bytes -/ blobs : Std.HashMap Address ByteArray := {} - /-- Hash-consed name components: Address → Ix.Name -/ - names : Std.HashMap Address Ix.Name := {} /-- Cryptographic commitments: Address → Comm -/ comms : Std.HashMap Address Comm := {} - /-- Reverse index: constant Address → Ix.Name -/ - addrToName : Std.HashMap Address Ix.Name := {} + /-- Hash-consed name components: Address → Ix.Name -/ + names : Std.HashMap Address Ix.Name := {} deriving Inhabited namespace Env @@ -1401,8 +1399,7 @@ def getConst? (env : Env) (addr : Address) : Option Constant := /-- Register a name with full Named metadata. -/ def registerName (env : Env) (name : Ix.Name) (named : Named) : Env := { env with - named := env.named.insert name named - addrToName := env.addrToName.insert named.addr name } + named := env.named.insert name named } /-- Register a name with just an address (empty metadata). -/ def registerNameAddr (env : Env) (name : Ix.Name) (addr : Address) : Env := @@ -1416,10 +1413,6 @@ def getAddr? (env : Env) (name : Ix.Name) : Option Address := def getNamed? (env : Env) (name : Ix.Name) : Option Named := env.named.get? name -/-- Look up an address's name. -/ -def getName? (env : Env) (addr : Address) : Option Ix.Name := - env.addrToName.get? addr - /-- Store a blob and return its content address. -/ def storeBlob (env : Env) (bytes : ByteArray) : Env × Address := let addr := Address.blake3 bytes @@ -1742,8 +1735,7 @@ def getEnv : GetM Env := do | some name => let namedEntry : Named := ⟨constAddr, constMeta⟩ env := { env with - named := env.named.insert name namedEntry - addrToName := env.addrToName.insert constAddr name } + named := env.named.insert name namedEntry } | none => throw s!"getEnv: named entry references unknown name address {reprStr (toString nameAddr)}" diff --git a/Ix/Kernel.lean b/Ix/Kernel.lean new file mode 100644 index 00000000..cbb6c467 --- /dev/null +++ b/Ix/Kernel.lean @@ -0,0 +1,44 @@ +import Lean +import Ix.Environment +import Ix.Kernel.Types +import Ix.Kernel.Datatypes +import Ix.Kernel.Level +import Ix.Kernel.TypecheckM +import Ix.Kernel.Eval +import Ix.Kernel.Equal +import Ix.Kernel.Infer +import Ix.Kernel.Convert + +namespace Ix.Kernel + +/-- Type-checking errors from the Rust kernel, mirroring `TcError` in Rust. -/ +inductive CheckError where + | typeExpected (expr : Ix.Expr) (inferred : Ix.Expr) + | functionExpected (expr : Ix.Expr) (inferred : Ix.Expr) + | typeMismatch (expected : Ix.Expr) (found : Ix.Expr) (expr : Ix.Expr) + | defEqFailure (lhs : Ix.Expr) (rhs : Ix.Expr) + | unknownConst (name : Ix.Name) + | duplicateUniverse (name : Ix.Name) + | freeBoundVariable (idx : UInt64) + | kernelException (msg : String) + deriving Repr + +/-- FFI: Run Rust kernel type-checker over all declarations in a Lean environment. -/ +@[extern "rs_check_env"] +opaque rsCheckEnvFFI : @& List (Lean.Name × Lean.ConstantInfo) → IO (Array (Ix.Name × CheckError)) + +/-- Check all declarations in a Lean environment using the Rust kernel. + Returns an array of (name, error) pairs for any declarations that fail. -/ +def rsCheckEnv (leanEnv : Lean.Environment) : IO (Array (Ix.Name × CheckError)) := + rsCheckEnvFFI leanEnv.constants.toList + +/-- FFI: Type-check a single constant by dotted name string. -/ +@[extern "rs_check_const"] +opaque rsCheckConstFFI : @& List (Lean.Name × Lean.ConstantInfo) → @& String → IO (Option CheckError) + +/-- Check a single constant by name using the Rust kernel. + Returns `none` on success, `some err` on failure. -/ +def rsCheckConst (leanEnv : Lean.Environment) (name : String) : IO (Option CheckError) := + rsCheckConstFFI leanEnv.constants.toList name + +end Ix.Kernel diff --git a/Ix/Kernel/Convert.lean b/Ix/Kernel/Convert.lean new file mode 100644 index 00000000..369ffca2 --- /dev/null +++ b/Ix/Kernel/Convert.lean @@ -0,0 +1,841 @@ +/- + Kernel Convert: Ixon.Env → Kernel.Env conversion. + + Two modes: + - `convert` produces `Kernel.Env .meta` with full names and binder info + - `convertAnon` produces `Kernel.Env .anon` with all metadata as () + + Much simpler than DecompileM: no Blake3 hash computation, no mdata reconstruction. +-/ +import Ix.Kernel.Types +import Ix.Ixon + +namespace Ix.Kernel.Convert + +open Ix (Name) +open Ixon (Constant ConstantInfo ConstantMeta MutConst Named) + +/-! ## Universe conversion -/ + +partial def convertUniv (m : MetaMode) (levelParamNames : Array (MetaField m Ix.Name) := #[]) + : Ixon.Univ → Level m + | .zero => .zero + | .succ l => .succ (convertUniv m levelParamNames l) + | .max l₁ l₂ => .max (convertUniv m levelParamNames l₁) (convertUniv m levelParamNames l₂) + | .imax l₁ l₂ => .imax (convertUniv m levelParamNames l₁) (convertUniv m levelParamNames l₂) + | .var idx => + let name := if h : idx.toNat < levelParamNames.size then levelParamNames[idx.toNat] else default + .param idx.toNat name + +/-! ## Expression conversion monad -/ + +structure ConvertEnv (m : MetaMode) where + sharing : Array Ixon.Expr + refs : Array Address + univs : Array Ixon.Univ + blobs : Std.HashMap Address ByteArray + recurAddrs : Array Address := #[] + arena : Ixon.ExprMetaArena := {} + names : Std.HashMap Address Ix.Name := {} + levelParamNames : Array (MetaField m Ix.Name) := #[] + binderNames : List (MetaField m Ix.Name) := [] + +structure ConvertState (m : MetaMode) where + exprCache : Std.HashMap (UInt64 × UInt64) (Expr m) := {} + +inductive ConvertError where + | refOutOfBounds (refIdx : UInt64) (refsSize : Nat) + | recurOutOfBounds (recIdx : UInt64) (recurAddrsSize : Nat) + | prjRefOutOfBounds (typeRefIdx : UInt64) (refsSize : Nat) + | missingMemberAddr (memberIdx : Nat) (numMembers : Nat) + | unresolvableCtxAddr (addr : Address) + | missingName (nameAddr : Address) + +instance : ToString ConvertError where + toString + | .refOutOfBounds idx sz => s!"ref index {idx} out of bounds (refs.size={sz})" + | .recurOutOfBounds idx sz => s!"recur index {idx} out of bounds (recurAddrs.size={sz})" + | .prjRefOutOfBounds idx sz => s!"proj type ref index {idx} out of bounds (refs.size={sz})" + | .missingMemberAddr idx n => s!"no address for member {idx} (numMembers={n})" + | .unresolvableCtxAddr addr => s!"unresolvable ctx address {addr}" + | .missingName addr => s!"missing name for address {addr}" + +abbrev ConvertM (m : MetaMode) := ReaderT (ConvertEnv m) (StateT (ConvertState m) (ExceptT ConvertError Id)) + +def ConvertState.init (_ : ConvertEnv m) : ConvertState m := {} + +def ConvertM.run (env : ConvertEnv m) (x : ConvertM m α) : Except ConvertError α := + match x env |>.run (ConvertState.init env) with + | .ok (a, _) => .ok a + | .error e => .error e + +/-- Run a ConvertM computation with existing state, return result and final state. -/ +def ConvertM.runWith (env : ConvertEnv m) (st : ConvertState m) (x : ConvertM m α) + : Except ConvertError (α × ConvertState m) := + x env |>.run st + +/-! ## Expression conversion -/ + +def resolveUnivs (m : MetaMode) (idxs : Array UInt64) : ConvertM m (Array (Level m)) := do + let ctx ← read + return idxs.map fun i => + if h : i.toNat < ctx.univs.size + then convertUniv m ctx.levelParamNames ctx.univs[i.toNat] + else .zero + +def decodeBlobNat (bytes : ByteArray) : Nat := Id.run do + let mut acc := 0 + for i in [:bytes.size] do + acc := acc + bytes[i]!.toNat * 256 ^ i + return acc + +def decodeBlobStr (bytes : ByteArray) : String := + String.fromUTF8! bytes + +/-- Look up an arena node by index, automatically unwrapping `.mdata` wrappers. -/ +partial def getArenaNode (idx : Option UInt64) : ConvertM m (Option Ixon.ExprMetaData) := do + match idx with + | none => return none + | some i => + let ctx ← read + if h : i.toNat < ctx.arena.nodes.size + then match ctx.arena.nodes[i.toNat] with + | .mdata _ child => getArenaNode (some child) + | node => return some node + else return none + +def mkMetaName (m : MetaMode) (name? : Option Ix.Name) : MetaField m Ix.Name := + match m with + | .meta => name?.getD default + | .anon => () + +/-- Resolve a name hash Address to a MetaField name via the names table. -/ +def resolveName (nameAddr : Address) : ConvertM m (MetaField m Ix.Name) := do + let ctx ← read + match ctx.names.get? nameAddr with + | some name => return (mkMetaName m (some name)) + | none => throw (.missingName nameAddr) + +partial def convertExpr (m : MetaMode) (expr : Ixon.Expr) (metaIdx : Option UInt64 := none) + : ConvertM m (Expr m) := do + -- 1. Expand share transparently, passing arena index through (same as DecompileM) + match expr with + | .share idx => + let ctx ← read + if h : idx.toNat < ctx.sharing.size then + convertExpr m ctx.sharing[idx.toNat] metaIdx + else return default + | _ => + + -- 1b. Handle .var before cache (binder names are context-dependent) + if let .var idx := expr then + let name := match (← read).binderNames[idx.toNat]? with + | some n => n | none => default + return (.bvar idx.toNat name) + + -- 2. Check cache (keyed on expression hash + arena index) + let cacheKey := (hash expr, metaIdx.getD UInt64.MAX) + if let some cached := (← get).exprCache.get? cacheKey then return cached + + -- 3. Resolve arena node + let node ← getArenaNode metaIdx + + -- 4. Convert expression + let result ← match expr with + | .sort idx => do + let ctx ← read + if h : idx.toNat < ctx.univs.size + then pure (.sort (convertUniv m ctx.levelParamNames ctx.univs[idx.toNat])) + else pure (.sort .zero) + | .var _ => pure default -- unreachable, handled above + | .ref refIdx univIdxs => do + let ctx ← read + let levels ← resolveUnivs m univIdxs + let addr ← match ctx.refs[refIdx.toNat]? with + | some a => pure a + | none => throw (.refOutOfBounds refIdx ctx.refs.size) + let name ← match node with + | some (.ref nameAddr) => resolveName nameAddr + | _ => pure default + pure (.const addr levels name) + | .recur recIdx univIdxs => do + let ctx ← read + let levels ← resolveUnivs m univIdxs + let addr ← match ctx.recurAddrs[recIdx.toNat]? with + | some a => pure a + | none => throw (.recurOutOfBounds recIdx ctx.recurAddrs.size) + let name ← match node with + | some (.ref nameAddr) => resolveName nameAddr + | _ => pure default + pure (.const addr levels name) + | .prj typeRefIdx fieldIdx struct => do + let ctx ← read + let typeAddr ← match ctx.refs[typeRefIdx.toNat]? with + | some a => pure a + | none => throw (.prjRefOutOfBounds typeRefIdx ctx.refs.size) + let (structChild, typeName) ← match node with + | some (.prj structNameAddr child) => do + let n ← resolveName structNameAddr + pure (some child, n) + | _ => pure (none, default) + let s ← convertExpr m struct structChild + pure (.proj typeAddr fieldIdx.toNat s typeName) + | .str blobRefIdx => do + let ctx ← read + if h : blobRefIdx.toNat < ctx.refs.size then + let blobAddr := ctx.refs[blobRefIdx.toNat] + match ctx.blobs.get? blobAddr with + | some bytes => pure (.lit (.strVal (decodeBlobStr bytes))) + | none => pure (.lit (.strVal "")) + else pure (.lit (.strVal "")) + | .nat blobRefIdx => do + let ctx ← read + if h : blobRefIdx.toNat < ctx.refs.size then + let blobAddr := ctx.refs[blobRefIdx.toNat] + match ctx.blobs.get? blobAddr with + | some bytes => pure (.lit (.natVal (decodeBlobNat bytes))) + | none => pure (.lit (.natVal 0)) + else pure (.lit (.natVal 0)) + | .app fn arg => do + let (fnChild, argChild) := match node with + | some (.app f a) => (some f, some a) + | _ => (none, none) + let f ← convertExpr m fn fnChild + let a ← convertExpr m arg argChild + pure (.app f a) + | .lam ty body => do + let (name, bi, tyChild, bodyChild) ← match node with + | some (.binder nameAddr info tyC bodyC) => do + let n ← resolveName nameAddr + let i : MetaField m Lean.BinderInfo := match m with | .meta => info | .anon => () + pure (n, i, some tyC, some bodyC) + | _ => pure (default, default, none, none) + let t ← convertExpr m ty tyChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.lam t b name bi) + | .all ty body => do + let (name, bi, tyChild, bodyChild) ← match node with + | some (.binder nameAddr info tyC bodyC) => do + let n ← resolveName nameAddr + let i : MetaField m Lean.BinderInfo := match m with | .meta => info | .anon => () + pure (n, i, some tyC, some bodyC) + | _ => pure (default, default, none, none) + let t ← convertExpr m ty tyChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.forallE t b name bi) + | .letE _nonDep ty val body => do + let (name, tyChild, valChild, bodyChild) ← match node with + | some (.letBinder nameAddr tyC valC bodyC) => do + let n ← resolveName nameAddr + pure (n, some tyC, some valC, some bodyC) + | _ => pure (default, none, none, none) + let t ← convertExpr m ty tyChild + let v ← convertExpr m val valChild + let b ← withReader (fun env => { env with binderNames := name :: env.binderNames }) + (convertExpr m body bodyChild) + pure (.letE t v b name) + | .share _ => pure default -- unreachable, handled above + + -- 5. Cache and return + modify fun s => { s with exprCache := s.exprCache.insert cacheKey result } + pure result + +/-! ## Enum conversions -/ + +def convertHints : Lean.ReducibilityHints → ReducibilityHints + | .opaque => .opaque + | .abbrev => .abbrev + | .regular h => .regular h + +def convertSafety : Ix.DefinitionSafety → DefinitionSafety + | .unsaf => .unsafe + | .safe => .safe + | .part => .partial + +def convertQuotKind : Ix.QuotKind → QuotKind + | .type => .type + | .ctor => .ctor + | .lift => .lift + | .ind => .ind + +/-! ## Constant conversion helpers -/ + +def mkConvertEnv (m : MetaMode) (c : Constant) (blobs : Std.HashMap Address ByteArray) + (recurAddrs : Array Address := #[]) + (arena : Ixon.ExprMetaArena := {}) + (names : Std.HashMap Address Ix.Name := {}) + (levelParamNames : Array (MetaField m Ix.Name) := #[]) : ConvertEnv m := + { sharing := c.sharing, refs := c.refs, univs := c.univs, blobs, recurAddrs, arena, names, + levelParamNames } + +def mkConstantVal (m : MetaMode) (numLvls : UInt64) (typ : Expr m) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) : ConstantVal m := + { numLevels := numLvls.toNat, type := typ, name, levelParams } + +/-! ## Factored constant conversion helpers -/ + +/-- Extract arena from ConstantMeta. -/ +def metaArena : ConstantMeta → Ixon.ExprMetaArena + | .defn _ _ _ _ _ a _ _ => a + | .axio _ _ a _ => a + | .quot _ _ a _ => a + | .indc _ _ _ _ _ a _ => a + | .ctor _ _ _ a _ => a + | .recr _ _ _ _ _ a _ _ => a + | .empty => {} + +/-- Extract type root index from ConstantMeta. -/ +def metaTypeRoot? : ConstantMeta → Option UInt64 + | .defn _ _ _ _ _ _ r _ => some r + | .axio _ _ _ r => some r + | .quot _ _ _ r => some r + | .indc _ _ _ _ _ _ r => some r + | .ctor _ _ _ _ r => some r + | .recr _ _ _ _ _ _ r _ => some r + | .empty => none + +/-- Extract value root index from ConstantMeta (defn only). -/ +def metaValueRoot? : ConstantMeta → Option UInt64 + | .defn _ _ _ _ _ _ _ r => some r + | .empty => none + | _ => none + +/-- Extract level param name addresses from ConstantMeta. -/ +def metaLvlAddrs : ConstantMeta → Array Address + | .defn _ lvls _ _ _ _ _ _ => lvls + | .axio _ lvls _ _ => lvls + | .quot _ lvls _ _ => lvls + | .indc _ lvls _ _ _ _ _ => lvls + | .ctor _ lvls _ _ _ => lvls + | .recr _ lvls _ _ _ _ _ _ => lvls + | .empty => #[] + +/-- Resolve level param addresses to MetaField names via the names table. -/ +def resolveLevelParams (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (lvlAddrs : Array Address) : Array (MetaField m Ix.Name) := + match m with + | .anon => lvlAddrs.map fun _ => () + | .meta => lvlAddrs.map fun addr => names.getD addr default + +/-- Build the MetaField levelParams value from resolved names. -/ +def mkLevelParams (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (lvlAddrs : Array Address) : MetaField m (Array Ix.Name) := + match m with + | .anon => () + | .meta => lvlAddrs.map fun addr => names.getD addr default + +/-- Resolve an array of name-hash addresses to a MetaField array of names. -/ +def resolveMetaNames (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (addrs : Array Address) : MetaField m (Array Ix.Name) := + match m with | .anon => () | .meta => addrs.map fun a => names.getD a default + +/-- Resolve a single name-hash address to a MetaField name. -/ +def resolveMetaName (m : MetaMode) (names : Std.HashMap Address Ix.Name) + (addr : Address) : MetaField m Ix.Name := + match m with | .anon => () | .meta => names.getD addr default + +/-- Extract rule root indices from ConstantMeta (recr only). -/ +def metaRuleRoots : ConstantMeta → Array UInt64 + | .recr _ _ _ _ _ _ _ rs => rs + | _ => #[] + +def convertRule (m : MetaMode) (rule : Ixon.RecursorRule) (ctorAddr : Address) + (ctorName : MetaField m Ix.Name := default) + (ruleRoot : Option UInt64 := none) : + ConvertM m (Ix.Kernel.RecursorRule m) := do + let rhs ← convertExpr m rule.rhs ruleRoot + return { ctor := ctorAddr, ctorName, nfields := rule.fields.toNat, rhs } + +def convertDefinition (m : MetaMode) (d : Ixon.Definition) + (hints : ReducibilityHints) (all : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m d.typ (metaTypeRoot? cMeta) + let value ← convertExpr m d.value (metaValueRoot? cMeta) + let cv := mkConstantVal m d.lvls typ name levelParams + match d.kind with + | .defn => return .defnInfo { toConstantVal := cv, value, hints, safety := convertSafety d.safety, all, allNames } + | .opaq => return .opaqueInfo { toConstantVal := cv, value, isUnsafe := d.safety == .unsaf, all, allNames } + | .thm => return .thmInfo { toConstantVal := cv, value, all, allNames } + +def convertAxiom (m : MetaMode) (a : Ixon.Axiom) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m a.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m a.lvls typ name levelParams + return .axiomInfo { toConstantVal := cv, isUnsafe := a.isUnsafe } + +def convertQuotient (m : MetaMode) (q : Ixon.Quotient) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m q.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m q.lvls typ name levelParams + return .quotInfo { toConstantVal := cv, kind := convertQuotKind q.kind } + +def convertInductive (m : MetaMode) (ind : Ixon.Inductive) + (ctorAddrs all : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) + (ctorNames : MetaField m (Array Ix.Name) := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m ind.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m ind.lvls typ name levelParams + let v : Ix.Kernel.InductiveVal m := + { toConstantVal := cv, numParams := ind.params.toNat, + numIndices := ind.indices.toNat, all, ctors := ctorAddrs, allNames, ctorNames, + numNested := ind.nested.toNat, isRec := ind.recr, isUnsafe := ind.isUnsafe, + isReflexive := ind.refl } + return .inductInfo v + +def convertConstructor (m : MetaMode) (c : Ixon.Constructor) + (inductAddr : Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (inductName : MetaField m Ix.Name := default) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m c.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m c.lvls typ name levelParams + let v : Ix.Kernel.ConstructorVal m := + { toConstantVal := cv, induct := inductAddr, inductName, + cidx := c.cidx.toNat, numParams := c.params.toNat, numFields := c.fields.toNat, + isUnsafe := c.isUnsafe } + return .ctorInfo v + +def convertRecursor (m : MetaMode) (r : Ixon.Recursor) + (all ruleCtorAddrs : Array Address) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (allNames : MetaField m (Array Ix.Name) := default) + (ruleCtorNames : Array (MetaField m Ix.Name) := #[]) : ConvertM m (Ix.Kernel.ConstantInfo m) := do + let typ ← convertExpr m r.typ (metaTypeRoot? cMeta) + let cv := mkConstantVal m r.lvls typ name levelParams + let ruleRoots := (metaRuleRoots cMeta) + let mut rules : Array (Ix.Kernel.RecursorRule m) := #[] + for i in [:r.rules.size] do + let ctorAddr := if h : i < ruleCtorAddrs.size then ruleCtorAddrs[i] else default + let ctorName := if h : i < ruleCtorNames.size then ruleCtorNames[i] else default + let ruleRoot := if h : i < ruleRoots.size then some ruleRoots[i] else none + rules := rules.push (← convertRule m r.rules[i]! ctorAddr ctorName ruleRoot) + let v : Ix.Kernel.RecursorVal m := + { toConstantVal := cv, all, allNames, + numParams := r.params.toNat, numIndices := r.indices.toNat, + numMotives := r.motives.toNat, numMinors := r.minors.toNat, + rules, k := r.k, isUnsafe := r.isUnsafe } + return .recInfo v + +/-! ## Metadata helpers -/ + +/-- Build a direct name-hash Address → constant Address lookup table. -/ +def buildHashToAddr (ixonEnv : Ixon.Env) : Std.HashMap Address Address := Id.run do + let mut acc : Std.HashMap Address Address := {} + for (nameHash, name) in ixonEnv.names do + match ixonEnv.named.get? name with + | some entry => acc := acc.insert nameHash entry.addr + | none => pure () + return acc + +/-- Extract block address from a projection constant, if it is one. -/ +def projBlockAddr : Ixon.ConstantInfo → Option Address + | .iPrj prj => some prj.block + | .cPrj prj => some prj.block + | .rPrj prj => some prj.block + | .dPrj prj => some prj.block + | _ => none + +/-! ## BlockIndex -/ + +/-- Cross-reference index for projections within a single muts block. + Built from the block group before conversion so we can derive addresses + without relying on metadata. -/ +structure BlockIndex where + /-- memberIdx → iPrj address (inductive type address) -/ + inductAddrs : Std.HashMap UInt64 Address := {} + /-- memberIdx → Array of cPrj addresses, ordered by cidx -/ + ctorAddrs : Std.HashMap UInt64 (Array Address) := {} + /-- All iPrj addresses in the block (the `all` array for inductives/recursors) -/ + allInductAddrs : Array Address := #[] + /-- memberIdx → primary projection address (for .recur resolution). + iPrj for inductives, dPrj for definitions. -/ + memberAddrs : Std.HashMap UInt64 Address := {} + +/-- Build a BlockIndex from a group of projections. -/ +def buildBlockIndex (projections : Array (Address × Constant)) : BlockIndex := Id.run do + let mut inductAddrs : Std.HashMap UInt64 Address := {} + let mut ctorEntries : Std.HashMap UInt64 (Array (UInt64 × Address)) := {} + let mut allInductAddrs : Array Address := #[] + let mut memberAddrs : Std.HashMap UInt64 Address := {} + for (addr, projConst) in projections do + match projConst.info with + | .iPrj prj => + inductAddrs := inductAddrs.insert prj.idx addr + allInductAddrs := allInductAddrs.push addr + memberAddrs := memberAddrs.insert prj.idx addr + | .cPrj prj => + let entries := ctorEntries.getD prj.idx #[] + ctorEntries := ctorEntries.insert prj.idx (entries.push (prj.cidx, addr)) + | .dPrj prj => + memberAddrs := memberAddrs.insert prj.idx addr + | .rPrj prj => + -- Only set if no iPrj/dPrj already set for this member + if !memberAddrs.contains prj.idx then + memberAddrs := memberAddrs.insert prj.idx addr + | _ => pure () + -- Sort constructor entries by cidx and extract just addresses + let mut ctorAddrs : Std.HashMap UInt64 (Array Address) := {} + for (idx, entries) in ctorEntries do + let sorted := entries.insertionSort (fun a b => a.1 < b.1) + ctorAddrs := ctorAddrs.insert idx (sorted.map (·.2)) + { inductAddrs, ctorAddrs, allInductAddrs, memberAddrs } + +/-- All constructor addresses in declaration order (by inductive member index, then cidx). + This matches the order of RecursorVal.rules in the Lean kernel. -/ +def BlockIndex.allCtorAddrsInOrder (bIdx : BlockIndex) : Array Address := Id.run do + let sorted := bIdx.inductAddrs.toArray.insertionSort (fun a b => a.1 < b.1) + let mut result : Array Address := #[] + for (idx, _) in sorted do + result := result ++ (bIdx.ctorAddrs.getD idx #[]) + result + +/-- Build recurAddrs array from BlockIndex. Maps member index → projection address. -/ +def buildRecurAddrs (bIdx : BlockIndex) (numMembers : Nat) : Except ConvertError (Array Address) := do + let mut addrs : Array Address := #[] + for i in [:numMembers] do + match bIdx.memberAddrs.get? i.toUInt64 with + | some addr => addrs := addrs.push addr + | none => throw (.missingMemberAddr i numMembers) + return addrs + +/-! ## Projection conversion -/ + +/-- Convert a single projection constant as a ConvertM action. + Uses BlockIndex for cross-references instead of metadata. -/ +def convertProjAction (m : MetaMode) + (addr : Address) (c : Constant) + (blockConst : Constant) (bIdx : BlockIndex) + (name : MetaField m Ix.Name := default) + (levelParams : MetaField m (Array Ix.Name) := default) + (cMeta : ConstantMeta := .empty) + (names : Std.HashMap Address Ix.Name := {}) + : Except String (ConvertM m (Ix.Kernel.ConstantInfo m)) := do + let .muts members := blockConst.info + | .error s!"projection block is not a muts at {addr}" + match c.info with + | .iPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .indc ind => + let ctorAs := bIdx.ctorAddrs.getD prj.idx #[] + let allNs := resolveMetaNames m names (match cMeta with | .indc _ _ _ a _ _ _ => a | _ => #[]) + let ctorNs := resolveMetaNames m names (match cMeta with | .indc _ _ c _ _ _ _ => c | _ => #[]) + .ok (convertInductive m ind ctorAs bIdx.allInductAddrs name levelParams cMeta allNs ctorNs) + | _ => .error s!"iPrj at {addr} does not point to an inductive" + else .error s!"iPrj index out of bounds at {addr}" + | .cPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .indc ind => + if h2 : prj.cidx.toNat < ind.ctors.size then + let ctor := ind.ctors[prj.cidx.toNat] + let inductAddr := bIdx.inductAddrs.getD prj.idx default + let inductNm := resolveMetaName m names (match cMeta with | .ctor _ _ i _ _ => i | _ => default) + .ok (convertConstructor m ctor inductAddr name levelParams cMeta inductNm) + else .error s!"cPrj cidx out of bounds at {addr}" + | _ => .error s!"cPrj at {addr} does not point to an inductive" + else .error s!"cPrj index out of bounds at {addr}" + | .rPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .recr r => + let ruleCtorAs := bIdx.allCtorAddrsInOrder + let allNs := resolveMetaNames m names (match cMeta with | .recr _ _ _ a _ _ _ _ => a | _ => #[]) + let metaRules := match cMeta with | .recr _ _ rules _ _ _ _ _ => rules | _ => #[] + let ruleCtorNs := metaRules.map fun x => resolveMetaName m names x + .ok (convertRecursor m r bIdx.allInductAddrs ruleCtorAs name levelParams cMeta allNs ruleCtorNs) + | _ => .error s!"rPrj at {addr} does not point to a recursor" + else .error s!"rPrj index out of bounds at {addr}" + | .dPrj prj => + if h : prj.idx.toNat < members.size then + match members[prj.idx.toNat] with + | .defn d => + let hints := match cMeta with + | .defn _ _ h _ _ _ _ _ => convertHints h + | _ => .opaque + let allNs := resolveMetaNames m names (match cMeta with | .defn _ _ _ a _ _ _ _ => a | _ => #[]) + .ok (convertDefinition m d hints bIdx.allInductAddrs name levelParams cMeta allNs) + | _ => .error s!"dPrj at {addr} does not point to a definition" + else .error s!"dPrj index out of bounds at {addr}" + | _ => .error s!"not a projection at {addr}" + +/-! ## Work items -/ + +/-- An entry to convert: address, constant, name, and metadata. -/ +structure ConvertEntry (m : MetaMode) where + addr : Address + const : Constant + name : MetaField m Ix.Name + constMeta : ConstantMeta + +/-- A work item: either a standalone constant or a complete block group. -/ +inductive WorkItem (m : MetaMode) where + | standalone (entry : ConvertEntry m) + | block (blockAddr : Address) (entries : Array (ConvertEntry m)) + +/-- Extract ctx addresses from ConstantMeta (mutual context for .recur resolution). -/ +def metaCtxAddrs : ConstantMeta → Array Address + | .defn _ _ _ _ ctx .. => ctx + | .indc _ _ _ _ ctx .. => ctx + | .recr _ _ _ _ ctx .. => ctx + | _ => #[] + +/-- Extract parent inductive name-hash address from ConstantMeta (ctor only). -/ +def metaInductAddr : ConstantMeta → Address + | .ctor _ _ induct _ _ => induct + | _ => default + +/-- Resolve ctx name-hash addresses to constant addresses for recurAddrs. -/ +def resolveCtxAddrs (hashToAddr : Std.HashMap Address Address) (ctx : Array Address) + : Except ConvertError (Array Address) := + ctx.mapM fun x => + match hashToAddr.get? x with + | some addr => .ok addr + | none => .error (.unresolvableCtxAddr x) + +/-- Convert a standalone (non-projection) constant. -/ +def convertStandalone (m : MetaMode) (hashToAddr : Std.HashMap Address Address) + (ixonEnv : Ixon.Env) (entry : ConvertEntry m) : + Except String (Option (Ix.Kernel.ConstantInfo m)) := do + let cMeta := entry.constMeta + let recurAddrs ← (resolveCtxAddrs hashToAddr (metaCtxAddrs cMeta)).mapError toString + let lvlNames := resolveLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let lps := mkLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let cEnv := mkConvertEnv m entry.const ixonEnv.blobs + (recurAddrs := recurAddrs) (arena := (metaArena cMeta)) (names := ixonEnv.names) + (levelParamNames := lvlNames) + match entry.const.info with + | .defn d => + let hints := match cMeta with + | .defn _ _ h _ _ _ _ _ => convertHints h + | _ => .opaque + let allHashAddrs := match cMeta with + | .defn _ _ _ a _ _ _ _ => a + | _ => #[] + let all := allHashAddrs.map fun x => hashToAddr.getD x x + let allNames := resolveMetaNames m ixonEnv.names allHashAddrs + let ci ← (ConvertM.run cEnv (convertDefinition m d hints all entry.name lps cMeta allNames)).mapError toString + return some ci + | .axio a => + let ci ← (ConvertM.run cEnv (convertAxiom m a entry.name lps cMeta)).mapError toString + return some ci + | .quot q => + let ci ← (ConvertM.run cEnv (convertQuotient m q entry.name lps cMeta)).mapError toString + return some ci + | .recr r => + let pair : Array Address × Array Address := match cMeta with + | .recr _ _ rules all _ _ _ _ => (all, rules) + | _ => (#[entry.addr], #[]) + let (metaAll, metaRules) := pair + let all := metaAll.map fun x => hashToAddr.getD x x + let ruleCtorAddrs := metaRules.map fun x => hashToAddr.getD x x + let allNames := resolveMetaNames m ixonEnv.names metaAll + let ruleCtorNames := metaRules.map fun x => resolveMetaName m ixonEnv.names x + let ci ← (ConvertM.run cEnv (convertRecursor m r all ruleCtorAddrs entry.name lps cMeta allNames ruleCtorNames)).mapError toString + return some ci + | .muts _ => return none + | _ => return none -- projections handled separately + +/-- Convert a complete block group (all projections share cache + recurAddrs). -/ +def convertWorkBlock (m : MetaMode) + (ixonEnv : Ixon.Env) (blockAddr : Address) + (entries : Array (ConvertEntry m)) + (results : Array (Address × Ix.Kernel.ConstantInfo m)) (errors : Array (Address × String)) + : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do + let mut results := results + let mut errors := errors + match ixonEnv.getConst? blockAddr with + | some blockConst => + -- Dedup projections by address for buildBlockIndex (avoid duplicate allInductAddrs) + let mut canonicalProjs : Array (Address × Constant) := #[] + let mut seenAddrs : Std.HashSet Address := {} + for e in entries do + if !seenAddrs.contains e.addr then + canonicalProjs := canonicalProjs.push (e.addr, e.const) + seenAddrs := seenAddrs.insert e.addr + let bIdx := buildBlockIndex canonicalProjs + let numMembers := match blockConst.info with + | .muts members => members.size + | _ => 0 + let recurAddrs ← match buildRecurAddrs bIdx numMembers with + | .ok addrs => pure addrs + | .error e => + for entry in entries do + errors := errors.push (entry.addr, toString e) + return (results, errors) + -- Base env (no arena/levelParamNames — each projection sets its own) + let baseEnv := mkConvertEnv m blockConst ixonEnv.blobs recurAddrs (names := ixonEnv.names) + let mut state := ConvertState.init baseEnv + let shareCache := match m with | .anon => true | .meta => false + for entry in entries do + if !shareCache then + state := ConvertState.init baseEnv + let cMeta := entry.constMeta + let lvlNames := resolveLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let lps := mkLevelParams m ixonEnv.names (metaLvlAddrs cMeta) + let cEnv := { baseEnv with arena := (metaArena cMeta), levelParamNames := lvlNames } + match convertProjAction m entry.addr entry.const blockConst bIdx entry.name lps cMeta ixonEnv.names with + | .ok action => + match ConvertM.runWith cEnv state action with + | .ok (ci, state') => + state := state' + results := results.push (entry.addr, ci) + | .error e => + errors := errors.push (entry.addr, toString e) + | .error e => errors := errors.push (entry.addr, e) + | none => + for entry in entries do + errors := errors.push (entry.addr, s!"block not found: {blockAddr}") + (results, errors) + +/-- Convert a chunk of work items. -/ +def convertChunk (m : MetaMode) (hashToAddr : Std.HashMap Address Address) + (ixonEnv : Ixon.Env) (chunk : Array (WorkItem m)) + : Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String) := Id.run do + let mut results : Array (Address × Ix.Kernel.ConstantInfo m) := #[] + let mut errors : Array (Address × String) := #[] + for item in chunk do + match item with + | .standalone entry => + match convertStandalone m hashToAddr ixonEnv entry with + | .ok (some ci) => results := results.push (entry.addr, ci) + | .ok none => pure () + | .error e => errors := errors.push (entry.addr, e) + | .block blockAddr entries => + (results, errors) := convertWorkBlock m ixonEnv blockAddr entries results errors + (results, errors) + +/-! ## Top-level conversion -/ + +/-- Convert an entire Ixon.Env to a Kernel.Env with primitives and quotInit flag. + Iterates named constants first (with full metadata), then picks up anonymous + constants not in named. Groups projections by block and parallelizes. -/ +def convertEnv (m : MetaMode) (ixonEnv : Ixon.Env) (numWorkers : Nat := 32) + : Except String (Ix.Kernel.Env m × Primitives × Bool) := + -- Build primitives with quot addresses + let prims : Primitives := Id.run do + let mut p := buildPrimitives + for (addr, c) in ixonEnv.consts do + match c.info with + | .quot q => match q.kind with + | .type => p := { p with quotType := addr } + | .ctor => p := { p with quotCtor := addr } + | .lift => p := { p with quotLift := addr } + | .ind => p := { p with quotInd := addr } + | _ => pure () + return p + let quotInit := Id.run do + for (_, c) in ixonEnv.consts do + if let .quot _ := c.info then return true + return false + let hashToAddr := buildHashToAddr ixonEnv + let (constants, allErrors) := Id.run do + -- Phase 1: Build entries from named constants (have names + metadata) + let mut entries : Array (ConvertEntry m) := #[] + let mut seen : Std.HashSet Address := {} + for (ixName, named) in ixonEnv.named do + let addr := named.addr + match ixonEnv.consts.get? addr with + | some c => + let name := mkMetaName m (some ixName) + entries := entries.push { addr, const := c, name, constMeta := named.constMeta } + seen := seen.insert addr + | none => pure () + -- Phase 2: Pick up anonymous constants not covered by named + for (addr, c) in ixonEnv.consts do + if !seen.contains addr then + entries := entries.push { addr, const := c, name := default, constMeta := .empty } + -- Phase 2.5: In .anon mode, dedup all entries by address (copies identical). + -- In .meta mode, keep all entries (named variants have distinct metadata). + let shouldDedup := match m with | .anon => true | .meta => false + if shouldDedup then + let mut dedupedEntries : Array (ConvertEntry m) := #[] + let mut seenDedup : Std.HashSet Address := {} + for entry in entries do + if !seenDedup.contains entry.addr then + dedupedEntries := dedupedEntries.push entry + seenDedup := seenDedup.insert entry.addr + entries := dedupedEntries + -- Phase 3: Group into standalones and block groups + -- Use (blockAddr, ctxKey) to disambiguate colliding block addresses + let mut standalones : Array (ConvertEntry m) := #[] + -- Pass 1: Build nameHash → ctx map from entries with ctx + let mut nameHashToCtx : Std.HashMap Address (Array Address) := {} + let mut projEntries : Array (Address × ConvertEntry m) := #[] + for entry in entries do + match projBlockAddr entry.const.info with + | some blockAddr => + projEntries := projEntries.push (blockAddr, entry) + let ctx := metaCtxAddrs entry.constMeta + if ctx.size > 0 then + for nameHash in ctx do + nameHashToCtx := nameHashToCtx.insert nameHash ctx + | none => standalones := standalones.push entry + -- Pass 2: Group by (blockAddr, ctxKey) to avoid collisions + let mut blockGroups : Std.HashMap (Address × UInt64) (Array (ConvertEntry m)) := {} + for (blockAddr, entry) in projEntries do + let ctx0 := metaCtxAddrs entry.constMeta + let ctx := if ctx0.size > 0 then ctx0 + else nameHashToCtx.getD (metaInductAddr entry.constMeta) #[] + let ctxKey := hash ctx + let key := (blockAddr, ctxKey) + blockGroups := blockGroups.insert key + ((blockGroups.getD key #[]).push entry) + -- Phase 4: Build work items + let mut workItems : Array (WorkItem m) := #[] + for entry in standalones do + workItems := workItems.push (.standalone entry) + for ((blockAddr, _), blockEntries) in blockGroups do + workItems := workItems.push (.block blockAddr blockEntries) + -- Phase 5: Chunk work items and parallelize + let total := workItems.size + let chunkSize := (total + numWorkers - 1) / numWorkers + let mut tasks : Array (Task (Array (Address × Ix.Kernel.ConstantInfo m) × Array (Address × String))) := #[] + let mut offset := 0 + while offset < total do + let endIdx := min (offset + chunkSize) total + let chunk := workItems[offset:endIdx] + let task := Task.spawn (prio := .dedicated) fun () => + convertChunk m hashToAddr ixonEnv chunk.toArray + tasks := tasks.push task + offset := endIdx + -- Phase 6: Collect results + let mut constants : Ix.Kernel.Env m := default + let mut allErrors : Array (Address × String) := #[] + for task in tasks do + let (chunkResults, chunkErrors) := task.get + for (addr, ci) in chunkResults do + constants := constants.insert addr ci + allErrors := allErrors ++ chunkErrors + (constants, allErrors) + if !allErrors.isEmpty then + let msgs := allErrors[:min 10 allErrors.size].toArray.map fun (addr, e) => s!" {addr}: {e}" + .error s!"conversion errors ({allErrors.size}):\n{"\n".intercalate msgs.toList}" + else + .ok (constants, prims, quotInit) + +/-- Convert an Ixon.Env to a Kernel.Env with full metadata. -/ +def convert (ixonEnv : Ixon.Env) : Except String (Ix.Kernel.Env .meta × Primitives × Bool) := + convertEnv .meta ixonEnv + +/-- Convert an Ixon.Env to a Kernel.Env without metadata. -/ +def convertAnon (ixonEnv : Ixon.Env) : Except String (Ix.Kernel.Env .anon × Primitives × Bool) := + convertEnv .anon ixonEnv + +end Ix.Kernel.Convert diff --git a/Ix/Kernel/Datatypes.lean b/Ix/Kernel/Datatypes.lean new file mode 100644 index 00000000..d94d8701 --- /dev/null +++ b/Ix/Kernel/Datatypes.lean @@ -0,0 +1,181 @@ +/- + Kernel Datatypes: Value, Neutral, SusValue, TypedExpr, Env, TypedConst. + + Closure-based semantic domain for NbE typechecking. + Parameterized over MetaMode for compile-time metadata erasure. +-/ +import Ix.Kernel.Types + +namespace Ix.Kernel + +/-! ## TypeInfo -/ + +inductive TypeInfo (m : MetaMode) where + | unit | proof | none + | sort : Level m → TypeInfo m + deriving Inhabited + +/-! ## AddInfo -/ + +structure AddInfo (Info Body : Type) where + info : Info + body : Body + deriving Inhabited + +/-! ## Forward declarations for mutual types -/ + +abbrev TypedExpr (m : MetaMode) := AddInfo (TypeInfo m) (Expr m) + +/-! ## Value / Neutral / SusValue -/ + +mutual + inductive Value (m : MetaMode) where + | sort : Level m → Value m + | app : Neutral m → List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (TypeInfo m) → Value m + | lam : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m + → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m + | pi : AddInfo (TypeInfo m) (Thunk (Value m)) → TypedExpr m → ValEnv m + → MetaField m Ix.Name → MetaField m Lean.BinderInfo → Value m + | lit : Lean.Literal → Value m + | exception : String → Value m + + inductive Neutral (m : MetaMode) where + | fvar : Nat → MetaField m Ix.Name → Neutral m + | const : Address → Array (Level m) → MetaField m Ix.Name → Neutral m + | proj : Address → Nat → AddInfo (TypeInfo m) (Value m) → MetaField m Ix.Name → Neutral m + + inductive ValEnv (m : MetaMode) where + | mk : List (AddInfo (TypeInfo m) (Thunk (Value m))) → List (Level m) → ValEnv m +end + +instance : Inhabited (Value m) where default := .exception "uninit" +instance : Inhabited (Neutral m) where default := .fvar 0 default +instance : Inhabited (ValEnv m) where default := .mk [] [] + +abbrev SusValue (m : MetaMode) := AddInfo (TypeInfo m) (Thunk (Value m)) + +instance : Inhabited (SusValue m) where + default := .mk default { fn := fun _ => default } + +/-! ## TypedConst -/ + +inductive TypedConst (m : MetaMode) where + | «axiom» : (type : TypedExpr m) → TypedConst m + | «theorem» : (type value : TypedExpr m) → TypedConst m + | «inductive» : (type : TypedExpr m) → (struct : Bool) → TypedConst m + | «opaque» : (type value : TypedExpr m) → TypedConst m + | definition : (type value : TypedExpr m) → (part : Bool) → TypedConst m + | constructor : (type : TypedExpr m) → (idx fields : Nat) → TypedConst m + | recursor : (type : TypedExpr m) → (params motives minors indices : Nat) → (k : Bool) + → (indAddr : Address) → (rules : Array (Nat × TypedExpr m)) → TypedConst m + | quotient : (type : TypedExpr m) → (kind : QuotKind) → TypedConst m + deriving Inhabited + +def TypedConst.type : TypedConst m → TypedExpr m + | «axiom» type .. + | «theorem» type .. + | «inductive» type .. + | «opaque» type .. + | definition type .. + | constructor type .. + | recursor type .. + | quotient type .. => type + +/-! ## Accessors -/ + +namespace AddInfo + +def expr (t : TypedExpr m) : Expr m := t.body +def thunk (sus : SusValue m) : Thunk (Value m) := sus.body +def get (sus : SusValue m) : Value m := sus.body.get +def getTyped (sus : SusValue m) : AddInfo (TypeInfo m) (Value m) := ⟨sus.info, sus.body.get⟩ +def value (val : AddInfo (TypeInfo m) (Value m)) : Value m := val.body +def sus (val : AddInfo (TypeInfo m) (Value m)) : SusValue m := ⟨val.info, val.body⟩ + +end AddInfo + +/-! ## TypedExpr helpers -/ + +partial def TypedExpr.toImplicitLambda : TypedExpr m → TypedExpr m + | .mk _ (.lam _ body _ _) => toImplicitLambda ⟨default, body⟩ + | x => x + +/-! ## Value helpers -/ + +def Value.neu (n : Neutral m) : Value m := .app n [] [] + +def Value.ctorName : Value m → String + | .sort .. => "sort" + | .app .. => "app" + | .lam .. => "lam" + | .pi .. => "pi" + | .lit .. => "lit" + | .exception .. => "exception" + +def Neutral.summary : Neutral m → String + | .fvar idx name => s!"fvar({name}, {idx})" + | .const addr _ name => s!"const({name}, {addr})" + | .proj _ idx _ name => s!"proj({name}, {idx})" + +def Value.summary : Value m → String + | .sort _ => "Sort" + | .app neu args _ => s!"{neu.summary} applied to {args.length} args" + | .lam .. => "lam" + | .pi .. => "Pi" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit(\"{s}\")" + | .exception e => s!"exception({e})" + +def TypeInfo.pp : TypeInfo m → String + | .unit => ".unit" + | .proof => ".proof" + | .none => ".none" + | .sort _ => ".sort" + +private def listGetOpt (l : List α) (i : Nat) : Option α := + match l, i with + | [], _ => none + | x :: _, 0 => some x + | _ :: xs, n+1 => listGetOpt xs n + +/-- Deep structural dump (one level into args) for debugging stuck terms. -/ +def Value.dump : Value m → String + | .sort _ => "Sort" + | .app neu args infos => + let argStrs := args.zipIdx.map fun (a, i) => + let info := match listGetOpt infos i with | some ti => TypeInfo.pp ti | none => "?" + s!" [{i}] info={info} val={a.get.summary}" + s!"{neu.summary} applied to {args.length} args:\n" ++ String.intercalate "\n" argStrs + | .lam dom _ _ _ _ => s!"lam(dom={dom.get.summary}, info={dom.info.pp})" + | .pi dom _ _ _ _ => s!"Pi(dom={dom.get.summary}, info={dom.info.pp})" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit(\"{s}\")" + | .exception e => s!"exception({e})" + +/-! ## ValEnv helpers -/ + +namespace ValEnv + +def exprs : ValEnv m → List (SusValue m) + | .mk es _ => es + +def univs : ValEnv m → List (Level m) + | .mk _ us => us + +def extendWith (env : ValEnv m) (thunk : SusValue m) : ValEnv m := + .mk (thunk :: env.exprs) env.univs + +def withExprs (env : ValEnv m) (exprs : List (SusValue m)) : ValEnv m := + .mk exprs env.univs + +end ValEnv + +/-! ## Smart constructors -/ + +def mkConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : Value m := + .neu (.const addr univs name) + +def mkSusVar (info : TypeInfo m) (idx : Nat) (name : MetaField m Ix.Name := default) : SusValue m := + .mk info (.mk fun _ => .neu (.fvar idx name)) + +end Ix.Kernel diff --git a/Ix/Kernel/DecompileM.lean b/Ix/Kernel/DecompileM.lean new file mode 100644 index 00000000..d52bda4a --- /dev/null +++ b/Ix/Kernel/DecompileM.lean @@ -0,0 +1,254 @@ +/- + Kernel DecompileM: Kernel.Expr/ConstantInfo → Lean.Expr/ConstantInfo decompilation. + + Used for roundtrip validation: Lean.Environment → Ixon.Env → Kernel.Env → Lean.ConstantInfo. + Comparing the roundtripped Lean.ConstantInfo against the original catches conversion bugs. +-/ +import Ix.Kernel.Types + +namespace Ix.Kernel.Decompile + +/-! ## Name conversion -/ + +/-- Convert Ix.Name to Lean.Name by stripping embedded hashes. -/ +def ixNameToLean : Ix.Name → Lean.Name + | .anonymous _ => .anonymous + | .str parent s _ => .str (ixNameToLean parent) s + | .num parent n _ => .num (ixNameToLean parent) n + +/-! ## Level conversion -/ + +/-- Convert a Kernel.Level back to Lean.Level. + Level param names are synthetic (`u_0`, `u_1`, ...) since Convert.lean + stores `default` for both param names and levelParams. -/ +partial def decompileLevel (levelParams : Array Ix.Name) : Level .meta → Lean.Level + | .zero => .zero + | .succ l => .succ (decompileLevel levelParams l) + | .max l₁ l₂ => .max (decompileLevel levelParams l₁) (decompileLevel levelParams l₂) + | .imax l₁ l₂ => .imax (decompileLevel levelParams l₁) (decompileLevel levelParams l₂) + | .param idx name => + let ixName := if name != default then name + else if h : idx < levelParams.size then levelParams[idx] + else Ix.Name.mkStr Ix.Name.mkAnon s!"u_{idx}" + .param (ixNameToLean ixName) + +/-! ## Expression conversion -/ + +@[inline] def kernelExprPtr (e : Expr .meta) : USize := unsafe ptrAddrUnsafe e + +/-- Convert a Kernel.Expr back to Lean.Expr with pointer-based caching. + Known lossy fields: + - `letE.nonDep` is always `true` (lost in Kernel conversion) + - Binder names/info come from metadata (may be `default` if missing) -/ +partial def decompileExprCached (levelParams : Array Ix.Name) (e : Expr .meta) + : StateM (Std.HashMap USize Lean.Expr) Lean.Expr := do + let ptr := kernelExprPtr e + if let some cached := (← get).get? ptr then return cached + let result ← match e with + | .bvar idx _ => pure (.bvar idx) + | .sort lvl => pure (.sort (decompileLevel levelParams lvl)) + | .const _addr levels name => + pure (.const (ixNameToLean name) (levels.toList.map (decompileLevel levelParams))) + | .app fn arg => do + let f ← decompileExprCached levelParams fn + let a ← decompileExprCached levelParams arg + pure (.app f a) + | .lam ty body name bi => do + let t ← decompileExprCached levelParams ty + let b ← decompileExprCached levelParams body + pure (.lam (ixNameToLean name) t b bi) + | .forallE ty body name bi => do + let t ← decompileExprCached levelParams ty + let b ← decompileExprCached levelParams body + pure (.forallE (ixNameToLean name) t b bi) + | .letE ty val body name => do + let t ← decompileExprCached levelParams ty + let v ← decompileExprCached levelParams val + let b ← decompileExprCached levelParams body + pure (.letE (ixNameToLean name) t v b true) + | .lit lit => pure (.lit lit) + | .proj _typeAddr idx struct typeName => do + let s ← decompileExprCached levelParams struct + pure (.proj (ixNameToLean typeName) idx s) + modify (·.insert ptr result) + pure result + +def decompileExpr (levelParams : Array Ix.Name) (e : Expr .meta) : Lean.Expr := + (decompileExprCached levelParams e |>.run {}).1 + +/-! ## ConstantInfo conversion -/ + +/-- Convert Kernel.DefinitionSafety to Lean.DefinitionSafety. -/ +def decompileSafety : DefinitionSafety → Lean.DefinitionSafety + | .safe => .safe + | .unsafe => .unsafe + | .partial => .partial + +/-- Convert Kernel.ReducibilityHints to Lean.ReducibilityHints. -/ +def decompileHints : ReducibilityHints → Lean.ReducibilityHints + | .opaque => .opaque + | .abbrev => .abbrev + | .regular h => .regular h + +/-- Synthetic level params: `[u_0, u_1, ..., u_{n-1}]`. -/ +def syntheticLevelParams (n : Nat) : List Lean.Name := + (List.range n).map fun i => .str .anonymous s!"u_{i}" + +/-- Convert a Kernel.ConstantInfo (.meta) back to Lean.ConstantInfo. + Name fields are resolved directly from the MetaField name fields + on the sub-structures (allNames, ctorNames, inductName, ctorName). -/ +def decompileConstantInfo (ci : ConstantInfo .meta) : Lean.ConstantInfo := + let cv := ci.cv + let lps := syntheticLevelParams cv.numLevels + let lpArr := cv.levelParams -- Array Ix.Name + let decompTy := decompileExpr lpArr cv.type + let decompVal (e : Expr .meta) := decompileExpr lpArr e + let name := ixNameToLean cv.name + match ci with + | .axiomInfo v => + .axiomInfo { + name, levelParams := lps, type := decompTy, isUnsafe := v.isUnsafe + } + | .defnInfo v => + .defnInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value + hints := decompileHints v.hints + safety := decompileSafety v.safety + } + | .thmInfo v => + .thmInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value + } + | .opaqueInfo v => + .opaqueInfo { + name, levelParams := lps, type := decompTy + value := decompVal v.value, isUnsafe := v.isUnsafe + } + | .quotInfo v => + let leanKind : Lean.QuotKind := match v.kind with + | .type => .type | .ctor => .ctor | .lift => .lift | .ind => .ind + .quotInfo { + name, levelParams := lps, type := decompTy, kind := leanKind + } + | .inductInfo v => + .inductInfo { + name, levelParams := lps, type := decompTy + numParams := v.numParams, numIndices := v.numIndices + isRec := v.isRec, isUnsafe := v.isUnsafe, isReflexive := v.isReflexive + all := v.allNames.toList.map ixNameToLean + ctors := v.ctorNames.toList.map ixNameToLean + numNested := v.numNested + } + | .ctorInfo v => + .ctorInfo { + name, levelParams := lps, type := decompTy + induct := ixNameToLean v.inductName + cidx := v.cidx, numParams := v.numParams, numFields := v.numFields + isUnsafe := v.isUnsafe + } + | .recInfo v => + .recInfo { + name, levelParams := lps, type := decompTy + all := v.allNames.toList.map ixNameToLean + numParams := v.numParams, numIndices := v.numIndices + numMotives := v.numMotives, numMinors := v.numMinors + k := v.k, isUnsafe := v.isUnsafe + rules := v.rules.toList.map fun r => { + ctor := ixNameToLean r.ctorName + nfields := r.nfields + rhs := decompVal r.rhs + } + } + +/-! ## Structural comparison -/ + +@[inline] def leanExprPtr (e : Lean.Expr) : USize := unsafe ptrAddrUnsafe e + +structure ExprPtrPair where + a : USize + b : USize + deriving Hashable, BEq + +/-- Compare two Lean.Exprs structurally, ignoring binder names and binder info. + Uses pointer-pair caching to avoid exponential blowup on shared subexpressions. + Returns `none` if structurally equal, `some (path, lhs, rhs)` on first mismatch. -/ +partial def exprStructEq (a b : Lean.Expr) (path : String := "") + : StateM (Std.HashSet ExprPtrPair) (Option (String × String × String)) := do + let ptrA := leanExprPtr a + let ptrB := leanExprPtr b + if ptrA == ptrB then return none + let pair := ExprPtrPair.mk ptrA ptrB + if (← get).contains pair then return none + let result ← match a, b with + | .bvar i, .bvar j => + pure (if i == j then none else some (path, s!"bvar({i})", s!"bvar({j})")) + | .sort l₁, .sort l₂ => + pure (if Lean.Level.isEquiv l₁ l₂ then none else some (path, s!"sort", s!"sort")) + | .const n₁ ls₁, .const n₂ ls₂ => + pure (if n₁ != n₂ then some (path, s!"const({n₁})", s!"const({n₂})") + else if ls₁.length != ls₂.length then + some (path, s!"const({n₁}) {ls₁.length} lvls", s!"const({n₂}) {ls₂.length} lvls") + else none) + | .app f₁ a₁, .app f₂ a₂ => do + match ← exprStructEq f₁ f₂ (path ++ ".app.fn") with + | some m => pure (some m) + | none => exprStructEq a₁ a₂ (path ++ ".app.arg") + | .lam _ t₁ b₁ _, .lam _ t₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".lam.ty") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".lam.body") + | .forallE _ t₁ b₁ _, .forallE _ t₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".pi.ty") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".pi.body") + | .letE _ t₁ v₁ b₁ _, .letE _ t₂ v₂ b₂ _ => do + match ← exprStructEq t₁ t₂ (path ++ ".let.ty") with + | some m => pure (some m) + | none => match ← exprStructEq v₁ v₂ (path ++ ".let.val") with + | some m => pure (some m) + | none => exprStructEq b₁ b₂ (path ++ ".let.body") + | .lit l₁, .lit l₂ => + pure (if l₁ == l₂ then none + else + let showLit : Lean.Literal → String + | .natVal n => s!"natLit({n})" + | .strVal s => s!"strLit({s})" + some (path, showLit l₁, showLit l₂)) + | .proj t₁ i₁ s₁, .proj t₂ i₂ s₂ => + if t₁ != t₂ then pure (some (path, s!"proj({t₁}.{i₁})", s!"proj({t₂}.{i₂})")) + else if i₁ != i₂ then pure (some (path, s!"proj.idx({i₁})", s!"proj.idx({i₂})")) + else exprStructEq s₁ s₂ (path ++ ".proj.struct") + | .mdata _ e₁, _ => exprStructEq e₁ b path + | _, .mdata _ e₂ => exprStructEq a e₂ path + | _, _ => + let tag (e : Lean.Expr) : String := match e with + | .bvar _ => "bvar" | .sort _ => "sort" | .const .. => "const" + | .app .. => "app" | .lam .. => "lam" | .forallE .. => "forallE" + | .letE .. => "letE" | .lit .. => "lit" | .proj .. => "proj" + | .fvar .. => "fvar" | .mvar .. => "mvar" | .mdata .. => "mdata" + pure (some (path, tag a, tag b)) + if result.isNone then modify (·.insert pair) + pure result + +/-- Compare two Lean.ConstantInfos structurally. Returns list of mismatches. -/ +def constInfoStructEq (a b : Lean.ConstantInfo) + : Array (String × String × String) := + let check : StateM (Std.HashSet ExprPtrPair) (Array (String × String × String)) := do + let mut mismatches : Array (String × String × String) := #[] + -- Compare types + if let some m ← exprStructEq a.type b.type "type" then + mismatches := mismatches.push m + -- Compare values if both have them + match a.value?, b.value? with + | some va, some vb => + if let some m ← exprStructEq va vb "value" then + mismatches := mismatches.push m + | none, some _ => mismatches := mismatches.push ("value", "none", "some") + | some _, none => mismatches := mismatches.push ("value", "some", "none") + | none, none => pure () + return mismatches + (check.run {}).1 + +end Ix.Kernel.Decompile diff --git a/Ix/Kernel/Equal.lean b/Ix/Kernel/Equal.lean new file mode 100644 index 00000000..4f219b7c --- /dev/null +++ b/Ix/Kernel/Equal.lean @@ -0,0 +1,168 @@ +/- + Kernel Equal: Definitional equality checking. + + Handles proof irrelevance, unit types, eta expansion. + In NbE, all non-partial definitions are eagerly unfolded by `eval`, so there + is no lazy delta reduction here — different const-headed values are genuinely + unequal (they are stuck constructors, recursors, axioms, or partial defs). + Adapted from Yatima.Typechecker.Equal, parameterized over MetaMode. +-/ +import Ix.Kernel.Eval + +namespace Ix.Kernel + +/-- Pointer equality on thunks: if two thunks share the same pointer, they must + produce the same value. Returns false conservatively when pointers differ. -/ +@[inline] private def susValuePtrEq (a b : SusValue m) : Bool := + unsafe ptrAddrUnsafe a.body == ptrAddrUnsafe b.body + +/-- Compare two arrays of levels for equality. -/ +private def equalUnivArrays (us us' : Array (Level m)) : Bool := + us.size == us'.size && Id.run do + let mut i := 0 + while i < us.size do + if !Level.equalLevel us[i]! us'[i]! then return false + i := i + 1 + return true + +/-- Construct a canonicalized cache key for two SusValues using their pointer addresses. + The smaller pointer always comes first, making the key symmetric: key(a,b) == key(b,a). -/ +@[inline] private def susValueCacheKey (a b : SusValue m) : USize × USize := + let pa := unsafe ptrAddrUnsafe a.body + let pb := unsafe ptrAddrUnsafe b.body + if pa ≤ pb then (pa, pb) else (pb, pa) + +mutual + /-- Try eta expansion for structure-like types. -/ + partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := do + match term'.get with + | .app (.const k _ _) args _ => + match (← get).typedConsts.get? k with + | some (.constructor type ..) => + match ← applyType (← eval type) args with + | .app (.const tk _ _) targs _ => + match (← get).typedConsts.get? tk with + | some (.inductive _ struct ..) => + -- Skip struct eta for Prop types (proof irrelevance handles them) + let isProp := match term'.info with | .proof => true | _ => false + if struct && !isProp then + targs.zipIdx.foldlM (init := true) fun acc (arg, i) => do + match arg.get with + | .app (.proj _ idx val _) _ _ => + pure (acc && i == idx && (← equal lvl term val.sus)) + | _ => pure false + else pure false + | _ => pure false + | _ => pure false + | _ => pure false + | _ => pure false + + /-- Check if two suspended values are definitionally equal at the given level. + Assumes both have the same type and live in the same context. -/ + partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := + match term.info, term'.info with + | .unit, .unit => pure true + | .proof, .proof => pure true + | _, _ => withFuelCheck do + if (← read).trace then dbg_trace s!"equal: {term.get.ctorName} vs {term'.get.ctorName}" + -- Fast path: pointer equality on thunks + if susValuePtrEq term term' then return true + -- Check equality cache + let key := susValueCacheKey term term' + if let some true := (← get).equalCache.get? key then return true + let tv := term.get + let tv' := term'.get + let result ← match tv, tv' with + | .lit lit, .lit lit' => pure (lit == lit') + | .sort u, .sort u' => pure (Level.equalLevel u u') + | .pi dom img env _ _, .pi dom' img' env' _ _ => do + let res ← equal lvl dom dom' + let ctx ← read + let stt ← get + let img := suspend img { ctx with env := env.extendWith (mkSusVar dom.info lvl) } stt + let img' := suspend img' { ctx with env := env'.extendWith (mkSusVar dom'.info lvl) } stt + let res' ← equal (lvl + 1) img img' + if !res' then + dbg_trace s!"equal Pi images FAILED at lvl={lvl}: lhs={img.get.dump} rhs={img'.get.dump}" + pure (res && res') + | .lam dom bod env _ _, .lam dom' bod' env' _ _ => do + let res ← equal lvl dom dom' + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith (mkSusVar dom.info lvl) } stt + let bod' := suspend bod' { ctx with env := env'.extendWith (mkSusVar dom'.info lvl) } stt + let res' ← equal (lvl + 1) bod bod' + pure (res && res') + | .lam dom bod env _ _, .app neu' args' infos' => do + let var := mkSusVar dom.info lvl + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith var } stt + let app := Value.app neu' (var :: args') (term'.info :: infos') + equal (lvl + 1) bod (.mk bod.info app) + | .app neu args infos, .lam dom bod env _ _ => do + let var := mkSusVar dom.info lvl + let ctx ← read + let stt ← get + let bod := suspend bod { ctx with env := env.extendWith var } stt + let app := Value.app neu (var :: args) (term.info :: infos) + equal (lvl + 1) (.mk bod.info app) bod + | .app (.fvar idx _) args _, .app (.fvar idx' _) args' _ => + if idx == idx' then equalThunks lvl args args' + else pure false + | .app (.const k us _) args _, .app (.const k' us' _) args' _ => + if k == k' && equalUnivArrays us us' then + equalThunks lvl args args' + else + -- In NbE, eval eagerly unfolds all non-partial definitions. + -- Different const heads here are stuck terms that can't reduce further. + pure false + -- Nat literal vs constructor expansion + | .lit (.natVal _), .app (.const _ _ _) _ _ => do + let prims := (← read).prims + let expanded ← toCtorIfLit prims tv + equal lvl (.mk term.info (.mk fun _ => expanded)) term' + | .app (.const _ _ _) _ _, .lit (.natVal _) => do + let prims := (← read).prims + let expanded ← toCtorIfLit prims tv' + equal lvl term (.mk term'.info (.mk fun _ => expanded)) + -- String literal vs constructor expansion + | .lit (.strVal _), .app (.const _ _ _) _ _ => do + let prims := (← read).prims + let expanded ← strLitToCtorVal prims (match tv with | .lit (.strVal s) => s | _ => "") + equal lvl (.mk term.info (.mk fun _ => expanded)) term' + | .app (.const _ _ _) _ _, .lit (.strVal _) => do + let prims := (← read).prims + let expanded ← strLitToCtorVal prims (match tv' with | .lit (.strVal s) => s | _ => "") + equal lvl term (.mk term'.info (.mk fun _ => expanded)) + | _, .app (.const _ _ _) _ _ => + tryEtaStruct lvl term term' + | .app (.const _ _ _) _ _, _ => + tryEtaStruct lvl term' term + | .app (.proj ind idx val _) args _, .app (.proj ind' idx' val' _) args' _ => + if ind == ind' && idx == idx' then do + let eqVal ← equal lvl val.sus val'.sus + let eqThunks ← equalThunks lvl args args' + pure (eqVal && eqThunks) + else pure false + | .exception e, _ | _, .exception e => + throw s!"exception in equal: {e}" + | _, _ => + dbg_trace s!"equal FALLTHROUGH at lvl={lvl}: lhs={tv.dump} rhs={tv'.dump}" + pure false + if result then + modify fun stt => { stt with equalCache := stt.equalCache.insert key true } + return result + + /-- Check if two lists of suspended values are pointwise equal. -/ + partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m Bool := + match vals, vals' with + | val :: vals, val' :: vals' => do + let eq ← equal lvl val val' + let eq' ← equalThunks lvl vals vals' + pure (eq && eq') + | [], [] => pure true + | _, _ => pure false +end + +end Ix.Kernel diff --git a/Ix/Kernel/Eval.lean b/Ix/Kernel/Eval.lean new file mode 100644 index 00000000..9fa74125 --- /dev/null +++ b/Ix/Kernel/Eval.lean @@ -0,0 +1,530 @@ +/- + Kernel Eval: Expression evaluation, constant/recursor/quot/nat reduction. + + Adapted from Yatima.Typechecker.Eval, parameterized over MetaMode. +-/ +import Ix.Kernel.TypecheckM + +namespace Ix.Kernel + +open Level (instBulkReduce reduceIMax) + +def TypeInfo.update (univs : Array (Level m)) : TypeInfo m → TypeInfo m + | .sort lvl => .sort (instBulkReduce univs lvl) + | .unit => .unit + | .proof => .proof + | .none => .none + +/-! ## Helpers (needed by mutual block) -/ + +/-- Check if an address is a primitive operation that takes arguments. -/ +private def isPrimOp (prims : Primitives) (addr : Address) : Bool := + addr == prims.natAdd || addr == prims.natSub || addr == prims.natMul || + addr == prims.natPow || addr == prims.natGcd || addr == prims.natMod || + addr == prims.natDiv || addr == prims.natBeq || addr == prims.natBle || + addr == prims.natLand || addr == prims.natLor || addr == prims.natXor || + addr == prims.natShiftLeft || addr == prims.natShiftRight || + addr == prims.natSucc + +/-- Look up element in a list by index. -/ +def listGet? (l : List α) (n : Nat) : Option α := + match l, n with + | [], _ => none + | a :: _, 0 => some a + | _ :: l, n+1 => listGet? l n + +/-- Try to reduce a primitive operation if all arguments are available. -/ +private def tryPrimOp (prims : Primitives) (addr : Address) + (args : List (SusValue m)) : TypecheckM m (Option (Value m)) := do + -- Nat.succ: 1 arg + if addr == prims.natSucc then + if args.length >= 1 then + match args.head!.get with + | .lit (.natVal n) => return some (.lit (.natVal (n + 1))) + | _ => return none + else return none + -- Binary nat operations: 2 args + else if args.length >= 2 then + let a := args[0]!.get + let b := args[1]!.get + match a, b with + | .lit (.natVal x), .lit (.natVal y) => + if addr == prims.natAdd then return some (.lit (.natVal (x + y))) + else if addr == prims.natSub then return some (.lit (.natVal (x - y))) + else if addr == prims.natMul then return some (.lit (.natVal (x * y))) + else if addr == prims.natPow then + if y > 16777216 then return none + return some (.lit (.natVal (Nat.pow x y))) + else if addr == prims.natMod then return some (.lit (.natVal (x % y))) + else if addr == prims.natDiv then return some (.lit (.natVal (x / y))) + else if addr == prims.natGcd then return some (.lit (.natVal (Nat.gcd x y))) + else if addr == prims.natBeq then + let boolAddr := if x == y then prims.boolTrue else prims.boolFalse + let boolName ← lookupName boolAddr + return some (mkConst boolAddr #[] boolName) + else if addr == prims.natBle then + let boolAddr := if x ≤ y then prims.boolTrue else prims.boolFalse + let boolName ← lookupName boolAddr + return some (mkConst boolAddr #[] boolName) + else if addr == prims.natLand then return some (.lit (.natVal (Nat.land x y))) + else if addr == prims.natLor then return some (.lit (.natVal (Nat.lor x y))) + else if addr == prims.natXor then return some (.lit (.natVal (Nat.xor x y))) + else if addr == prims.natShiftLeft then return some (.lit (.natVal (Nat.shiftLeft x y))) + else if addr == prims.natShiftRight then return some (.lit (.natVal (Nat.shiftRight x y))) + else return none + | _, _ => return none + else return none + +/-- Expand a string literal to its constructor form: String.mk (list-of-chars). + Each character is represented as Char.ofNat n, and the list uses + List.cons/List.nil at universe level 0. -/ +def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) := do + let charMkName ← lookupName prims.charMk + let charName ← lookupName prims.char + let listNilName ← lookupName prims.listNil + let listConsName ← lookupName prims.listCons + let stringMkName ← lookupName prims.stringMk + let mkCharOfNat (c : Char) : SusValue m := + ⟨.none, .mk fun _ => + Value.app (.const prims.charMk #[] charMkName) + [⟨.none, .mk fun _ => .lit (.natVal c.toNat)⟩] [.none]⟩ + let charType : SusValue m := + ⟨.none, .mk fun _ => Value.neu (.const prims.char #[] charName)⟩ + let nilVal : Value m := + Value.app (.const prims.listNil #[.zero] listNilName) [charType] [.none] + let listVal := s.toList.foldr (fun c acc => + let tail : SusValue m := ⟨.none, .mk fun _ => acc⟩ + let head := mkCharOfNat c + Value.app (.const prims.listCons #[.zero] listConsName) + [tail, head, charType] [.none, .none, .none] + ) nilVal + let data : SusValue m := ⟨.none, .mk fun _ => listVal⟩ + pure (Value.app (.const prims.stringMk #[] stringMkName) [data] [.none]) + +/-! ## Eval / Apply mutual block -/ + +mutual + /-- Evaluate a typed expression to a value. -/ + partial def eval (t : TypedExpr m) : TypecheckM m (Value m) := withFuelCheck do + if (← read).trace then dbg_trace s!"eval: {t.body.tag}" + match t.body with + | .app fnc arg => do + let ctx ← read + let stt ← get + let argThunk := suspend ⟨default, arg⟩ ctx stt + let fnc ← evalTyped ⟨default, fnc⟩ + try apply fnc argThunk + catch e => + throw s!"{e}\n in app: ({fnc.body.summary}) applied to ({arg.pp})" + | .lam ty body name bi => do + let ctx ← read + let stt ← get + let dom := suspend ⟨default, ty⟩ ctx stt + pure (.lam dom ⟨default, body⟩ ctx.env name bi) + | .bvar idx _ => do + let some thunk := listGet? (← read).env.exprs idx + | throw s!"Index {idx} is out of range for expression environment" + pure thunk.get + | .const addr levels name => do + let env := (← read).env + let levels := levels.map (instBulkReduce env.univs.toArray) + try evalConst addr levels name + catch e => + let nameStr := match (← read).kenv.find? addr with + | some c => s!"{c.cv.name}" | none => s!"{addr}" + throw s!"{e}\n in eval const {nameStr}" + | .letE _ val body _ => do + let ctx ← read + let stt ← get + let thunk := suspend ⟨default, val⟩ ctx stt + withExtendedEnv thunk (eval ⟨default, body⟩) + | .forallE ty body name bi => do + let ctx ← read + let stt ← get + let dom := suspend ⟨default, ty⟩ ctx stt + pure (.pi dom ⟨default, body⟩ ctx.env name bi) + | .sort univ => do + let env := (← read).env + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .lit lit => + pure (.lit lit) + | .proj typeAddr idx struct typeName => do + let raw ← eval ⟨default, struct⟩ + -- Expand string literals to constructor form before projecting + let val ← match raw with + | .lit (.strVal s) => strLitToCtorVal (← read).prims s + | v => pure v + match val with + | .app (.const ctorAddr _ _) args _ => + let ctx ← read + match ctx.kenv.find? ctorAddr with + | some (.ctorInfo v) => + let idx := v.numParams + idx + let some arg := listGet? args.reverse idx + | throw s!"Invalid projection of index {idx} but constructor has only {args.length} arguments" + pure arg.get + | _ => do + let ti := TypeInfo.update (← read).env.univs.toArray (default : TypeInfo m) + pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) + | .app _ _ _ => do + let ti := TypeInfo.update (← read).env.univs.toArray (default : TypeInfo m) + pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) + | e => throw s!"Value is impossible to project: {e.ctorName}" + + partial def evalTyped (t : TypedExpr m) : TypecheckM m (AddInfo (TypeInfo m) (Value m)) := do + let reducedInfo := t.info.update (← read).env.univs.toArray + let value ← eval t + pure ⟨reducedInfo, value⟩ + + /-- Evaluate a constant that is not a primitive. + Theorems are treated as opaque (not unfolded) — proof irrelevance handles + equality of proof terms, and this avoids deep recursion through proof bodies. + Caches evaluated definition bodies to avoid redundant evaluation. -/ + partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + match (← read).kenv.find? addr with + | some (.defnInfo _) => + -- Check eval cache (must also match universe parameters) + if let some (cachedUnivs, cachedVal) := (← get).evalCache.get? addr then + if cachedUnivs == univs then return cachedVal + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.definition _ deref part) => + if part then pure (mkConst addr univs name) + else + let val ← withEnv (.mk [] univs.toList) (eval deref) + modify fun stt => { stt with evalCache := stt.evalCache.insert addr (univs, val) } + pure val + | _ => throw "Invalid const kind for evaluation" + | _ => pure (mkConst addr univs name) + + /-- Evaluate a constant: check if it's Nat.zero, a primitive op, or unfold it. -/ + partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + let prims := (← read).prims + if addr == prims.natZero then pure (.lit (.natVal 0)) + else if isPrimOp prims addr then pure (mkConst addr univs name) + else evalConst' addr univs name + + /-- Create a suspended value from a typed expression, capturing context. -/ + partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m) (stt : TypecheckState m) : SusValue m := + let thunk : Thunk (Value m) := .mk fun _ => + match TypecheckM.run ctx stt (eval expr) with + | .ok a => a + | .error e => .exception e + let reducedInfo := expr.info.update ctx.env.univs.toArray + ⟨reducedInfo, thunk⟩ + + /-- Apply a value to an argument. -/ + partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m (Value m) := do + if (← read).trace then dbg_trace s!"apply: {val.body.ctorName}" + match val.body with + | .lam _ bod lamEnv _ _ => + withNewExtendedEnv lamEnv arg (eval bod) + | .pi dom img piEnv _ _ => + -- Propagate TypeInfo: if domain is Prop, argument is a proof + let enrichedArg : SusValue m := match arg.info, dom.info with + | .none, .sort (.zero) => ⟨.proof, arg.body⟩ + | _, _ => arg + withNewExtendedEnv piEnv enrichedArg (eval img) + | .app (.const addr univs name) args infos => applyConst addr univs arg args val.info infos name + | .app neu args infos => pure (.app neu (arg :: args) (val.info :: infos)) + | v => + throw s!"Invalid case for apply: got {v.ctorName} ({v.summary})" + + /-- Apply a named constant to arguments, handling recursors, quotients, and primitives. -/ + partial def applyConst (addr : Address) (univs : Array (Level m)) (arg : SusValue m) + (args : List (SusValue m)) (info : TypeInfo m) (infos : List (TypeInfo m)) + (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + let prims := (← read).prims + -- Try primitive operations + if let some result ← tryPrimOp prims addr (arg :: args) then + return result + + ---- Try recursor/quotient (ensure provisional entry exists for eval-time lookups) + ensureTypedConst addr + match (← get).typedConsts.get? addr with + | some (.recursor _ params motives minors indices isK indAddr rules) => + let majorIdx := params + motives + minors + indices + if args.length != majorIdx then + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else if isK then + -- K-reduce when major is a constructor, or shortcut via proof irrelevance + let isKCtor ← match ← toCtorIfLit prims (arg.get) with + | .app (.const ctorAddr _ _) _ _ => + match (← get).typedConsts.get? ctorAddr with + | some (.constructor ..) => pure true + | _ => match (← read).kenv.find? ctorAddr with + | some (.ctorInfo _) => pure true + | _ => pure false + | _ => pure false + -- Also check if the inductive lives in Prop, since eval doesn't track TypeInfo + let isPropInd := match (← read).kenv.find? indAddr with + | some (.inductInfo v) => + let rec getSort : Expr m → Bool + | .forallE _ body _ _ => getSort body + | .sort (.zero) => true + | _ => false + getSort v.type + | _ => false + if isKCtor || isPropInd || (match arg.info with | .proof => true | _ => false) then + let nArgs := args.length + let nDrop := params + motives + 1 + if nArgs < nDrop then throw s!"Too few arguments ({nArgs}). At least {nDrop} needed" + let minorIdx := nArgs - nDrop + let some minor := listGet? args minorIdx | throw s!"Index {minorIdx} is out of range" + pure minor.get + else + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + -- Skip Nat.rec reduction on large literals to avoid O(n) eval overhead + let skipLargeNat := match arg.get with + | .lit (.natVal n) => indAddr == prims.nat && n > 256 + | _ => false + if skipLargeNat then + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + match ← toCtorIfLit prims (arg.get) with + | .app (.const ctorAddr _ _) ctorArgs _ => + let st ← get + let ctx ← read + let ctorInfo? := match st.typedConsts.get? ctorAddr with + | some (.constructor _ ctorIdx numFields) => some (ctorIdx, numFields) + | _ => match ctx.kenv.find? ctorAddr with + | some (.ctorInfo cv) => some (cv.cidx, cv.numFields) + | _ => none + match ctorInfo? with + | some (ctorIdx, _) => + match rules[ctorIdx]? with + | some (fields, rhs) => + let exprs := (ctorArgs.take fields) ++ (args.drop indices) + withEnv (.mk exprs univs.toList) (eval rhs.toImplicitLambda) + | none => throw s!"Constructor has no associated recursion rule" + | none => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => + -- Structure eta: expand struct-like major via projections + let kenv := (← read).kenv + let doStructEta := match arg.info with + | .proof => false + | _ => kenv.isStructureLike indAddr + if doStructEta then + match rules[0]? with + | some (fields, rhs) => + let mut projArgs : List (SusValue m) := [] + for i in [:fields] do + let proj : SusValue m := ⟨.none, .mk fun _ => + Value.app (.proj indAddr i ⟨.none, arg.get⟩ default) [] []⟩ + projArgs := proj :: projArgs + let exprs := projArgs ++ (args.drop indices) + withEnv (.mk exprs univs.toList) (eval rhs.toImplicitLambda) + | none => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + else + pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | some (.quotient _ kind) => match kind with + | .lift => applyQuot prims arg args 6 1 (.app (.const addr univs name) (arg :: args) (info :: infos)) + | .ind => applyQuot prims arg args 5 0 (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + | _ => pure (.app (.const addr univs name) (arg :: args) (info :: infos)) + + /-- Apply a quotient to a value. -/ + partial def applyQuot (_prims : Primitives) (major : SusValue m) (args : List (SusValue m)) + (reduceSize argPos : Nat) (default : Value m) : TypecheckM m (Value m) := + let argsLength := args.length + 1 + if argsLength == reduceSize then + match major.get with + | .app (.const majorFn _ _) majorArgs _ => do + match (← get).typedConsts.get? majorFn with + | some (.quotient _ .ctor) => + if majorArgs.length != 3 then throw "majorArgs should have size 3" + let some majorArg := majorArgs.head? | throw "majorArgs can't be empty" + let some head := listGet? args argPos | throw s!"{argPos} is an invalid index for args" + apply head.getTyped majorArg + | _ => pure default + | _ => pure default + else if argsLength < reduceSize then pure default + else throw s!"argsLength {argsLength} can't be greater than reduceSize {reduceSize}" + + /-- Convert a nat literal to Nat.succ/Nat.zero constructors. -/ + partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m (Value m) + | .lit (.natVal 0) => do + let name ← lookupName prims.natZero + pure (Value.neu (.const prims.natZero #[] name)) + | .lit (.natVal (n+1)) => do + let name ← lookupName prims.natSucc + let thunk : SusValue m := ⟨.none, Thunk.mk fun _ => .lit (.natVal n)⟩ + pure (.app (.const prims.natSucc #[] name) [thunk] [.none]) + | v => pure v +end + +/-! ## Quoting (read-back from Value to Expr) -/ + +mutual + partial def quote (lvl : Nat) : Value m → TypecheckM m (Expr m) + | .sort univ => do + let env := (← read).env + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .app neu args infos => do + let argsInfos := args.zip infos + argsInfos.foldrM (init := ← quoteNeutral lvl neu) fun (arg, _info) acc => do + let argExpr ← quoteTyped lvl arg.getTyped + pure (.app acc argExpr.body) + | .lam dom bod env name bi => do + let dom ← quoteTyped lvl dom.getTyped + let var := mkSusVar (default : TypeInfo m) lvl name + let bod ← quoteTypedExpr (lvl+1) bod (env.extendWith var) + pure (.lam dom.body bod.body name bi) + | .pi dom img env name bi => do + let dom ← quoteTyped lvl dom.getTyped + let var := mkSusVar (default : TypeInfo m) lvl name + let img ← quoteTypedExpr (lvl+1) img (env.extendWith var) + pure (.forallE dom.body img.body name bi) + | .lit lit => pure (.lit lit) + | .exception e => throw e + + partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m (TypedExpr m) := do + pure ⟨val.info, ← quote lvl val.body⟩ + + partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m (TypedExpr m) := do + let e ← quoteExpr lvl t.body env + pure ⟨t.info, e⟩ + + partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m (Expr m) := + match expr with + | .bvar idx _ => do + match listGet? env.exprs idx with + | some val => quote lvl val.get + | none => throw s!"Unbound variable _@{idx}" + | .app fnc arg => do + let fnc ← quoteExpr lvl fnc env + let arg ← quoteExpr lvl arg env + pure (.app fnc arg) + | .lam ty body n bi => do + let ty ← quoteExpr lvl ty env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.lam ty body n bi) + | .forallE ty body n bi => do + let ty ← quoteExpr lvl ty env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.forallE ty body n bi) + | .letE ty val body n => do + let ty ← quoteExpr lvl ty env + let val ← quoteExpr lvl val env + let var := mkSusVar (default : TypeInfo m) lvl n + let body ← quoteExpr (lvl+1) body (env.extendWith var) + pure (.letE ty val body n) + | .const addr levels name => + pure (.const addr (levels.map (instBulkReduce env.univs.toArray)) name) + | .sort univ => + pure (.sort (instBulkReduce env.univs.toArray univ)) + | .proj typeAddr idx struct name => do + let struct ← quoteExpr lvl struct env + pure (.proj typeAddr idx struct name) + | .lit .. => pure expr + + partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m (Expr m) + | .fvar idx name => do + pure (.bvar (lvl - idx - 1) name) + | .const addr univs name => do + let env := (← read).env + pure (.const addr (univs.map (instBulkReduce env.univs.toArray)) name) + | .proj typeAddr idx val name => do + let te ← quoteTyped lvl val + pure (.proj typeAddr idx te.body name) +end + +/-! ## Literal folding for pretty printing -/ + +/-- Try to extract a Char from a Char.ofNat application in an Expr. -/ +private partial def tryFoldChar (prims : Primitives) (e : Expr m) : Option Char := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.charMk then + let args := e.getAppArgs + if args.size == 1 then + match args[0]! with + | .lit (.natVal n) => some (Char.ofNat n) + | _ => none + else none + else none + | _ => none + +/-- Try to extract a List Char from a List.cons/List.nil chain in an Expr. -/ +private partial def tryFoldCharList (prims : Primitives) (e : Expr m) : Option (List Char) := + match e.getAppFn with + | .const addr _ _ => + if addr == prims.listNil then some [] + else if addr == prims.listCons then + let args := e.getAppArgs + -- args = [type, head, tail] + if args.size == 3 then + match tryFoldChar prims args[1]!, tryFoldCharList prims args[2]! with + | some c, some cs => some (c :: cs) + | _, _ => none + else none + else none + | _ => none + +/-- Walk an Expr and fold Nat.zero/Nat.succ chains to nat literals, + and String.mk (char list) to string literals. -/ +partial def foldLiterals (prims : Primitives) : Expr m → Expr m + | .const addr lvls name => + if addr == prims.natZero then .lit (.natVal 0) + else .const addr lvls name + | .app fn arg => + let fn' := foldLiterals prims fn + let arg' := foldLiterals prims arg + let e := Expr.app fn' arg' + -- Try folding the fully-reconstructed app + match e.getAppFn with + | .const addr _ _ => + if addr == prims.natSucc && e.getAppNumArgs == 1 then + match e.appArg! with + | .lit (.natVal n) => .lit (.natVal (n + 1)) + | _ => e + else if addr == prims.stringMk && e.getAppNumArgs == 1 then + match tryFoldCharList prims e.appArg! with + | some cs => .lit (.strVal (String.ofList cs)) + | none => e + else e + | _ => e + | .lam ty body n bi => + .lam (foldLiterals prims ty) (foldLiterals prims body) n bi + | .forallE ty body n bi => + .forallE (foldLiterals prims ty) (foldLiterals prims body) n bi + | .letE ty val body n => + .letE (foldLiterals prims ty) (foldLiterals prims val) (foldLiterals prims body) n + | .proj ta idx s tn => + .proj ta idx (foldLiterals prims s) tn + | e => e + +/-! ## Value pretty printing -/ + +/-- Pretty-print a value by quoting it back to an Expr, then using Expr.pp. + Folds Nat/String constructor chains back to literals for readability. -/ +partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m String := do + let expr ← quote lvl v + let expr := foldLiterals (← read).prims expr + return expr.pp + +/-- Pretty-print a suspended value. -/ +partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m String := + ppValue lvl sv.get + +/-- Pretty-print a value, falling back to the shallow summary on error. -/ +partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m String := do + try ppValue lvl v + catch _ => return v.summary + +/-- Apply a value to a list of arguments. -/ +def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m (Value m) := + match args with + | [] => pure v + | arg :: rest => do + let info : TypeInfo m := .none + let v' ← try apply ⟨info, v⟩ arg + catch e => + let ppV ← tryPpValue (← read).lvl v + throw s!"{e}\n in applyType: {ppV} with {args.length} remaining args" + applyType v' rest + +end Ix.Kernel diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean new file mode 100644 index 00000000..1d0b0159 --- /dev/null +++ b/Ix/Kernel/Infer.lean @@ -0,0 +1,406 @@ +/- + Kernel Infer: Type inference and declaration checking. + + Adapted from Yatima.Typechecker.Infer, parameterized over MetaMode. +-/ +import Ix.Kernel.Equal + +namespace Ix.Kernel + +/-! ## Type info helpers -/ + +def lamInfo : TypeInfo m → TypeInfo m + | .proof => .proof + | _ => .none + +def piInfo (dom img : TypeInfo m) : TypecheckM m (TypeInfo m) := match dom, img with + | .sort lvl, .sort lvl' => pure (.sort (Level.reduceIMax lvl lvl')) + | _, _ => pure .none + +def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m Bool := do + match inferType.info, expectType.info with + | .sort lvl, .sort lvl' => pure (Level.equalLevel lvl lvl') + | _, _ => pure true -- info unavailable; defer to structural equality + +def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := + match typ.info with + | .sort (.zero) => pure .proof + | _ => + match typ.get with + | .app (.const addr _ _) _ _ => do + match (← read).kenv.find? addr with + | some (.inductInfo v) => + -- Check if it's unit-like: one constructor with zero fields + if v.ctors.size == 1 then + match (← read).kenv.find? v.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields == 0 then pure .unit else pure .none + | _ => pure .none + else pure .none + | _ => pure .none + | .sort lvl => pure (.sort lvl) + | _ => pure .none + +/-! ## Inference / Checking -/ + +mutual + /-- Check that a term has a given type. -/ + partial def check (term : Expr m) (type : SusValue m) : TypecheckM m (TypedExpr m) := do + if (← read).trace then dbg_trace s!"check: {term.tag}" + let (te, inferType) ← infer term + if !(← eqSortInfo inferType type) then + throw s!"Info mismatch on {term.tag}" + if !(← equal (← read).lvl type inferType) then + let lvl := (← read).lvl + let ppInferred ← tryPpValue lvl inferType.get + let ppExpected ← tryPpValue lvl type.get + let dumpInferred := inferType.get.dump + let dumpExpected := type.get.dump + throw s!"Type mismatch on {term.tag}\n inferred: {ppInferred}\n expected: {ppExpected}\n inferred dump: {dumpInferred}\n expected dump: {dumpExpected}\n inferred info: {inferType.info.pp}\n expected info: {type.info.pp}" + pure te + + /-- Infer the type of an expression, returning the typed expression and its type. -/ + partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × SusValue m) := withFuelCheck do + if (← read).trace then dbg_trace s!"infer: {term.tag}" + match term with + | .bvar idx bvarName => do + let ctx ← read + if idx < ctx.lvl then + let some type := listGet? ctx.types idx + | throw s!"var@{idx} out of environment range (size {ctx.types.length})" + let te : TypedExpr m := ⟨← infoFromType type, .bvar idx bvarName⟩ + pure (te, type) + else + -- Mutual reference + match ctx.mutTypes.get? (idx - ctx.lvl) with + | some (addr, typeValFn) => + if some addr == ctx.recAddr? then + throw s!"Invalid recursion" + let univs := ctx.env.univs.toArray + let type := typeValFn univs + let name ← lookupName addr + let te : TypedExpr m := ⟨← infoFromType type, .const addr univs name⟩ + pure (te, type) + | none => + throw s!"var@{idx} out of environment range and does not represent a mutual constant" + | .sort lvl => do + let univs := (← read).env.univs.toArray + let lvl := Level.instBulkReduce univs lvl + let lvl' := Level.succ lvl + let typ : SusValue m := .mk (.sort (Level.succ lvl')) (.mk fun _ => .sort lvl') + let te : TypedExpr m := ⟨.sort lvl', .sort lvl⟩ + pure (te, typ) + | .app fnc arg => do + let (fnTe, fncType) ← infer fnc + match fncType.get with + | .pi dom img piEnv _ _ => do + let argTe ← check arg dom + let ctx ← read + let stt ← get + let typ := suspend img { ctx with env := piEnv.extendWith (suspend argTe ctx stt) } stt + let te : TypedExpr m := ⟨← infoFromType typ, .app fnTe.body argTe.body⟩ + pure (te, typ) + | v => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a pi type, got {ppV}\n dump: {v.dump}\n fncType info: {fncType.info.pp}\n function: {fnc.pp}\n argument: {arg.pp}" + | .lam ty body lamName lamBi => do + let (domTe, _) ← isSort ty + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl lamName + let (bodTe, imgVal) ← withExtendedCtx var domVal (infer body) + let te : TypedExpr m := ⟨lamInfo bodTe.info, .lam domTe.body bodTe.body lamName lamBi⟩ + let imgTE ← quoteTyped (ctx.lvl+1) imgVal.getTyped + let typ : SusValue m := ⟨← piInfo domVal.info imgVal.info, + Thunk.mk fun _ => Value.pi domVal imgTE ctx.env lamName lamBi⟩ + pure (te, typ) + | .forallE ty body piName _ => do + let (domTe, domLvl) ← isSort ty + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let domSusVal := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx domSusVal domVal do + let (imgTe, imgLvl) ← isSort body + let sortLvl := Level.reduceIMax domLvl imgLvl + let typ : SusValue m := .mk (.sort (Level.succ sortLvl)) (.mk fun _ => .sort sortLvl) + let te : TypedExpr m := ⟨← infoFromType typ, .forallE domTe.body imgTe.body piName default⟩ + pure (te, typ) + | .letE ty val body letName => do + let (tyTe, _) ← isSort ty + let ctx ← read + let stt ← get + let tyVal := suspend tyTe ctx stt + let valTe ← check val tyVal + let valVal := suspend valTe ctx stt + let (bodTe, typ) ← withExtendedCtx valVal tyVal (infer body) + let te : TypedExpr m := ⟨bodTe.info, .letE tyTe.body valTe.body bodTe.body letName⟩ + pure (te, typ) + | .lit (.natVal _) => do + let prims := (← read).prims + let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.nat #[]) + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .lit (.strVal _) => do + let prims := (← read).prims + let typ : SusValue m := .mk (.sort (Level.succ .zero)) (.mk fun _ => mkConst prims.string #[]) + let te : TypedExpr m := ⟨.none, term⟩ + pure (te, typ) + | .const addr constUnivs _ => do + ensureTypedConst addr + let ctx ← read + let univs := ctx.env.univs.toArray + let reducedUnivs := constUnivs.toList.map (Level.instBulkReduce univs) + -- Check const type cache (must also match universe parameters) + match (← get).constTypeCache.get? addr with + | some (cachedUnivs, cachedTyp) => + if cachedUnivs == reducedUnivs then + let te : TypedExpr m := ⟨← infoFromType cachedTyp, term⟩ + pure (te, cachedTyp) + else + let tconst ← derefTypedConst addr + let env : ValEnv m := .mk [] reducedUnivs + let stt ← get + let typ := suspend tconst.type { ctx with env := env } stt + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } + let te : TypedExpr m := ⟨← infoFromType typ, term⟩ + pure (te, typ) + | none => + let tconst ← derefTypedConst addr + let env : ValEnv m := .mk [] reducedUnivs + let stt ← get + let typ := suspend tconst.type { ctx with env := env } stt + modify fun stt => { stt with constTypeCache := stt.constTypeCache.insert addr (reducedUnivs, typ) } + let te : TypedExpr m := ⟨← infoFromType typ, term⟩ + pure (te, typ) + | .proj typeAddr idx struct _ => do + let (structTe, structType) ← infer struct + let (ctorType, univs, params) ← getStructInfo structType.get + let mut ct ← applyType (← withEnv (.mk [] univs) (eval ctorType)) params.reverse + for i in [:idx] do + match ct with + | .pi dom img piEnv _ _ => do + let info ← infoFromType dom + let ctx ← read + let stt ← get + let proj := suspend ⟨info, .proj typeAddr i structTe.body default⟩ ctx stt + ct ← withNewExtendedEnv piEnv proj (eval img) + | _ => pure () + match ct with + | .pi dom _ _ _ _ => + let te : TypedExpr m := ⟨← infoFromType dom, .proj typeAddr idx structTe.body default⟩ + pure (te, dom) + | _ => throw "Impossible case: structure type does not have enough fields" + + /-- Check if an expression is a Sort, returning the typed expr and the universe level. -/ + partial def isSort (expr : Expr m) : TypecheckM m (TypedExpr m × Level m) := do + let (te, typ) ← infer expr + match typ.get with + | .sort u => pure (te, u) + | v => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a sort type, got {ppV}\n expr: {expr.pp}" + + /-- Get structure info from a value that should be a structure type. -/ + partial def getStructInfo (v : Value m) : + TypecheckM m (TypedExpr m × List (Level m) × List (SusValue m)) := do + match v with + | .app (.const indAddr univs _) params _ => + match (← read).kenv.find? indAddr with + | some (.inductInfo v) => + if v.ctors.size != 1 || params.length != v.numParams then + throw s!"Expected a structure type, but {v.name} ({indAddr}) has {v.ctors.size} ctors and {params.length}/{v.numParams} params" + ensureTypedConst indAddr + let ctorAddr := v.ctors[0]! + ensureTypedConst ctorAddr + match (← get).typedConsts.get? ctorAddr with + | some (.constructor type _ _) => + return (type, univs.toList, params) + | _ => throw s!"Constructor {ctorAddr} is not in typed consts" + | some ci => throw s!"Expected a structure type, but {indAddr} is a {ci.kindName}" + | none => throw s!"Expected a structure type, but {indAddr} not found in env" + | _ => + let ppV ← tryPpValue (← read).lvl v + throw s!"Expected a structure type, got {ppV}" + + /-- Typecheck a constant. With fresh state per declaration, dependencies get + provisional entries via `ensureTypedConst` and are assumed well-typed. -/ + partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do + -- Reset fuel and per-constant caches + modify fun stt => { stt with + fuel := defaultFuel + evalCache := {} + equalCache := {} + constTypeCache := {} } + -- Skip if already in typedConsts (provisional entry is fine — dependency assumed well-typed) + if (← get).typedConsts.get? addr |>.isSome then + return () + let ci ← derefConst addr + let univs := ci.cv.mkUnivParams + withEnv (.mk [] univs.toList) do + let newConst ← match ci with + | .axiomInfo _ => + let (type, _) ← isSort ci.type + pure (TypedConst.axiom type) + | .opaqueInfo _ => + let (type, _) ← isSort ci.type + let typeSus := suspend type (← read) (← get) + let value ← withRecAddr addr (check ci.value?.get! typeSus) + pure (TypedConst.opaque type value) + | .thmInfo _ => + let (type, lvl) ← isSort ci.type + if !Level.isZero lvl then + throw s!"theorem type must be a proposition (Sort 0)" + let typeSus := suspend type (← read) (← get) + let value ← withRecAddr addr (check ci.value?.get! typeSus) + pure (TypedConst.theorem type value) + | .defnInfo v => + let (type, _) ← isSort ci.type + let ctx ← read + let stt ← get + let typeSus := suspend type ctx stt + let part := v.safety == .partial + let value ← + if part then + let typeSusFn := suspend type { ctx with env := ValEnv.mk ctx.env.exprs ctx.env.univs } stt + let mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare := + (Std.TreeMap.empty).insert 0 (addr, fun _ => typeSusFn) + withMutTypes mutTypes (withRecAddr addr (check v.value typeSus)) + else withRecAddr addr (check v.value typeSus) + pure (TypedConst.definition type value part) + | .quotInfo v => + let (type, _) ← isSort ci.type + pure (TypedConst.quotient type v.kind) + | .inductInfo _ => + checkIndBlock addr + return () + | .ctorInfo v => + checkIndBlock v.induct + return () + | .recInfo v => do + -- Extract the major premise's inductive from the recursor type + let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices + |>.getD default + -- Ensure the inductive has a provisional entry (assumed well-typed with fresh state per decl) + ensureTypedConst indAddr + -- Check recursor type + let (type, _) ← isSort ci.type + -- Check recursor rules + let typedRules ← v.rules.mapM fun rule => do + let (rhs, _) ← infer rule.rhs + pure (rule.nfields, rhs) + pure (TypedConst.recursor type v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules) + modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr newConst } + + /-- Walk a Pi chain to extract the return sort level (the universe of the result type). + Assumes the expression ends in `Sort u` after `numBinders` forall binders. -/ + partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m (Level m) := + match numBinders, expr with + | 0, .sort u => do + let univs := (← read).env.univs.toArray + pure (Level.instBulkReduce univs u) + | 0, _ => do + -- Not syntactically a sort; try to infer + let (_, typ) ← infer expr + match typ.get with + | .sort u => pure u + | _ => throw "inductive return type is not a sort" + | n+1, .forallE dom body _ _ => do + let (domTe, _) ← isSort dom + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl + withExtendedCtx var domVal (getReturnSort body n) + | _, _ => throw "inductive type has fewer binders than expected" + + /-- Typecheck a mutual inductive block starting from one of its addresses. -/ + partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do + let ci ← derefConst addr + -- Find the inductive info + let indInfo ← match ci with + | .inductInfo _ => pure ci + | .ctorInfo v => + match (← read).kenv.find? v.induct with + | some ind@(.inductInfo ..) => pure ind + | _ => throw "Constructor's inductive not found" + | _ => throw "Expected an inductive" + let .inductInfo iv := indInfo | throw "unreachable" + -- Check if already done + if (← get).typedConsts.get? addr |>.isSome then return () + -- Check the inductive type + let univs := iv.toConstantVal.mkUnivParams + let (type, _) ← withEnv (.mk [] univs.toList) (isSort iv.type) + let isStruct := !iv.isRec && iv.numIndices == 0 && iv.ctors.size == 1 && + match (← read).kenv.find? iv.ctors[0]! with + | some (.ctorInfo cv) => cv.numFields > 0 + | _ => false + modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr (TypedConst.inductive type isStruct) } + -- Check constructors + for (ctorAddr, cidx) in iv.ctors.toList.zipIdx do + match (← read).kenv.find? ctorAddr with + | some (.ctorInfo cv) => do + let ctorUnivs := cv.toConstantVal.mkUnivParams + let (ctorType, _) ← withEnv (.mk [] ctorUnivs.toList) (isSort cv.type) + modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cidx cv.numFields) } + | _ => throw s!"Constructor {ctorAddr} not found" + -- Note: recursors are checked individually via checkConst's .recInfo branch, + -- which calls checkConst on the inductives first then checks rules. +end -- mutual + +/-! ## Top-level entry points -/ + +/-- Typecheck a single constant by address. -/ +def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) + (quotInit : Bool := true) : Except String Unit := do + let ctx : TypecheckCtx m := { + lvl := 0, env := default, types := [], kenv := kenv, + prims := prims, safety := .safe, quotInit := quotInit, + mutTypes := default, recAddr? := none + } + let stt : TypecheckState m := { typedConsts := default } + TypecheckM.run ctx stt (checkConst addr) + +/-- Typecheck all constants in a kernel environment. + Uses fresh state per declaration — dependencies are assumed well-typed. -/ +def typecheckAll (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) + : Except String Unit := do + for (addr, ci) in kenv do + match typecheckConst kenv prims addr quotInit with + | .ok () => pure () + | .error e => + let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + let typ := ci.type.pp + let val := match ci.value? with + | some v => s!"\n value: {v.pp}" + | none => "" + throw s!"{header}: {e}\n type: {typ}{val}" + +/-- Typecheck all constants with IO progress reporting. + Uses fresh state per declaration — dependencies are assumed well-typed. -/ +def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) + : IO (Except String Unit) := do + let mut items : Array (Address × ConstantInfo m) := #[] + for (addr, ci) in kenv do + items := items.push (addr, ci) + let total := items.size + for h : idx in [:total] do + let (addr, ci) := items[idx] + --let typ := ci.type.pp + --let val := match ci.value? with + -- | some v => s!"\n value: {v.pp}" + -- | none => "" + let (typ, val) := ("_", "_") + (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})\n type: {typ}{val}" + (← IO.getStdout).flush + match typecheckConst kenv prims addr quotInit with + | .ok () => + (← IO.getStdout).putStrLn s!" ✓ {ci.cv.name}" + (← IO.getStdout).flush + | .error e => + let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + return .error s!"{header}: {e}\n type: {typ}{val}" + return .ok () + +end Ix.Kernel diff --git a/Ix/Kernel/Level.lean b/Ix/Kernel/Level.lean new file mode 100644 index 00000000..f22bcb53 --- /dev/null +++ b/Ix/Kernel/Level.lean @@ -0,0 +1,131 @@ +/- + Level normalization and comparison for `Level m`. + + Generic over MetaMode — metadata on `.param` is ignored. + Adapted from Yatima.Datatypes.Univ + Ix.IxVM.Level. +-/ +import Init.Data.Int +import Ix.Kernel.Types + +namespace Ix.Kernel + +namespace Level + +/-! ## Reduction -/ + +/-- Reduce `max a b` assuming `a` and `b` are already reduced. -/ +def reduceMax (a b : Level m) : Level m := + match a, b with + | .zero, _ => b + | _, .zero => a + | .succ a, .succ b => .succ (reduceMax a b) + | .param idx _, .param idx' _ => if idx == idx' then a else .max a b + | _, _ => .max a b + +/-- Reduce `imax a b` assuming `a` and `b` are already reduced. -/ +def reduceIMax (a b : Level m) : Level m := + match b with + | .zero => .zero + | .succ _ => reduceMax a b + | .param idx _ => match a with + | .param idx' _ => if idx == idx' then a else .imax a b + | _ => .imax a b + | _ => .imax a b + +/-- Reduce a level to normal form. -/ +def reduce : Level m → Level m + | .succ u => .succ (reduce u) + | .max a b => reduceMax (reduce a) (reduce b) + | .imax a b => + let b' := reduce b + match b' with + | .zero => .zero + | .succ _ => reduceMax (reduce a) b' + | _ => .imax (reduce a) b' + | u => u + +/-! ## Instantiation -/ + +/-- Instantiate a single variable and reduce. Assumes `subst` is already reduced. + Does not shift variables (used only in comparison algorithm). -/ +def instReduce (u : Level m) (idx : Nat) (subst : Level m) : Level m := + match u with + | .succ u => .succ (instReduce u idx subst) + | .max a b => reduceMax (instReduce a idx subst) (instReduce b idx subst) + | .imax a b => + let a' := instReduce a idx subst + let b' := instReduce b idx subst + match b' with + | .zero => .zero + | .succ _ => reduceMax a' b' + | _ => .imax a' b' + | .param idx' _ => if idx' == idx then subst else u + | .zero => u + +/-- Instantiate multiple variables at once and reduce. Substitutes `.param idx` by `substs[idx]`. + Assumes already reduced `substs`. -/ +def instBulkReduce (substs : Array (Level m)) : Level m → Level m + | z@(.zero ..) => z + | .succ u => .succ (instBulkReduce substs u) + | .max a b => reduceMax (instBulkReduce substs a) (instBulkReduce substs b) + | .imax a b => + let b' := instBulkReduce substs b + match b' with + | .zero => .zero + | .succ _ => reduceMax (instBulkReduce substs a) b' + | _ => .imax (instBulkReduce substs a) b' + | .param idx name => + if h : idx < substs.size then substs[idx] + else .param (idx - substs.size) name + +/-! ## Comparison -/ + +/-- Comparison algorithm: `a <= b + diff`. Assumes `a` and `b` are already reduced. -/ +partial def leq (a b : Level m) (diff : _root_.Int) : Bool := + if diff >= 0 && match a with | .zero => true | _ => false then true + else match a, b with + | .zero, .zero => diff >= 0 + -- Succ cases + | .succ a, _ => leq a b (diff - 1) + | _, .succ b => leq a b (diff + 1) + | .param .., .zero => false + | .zero, .param .. => diff >= 0 + | .param x _, .param y _ => x == y && diff >= 0 + -- IMax cases + | .imax _ (.param idx _), _ => + leq .zero (instReduce b idx .zero) diff && + let s := .succ (.param idx default) + leq (instReduce a idx s) (instReduce b idx s) diff + | .imax c (.max e f), _ => + let newMax := reduceMax (reduceIMax c e) (reduceIMax c f) + leq newMax b diff + | .imax c (.imax e f), _ => + let newMax := reduceMax (reduceIMax c f) (.imax e f) + leq newMax b diff + | _, .imax _ (.param idx _) => + leq (instReduce a idx .zero) .zero diff && + let s := .succ (.param idx default) + leq (instReduce a idx s) (instReduce b idx s) diff + | _, .imax c (.max e f) => + let newMax := reduceMax (reduceIMax c e) (reduceIMax c f) + leq a newMax diff + | _, .imax c (.imax e f) => + let newMax := reduceMax (reduceIMax c f) (.imax e f) + leq a newMax diff + -- Max cases + | .max c d, _ => leq c b diff && leq d b diff + | _, .max c d => leq a c diff || leq a d diff + | _, _ => false + +/-- Semantic equality of levels. Assumes `a` and `b` are already reduced. -/ +def equalLevel (a b : Level m) : Bool := + leq a b 0 && leq b a 0 + +/-- Faster equality for zero, assumes input is already reduced. -/ +def isZero : Level m → Bool + | .zero => true + | _ => false + +end Level + +end Ix.Kernel diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean new file mode 100644 index 00000000..8b1a93ba --- /dev/null +++ b/Ix/Kernel/TypecheckM.lean @@ -0,0 +1,180 @@ +/- + TypecheckM: Monad stack, context, state, and utilities for the kernel typechecker. +-/ +import Ix.Kernel.Datatypes +import Ix.Kernel.Level + +namespace Ix.Kernel + +/-! ## Typechecker Context -/ + +structure TypecheckCtx (m : MetaMode) where + lvl : Nat + env : ValEnv m + types : List (SusValue m) + kenv : Env m + prims : Primitives + safety : DefinitionSafety + quotInit : Bool + /-- Maps a variable index (mutual reference) to (address, type-value function). -/ + mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare + /-- Tracks the address of the constant currently being checked, for recursion detection. -/ + recAddr? : Option Address + /-- Depth fuel: bounds the call-stack depth to prevent native stack overflow. + Decremented via the reader on each entry to eval/equal/infer. + Thunks inherit the depth from their capture point. -/ + depth : Nat := 3000 + /-- Enable dbg_trace on major entry points for debugging. -/ + trace : Bool := false + deriving Inhabited + +/-! ## Typechecker State -/ + +/-- Default fuel for bounding total recursive work per constant. -/ +def defaultFuel : Nat := 100000 + +structure TypecheckState (m : MetaMode) where + typedConsts : Std.TreeMap Address (TypedConst m) Address.compare + /-- Fuel counter for bounding total recursive work. Decremented on each entry to + eval/equal/infer. Reset at the start of each `checkConst` call. -/ + fuel : Nat := defaultFuel + /-- Cache for evaluated constant definitions. Maps an address to its universe + parameters and evaluated value. Universe-polymorphic constants produce different + values for different universe instantiations, so we store and check univs. -/ + evalCache : Std.HashMap Address (Array (Level m) × Value m) := {} + /-- Cache for definitional equality results. Maps `(ptrAddrUnsafe a, ptrAddrUnsafe b)` + (canonicalized so smaller pointer comes first) to `Bool`. Only `true` results are + cached (monotone under state growth). -/ + equalCache : Std.HashMap (USize × USize) Bool := {} + /-- Cache for constant type SusValues. When `infer (.const addr _)` computes a + suspended type, it is cached here so repeated references to the same constant + share the same SusValue pointer, enabling fast-path pointer equality in `equal`. + Stores universe parameters alongside the value for correctness with polymorphic constants. -/ + constTypeCache : Std.HashMap Address (List (Level m) × SusValue m) := {} + deriving Inhabited + +/-! ## TypecheckM monad -/ + +abbrev TypecheckM (m : MetaMode) := ReaderT (TypecheckCtx m) (StateT (TypecheckState m) (Except String)) + +def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) : Except String α := + match (StateT.run (ReaderT.run x ctx) stt) with + | .error e => .error e + | .ok (a, _) => .ok a + +def TypecheckM.runState (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) + : Except String (α × TypecheckState m) := + StateT.run (ReaderT.run x ctx) stt + +/-! ## Context modifiers -/ + +def withEnv (env : ValEnv m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := env } + +def withResetCtx : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with + lvl := 0, env := default, types := default, mutTypes := default, recAddr? := none } + +def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare) : + TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with mutTypes := mutTypes } + +def withExtendedCtx (val typ : SusValue m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with + lvl := ctx.lvl + 1, + types := typ :: ctx.types, + env := ctx.env.extendWith val } + +def withExtendedEnv (thunk : SusValue m) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := ctx.env.extendWith thunk } + +def withNewExtendedEnv (env : ValEnv m) (thunk : SusValue m) : + TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with env := env.extendWith thunk } + +def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := + withReader fun ctx => { ctx with recAddr? := some addr } + +/-- Check both fuel counters, decrement them, and run the action. + - State fuel bounds total work (prevents exponential blowup / hanging). + - Reader depth bounds call-stack depth (prevents native stack overflow). -/ +def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do + let ctx ← read + if ctx.depth == 0 then + throw "deep recursion depth limit reached" + let stt ← get + if stt.fuel == 0 then throw "deep recursion work limit reached" + set { stt with fuel := stt.fuel - 1 } + withReader (fun ctx => { ctx with depth := ctx.depth - 1 }) action + +/-! ## Name lookup -/ + +/-- Look up the MetaField name for a constant address from the kernel environment. -/ +def lookupName (addr : Address) : TypecheckM m (MetaField m Ix.Name) := do + match (← read).kenv.find? addr with + | some ci => pure ci.cv.name + | none => pure default + +/-! ## Const dereferencing -/ + +def derefConst (addr : Address) : TypecheckM m (ConstantInfo m) := do + let ctx ← read + match ctx.kenv.find? addr with + | some ci => pure ci + | none => throw s!"unknown constant {addr}" + +def derefTypedConst (addr : Address) : TypecheckM m (TypedConst m) := do + match (← get).typedConsts.get? addr with + | some tc => pure tc + | none => throw s!"typed constant not found: {addr}" + +/-! ## Provisional TypedConst -/ + +/-- Extract the major premise's inductive address from a recursor type. + Skips numParams + numMotives + numMinors + numIndices foralls, + then the next forall's domain's app head is the inductive const. -/ +def getMajorInduct (type : Expr m) (numParams numMotives numMinors numIndices : Nat) : Option Address := + go (numParams + numMotives + numMinors + numIndices) type +where + go : Nat → Expr m → Option Address + | 0, e => match e with + | .forallE dom _ _ _ => some dom.getAppFn.constAddr! + | _ => none + | n+1, e => match e with + | .forallE _ body _ _ => go n body + | _ => none + +/-- Build a provisional TypedConst entry from raw ConstantInfo. + Used when `infer` encounters a `.const` reference before the constant + has been fully typechecked. The entry uses default TypeInfo and raw + expressions directly from the kernel environment. -/ +def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := + let rawType : TypedExpr m := ⟨default, ci.type⟩ + match ci with + | .axiomInfo _ => .axiom rawType + | .thmInfo v => .theorem rawType ⟨default, v.value⟩ + | .defnInfo v => + .definition rawType ⟨default, v.value⟩ (v.safety == .partial) + | .opaqueInfo v => .opaque rawType ⟨default, v.value⟩ + | .quotInfo v => .quotient rawType v.kind + | .inductInfo v => + let isStruct := v.ctors.size == 1 -- approximate; refined by checkIndBlock + .inductive rawType isStruct + | .ctorInfo v => .constructor rawType v.cidx v.numFields + | .recInfo v => + let indAddr := getMajorInduct ci.type v.numParams v.numMotives v.numMinors v.numIndices + |>.getD default + let typedRules := v.rules.map fun r => (r.nfields, (⟨default, r.rhs⟩ : TypedExpr m)) + .recursor rawType v.numParams v.numMotives v.numMinors v.numIndices v.k indAddr typedRules + +/-- Ensure a constant has a TypedConst entry. If not already present, build a + provisional one from raw ConstantInfo. This avoids the deep recursion of + `checkConst` when called from `infer`. -/ +def ensureTypedConst (addr : Address) : TypecheckM m Unit := do + if (← get).typedConsts.get? addr |>.isSome then return () + let ci ← derefConst addr + let tc := provisionalTypedConst ci + modify fun stt => { stt with + typedConsts := stt.typedConsts.insert addr tc } + +end Ix.Kernel diff --git a/Ix/Kernel/Types.lean b/Ix/Kernel/Types.lean new file mode 100644 index 00000000..fba45b00 --- /dev/null +++ b/Ix/Kernel/Types.lean @@ -0,0 +1,569 @@ +/- + Kernel Types: Closure-based typechecker types with compile-time metadata erasure. + + The MetaMode flag controls whether name/binder metadata is present: + - `Expr .meta` carries full names and binder info (for debugging) + - `Expr .anon` has Unit fields (proven no metadata leakage) +-/ +import Ix.Address +import Ix.Environment + +namespace Ix.Kernel + +/-! ## MetaMode and MetaField -/ + +inductive MetaMode where | «meta» | anon + +def MetaField (m : MetaMode) (α : Type) : Type := + match m with + | .meta => α + | .anon => Unit + +instance {m : MetaMode} {α : Type} [Inhabited α] : Inhabited (MetaField m α) := + match m with + | .meta => inferInstanceAs (Inhabited α) + | .anon => ⟨()⟩ + +instance {m : MetaMode} {α : Type} [BEq α] : BEq (MetaField m α) := + match m with + | .meta => inferInstanceAs (BEq α) + | .anon => ⟨fun _ _ => true⟩ + +instance {m : MetaMode} {α : Type} [Repr α] : Repr (MetaField m α) := + match m with + | .meta => inferInstanceAs (Repr α) + | .anon => ⟨fun _ _ => "()".toFormat⟩ + +instance {m : MetaMode} {α : Type} [ToString α] : ToString (MetaField m α) := + match m with + | .meta => inferInstanceAs (ToString α) + | .anon => ⟨fun _ => "()"⟩ + +instance {m : MetaMode} {α : Type} [Ord α] : Ord (MetaField m α) := + match m with + | .meta => inferInstanceAs (Ord α) + | .anon => ⟨fun _ _ => .eq⟩ + +/-! ## Level -/ + +inductive Level (m : MetaMode) where + | zero + | succ (l : Level m) + | max (l₁ l₂ : Level m) + | imax (l₁ l₂ : Level m) + | param (idx : Nat) (name : MetaField m Ix.Name) + deriving Inhabited, BEq + +/-! ## Expr -/ + +inductive Expr (m : MetaMode) where + | bvar (idx : Nat) (name : MetaField m Ix.Name) + | sort (level : Level m) + | const (addr : Address) (levels : Array (Level m)) + (name : MetaField m Ix.Name) + | app (fn arg : Expr m) + | lam (ty body : Expr m) + (name : MetaField m Ix.Name) (bi : MetaField m Lean.BinderInfo) + | forallE (ty body : Expr m) + (name : MetaField m Ix.Name) (bi : MetaField m Lean.BinderInfo) + | letE (ty val body : Expr m) + (name : MetaField m Ix.Name) + | lit (l : Lean.Literal) + | proj (typeAddr : Address) (idx : Nat) (struct : Expr m) + (typeName : MetaField m Ix.Name) + deriving Inhabited, BEq + +/-! ## Pretty printing helpers -/ + +private def succCount : Level m → Nat → Nat × Level m + | .succ l, n => succCount l (n + 1) + | l, n => (n, l) + +private partial def ppLevel : Level m → String + | .zero => "0" + | .succ l => + let (n, base) := succCount l 1 + match base with + | .zero => toString n + | _ => s!"{ppLevel base} + {n}" + | .max l₁ l₂ => s!"max ({ppLevel l₁}) ({ppLevel l₂})" + | .imax l₁ l₂ => s!"imax ({ppLevel l₁}) ({ppLevel l₂})" + | .param idx name => + let s := s!"{name}" + if s == "()" then s!"u_{idx}" else s + +private def ppSort (l : Level m) : String := + match l with + | .zero => "Prop" + | .succ .zero => "Type" + | .succ l' => + let s := ppLevel l' + if s.any (· == ' ') then s!"Type ({s})" else s!"Type {s}" + | _ => + let s := ppLevel l + if s.any (· == ' ') then s!"Sort ({s})" else s!"Sort {s}" + +private def ppBinderName (name : MetaField m Ix.Name) : String := + let s := s!"{name}" + if s == "()" then "_" + else if s.isEmpty then "???" + else s + +private def ppVarName (name : MetaField m Ix.Name) (idx : Nat) : String := + let s := s!"{name}" + if s == "()" then s!"^{idx}" + else if s.isEmpty then "???" + else s + +private def ppConstName (name : MetaField m Ix.Name) (addr : Address) : String := + let s := s!"{name}" + if s == "()" then s!"#{String.ofList ((toString addr).toList.take 8)}" + else if s.isEmpty then s!"{addr}" + else s + +/-! ## Expr smart constructors -/ + +namespace Expr + +def mkBVar (idx : Nat) : Expr m := .bvar idx default +def mkSort (level : Level m) : Expr m := .sort level +def mkConst (addr : Address) (levels : Array (Level m)) : Expr m := + .const addr levels default +def mkApp (fn arg : Expr m) : Expr m := .app fn arg +def mkLam (ty body : Expr m) : Expr m := .lam ty body default default +def mkForallE (ty body : Expr m) : Expr m := .forallE ty body default default +def mkLetE (ty val body : Expr m) : Expr m := .letE ty val body default +def mkLit (l : Lean.Literal) : Expr m := .lit l +def mkProj (typeAddr : Address) (idx : Nat) (struct : Expr m) : Expr m := + .proj typeAddr idx struct default + +/-! ### Predicates -/ + +def isSort : Expr m → Bool | sort .. => true | _ => false +def isForall : Expr m → Bool | forallE .. => true | _ => false +def isLambda : Expr m → Bool | lam .. => true | _ => false +def isApp : Expr m → Bool | app .. => true | _ => false +def isLit : Expr m → Bool | lit .. => true | _ => false +def isConst : Expr m → Bool | const .. => true | _ => false +def isBVar : Expr m → Bool | bvar .. => true | _ => false + +def isConstOf (e : Expr m) (addr : Address) : Bool := + match e with | const a _ _ => a == addr | _ => false + +/-! ### Accessors -/ + +def bvarIdx! : Expr m → Nat | bvar i _ => i | _ => panic! "bvarIdx!" +def sortLevel! : Expr m → Level m | sort l => l | _ => panic! "sortLevel!" +def bindingDomain! : Expr m → Expr m + | forallE ty _ _ _ => ty | lam ty _ _ _ => ty | _ => panic! "bindingDomain!" +def bindingBody! : Expr m → Expr m + | forallE _ b _ _ => b | lam _ b _ _ => b | _ => panic! "bindingBody!" +def appFn! : Expr m → Expr m | app f _ => f | _ => panic! "appFn!" +def appArg! : Expr m → Expr m | app _ a => a | _ => panic! "appArg!" +def constAddr! : Expr m → Address | const a _ _ => a | _ => panic! "constAddr!" +def constLevels! : Expr m → Array (Level m) | const _ ls _ => ls | _ => panic! "constLevels!" +def litValue! : Expr m → Lean.Literal | lit l => l | _ => panic! "litValue!" +def projIdx! : Expr m → Nat | proj _ i _ _ => i | _ => panic! "projIdx!" +def projStruct! : Expr m → Expr m | proj _ _ s _ => s | _ => panic! "projStruct!" +def projTypeAddr! : Expr m → Address | proj a _ _ _ => a | _ => panic! "projTypeAddr!" + +/-! ### App Spine -/ + +def getAppFn : Expr m → Expr m + | app f _ => getAppFn f + | e => e + +def getAppNumArgs : Expr m → Nat + | app f _ => getAppNumArgs f + 1 + | _ => 0 + +partial def getAppRevArgs (e : Expr m) : Array (Expr m) := + go e #[] +where + go : Expr m → Array (Expr m) → Array (Expr m) + | app f a, acc => go f (acc.push a) + | _, acc => acc + +def getAppArgs (e : Expr m) : Array (Expr m) := + e.getAppRevArgs.reverse + +def mkAppN (fn : Expr m) (args : Array (Expr m)) : Expr m := + args.foldl (fun acc a => mkApp acc a) fn + +def mkAppRange (fn : Expr m) (start stop : Nat) (args : Array (Expr m)) : Expr m := Id.run do + let mut r := fn + for i in [start:stop] do + r := mkApp r args[i]! + return r + +def prop : Expr m := mkSort .zero + +partial def pp (atom : Bool := false) : Expr m → String + | .bvar idx name => ppVarName name idx + | .sort level => ppSort level + | .const addr _ name => ppConstName name addr + | .app fn arg => + let s := s!"{pp false fn} {pp true arg}" + if atom then s!"({s})" else s + | .lam ty body name _ => + let s := ppLam s!"({ppBinderName name} : {pp false ty})" body + if atom then s!"({s})" else s + | .forallE ty body name _ => + let s := ppPi s!"({ppBinderName name} : {pp false ty})" body + if atom then s!"({s})" else s + | .letE ty val body name => + let s := s!"let {ppBinderName name} : {pp false ty} := {pp false val}; {pp false body}" + if atom then s!"({s})" else s + | .lit (.natVal n) => toString n + | .lit (.strVal s) => s!"\"{s}\"" + | .proj _ idx struct _ => s!"{pp true struct}.{idx}" +where + ppLam (acc : String) : Expr m → String + | .lam ty body name _ => + ppLam s!"{acc} ({ppBinderName name} : {pp false ty})" body + | e => s!"λ {acc} => {pp false e}" + ppPi (acc : String) : Expr m → String + | .forallE ty body name _ => + ppPi s!"{acc} ({ppBinderName name} : {pp false ty})" body + | e => s!"∀ {acc}, {pp false e}" + +/-- Short constructor tag for tracing (no recursion into subterms). -/ +def tag : Expr m → String + | .bvar idx _ => s!"bvar({idx})" + | .sort _ => "sort" + | .const _ _ name => s!"const({name})" + | .app .. => "app" + | .lam .. => "lam" + | .forallE .. => "forallE" + | .letE .. => "letE" + | .lit (.natVal n) => s!"natLit({n})" + | .lit (.strVal s) => s!"strLit({s})" + | .proj _ idx _ _ => s!"proj({idx})" + +end Expr + +/-! ## Enums -/ + +inductive DefinitionSafety where + | safe | «unsafe» | «partial» + deriving BEq, Repr, Inhabited + +inductive ReducibilityHints where + | opaque | abbrev | regular (height : UInt32) + deriving BEq, Repr, Inhabited + +namespace ReducibilityHints + +def lt' : ReducibilityHints → ReducibilityHints → Bool + | .regular d₁, .regular d₂ => d₁ < d₂ + | .regular _, .opaque => true + | .abbrev, .opaque => true + | _, _ => false + +def isRegular : ReducibilityHints → Bool + | .regular _ => true + | _ => false + +end ReducibilityHints + +inductive QuotKind where + | type | ctor | lift | ind + deriving BEq, Repr, Inhabited + +/-! ## ConstantInfo -/ + +structure ConstantVal (m : MetaMode) where + numLevels : Nat + type : Expr m + name : MetaField m Ix.Name + levelParams : MetaField m (Array Ix.Name) + deriving Inhabited + +def ConstantVal.mkUnivParams (cv : ConstantVal m) : Array (Level m) := + match m with + | .meta => + let lps : Array Ix.Name := cv.levelParams + Array.ofFn (n := cv.numLevels) fun i => + .param i.val (if h : i.val < lps.size then lps[i.val] else default) + | .anon => Array.ofFn (n := cv.numLevels) fun i => .param i.val () + +structure AxiomVal (m : MetaMode) extends ConstantVal m where + isUnsafe : Bool + +structure DefinitionVal (m : MetaMode) extends ConstantVal m where + value : Expr m + hints : ReducibilityHints + safety : DefinitionSafety + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure TheoremVal (m : MetaMode) extends ConstantVal m where + value : Expr m + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure OpaqueVal (m : MetaMode) extends ConstantVal m where + value : Expr m + isUnsafe : Bool + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + +structure QuotVal (m : MetaMode) extends ConstantVal m where + kind : QuotKind + +structure InductiveVal (m : MetaMode) extends ConstantVal m where + numParams : Nat + numIndices : Nat + all : Array Address + ctors : Array Address + allNames : MetaField m (Array Ix.Name) := default + ctorNames : MetaField m (Array Ix.Name) := default + numNested : Nat + isRec : Bool + isUnsafe : Bool + isReflexive : Bool + +structure ConstructorVal (m : MetaMode) extends ConstantVal m where + induct : Address + inductName : MetaField m Ix.Name := default + cidx : Nat + numParams : Nat + numFields : Nat + isUnsafe : Bool + +structure RecursorRule (m : MetaMode) where + ctor : Address + ctorName : MetaField m Ix.Name := default + nfields : Nat + rhs : Expr m + +structure RecursorVal (m : MetaMode) extends ConstantVal m where + all : Array Address + allNames : MetaField m (Array Ix.Name) := default + numParams : Nat + numIndices : Nat + numMotives : Nat + numMinors : Nat + rules : Array (RecursorRule m) + k : Bool + isUnsafe : Bool + +inductive ConstantInfo (m : MetaMode) where + | axiomInfo (val : AxiomVal m) + | defnInfo (val : DefinitionVal m) + | thmInfo (val : TheoremVal m) + | opaqueInfo (val : OpaqueVal m) + | quotInfo (val : QuotVal m) + | inductInfo (val : InductiveVal m) + | ctorInfo (val : ConstructorVal m) + | recInfo (val : RecursorVal m) + +namespace ConstantInfo + +def cv : ConstantInfo m → ConstantVal m + | axiomInfo v => v.toConstantVal + | defnInfo v => v.toConstantVal + | thmInfo v => v.toConstantVal + | opaqueInfo v => v.toConstantVal + | quotInfo v => v.toConstantVal + | inductInfo v => v.toConstantVal + | ctorInfo v => v.toConstantVal + | recInfo v => v.toConstantVal + +def numLevels (ci : ConstantInfo m) : Nat := ci.cv.numLevels +def type (ci : ConstantInfo m) : Expr m := ci.cv.type + +def isUnsafe : ConstantInfo m → Bool + | axiomInfo v => v.isUnsafe + | defnInfo v => v.safety == .unsafe + | thmInfo _ => false + | opaqueInfo v => v.isUnsafe + | quotInfo _ => false + | inductInfo v => v.isUnsafe + | ctorInfo v => v.isUnsafe + | recInfo v => v.isUnsafe + +def hasValue : ConstantInfo m → Bool + | defnInfo .. | thmInfo .. | opaqueInfo .. => true + | _ => false + +def value? : ConstantInfo m → Option (Expr m) + | defnInfo v => some v.value + | thmInfo v => some v.value + | opaqueInfo v => some v.value + | _ => none + +def hints : ConstantInfo m → ReducibilityHints + | defnInfo v => v.hints + | _ => .opaque + +def safety : ConstantInfo m → DefinitionSafety + | defnInfo v => v.safety + | _ => .safe + +def all? : ConstantInfo m → Option (Array Address) + | defnInfo v => some v.all + | thmInfo v => some v.all + | opaqueInfo v => some v.all + | inductInfo v => some v.all + | recInfo v => some v.all + | _ => none + +def kindName : ConstantInfo m → String + | axiomInfo .. => "axiom" + | defnInfo .. => "definition" + | thmInfo .. => "theorem" + | opaqueInfo .. => "opaque" + | quotInfo .. => "quotient" + | inductInfo .. => "inductive" + | ctorInfo .. => "constructor" + | recInfo .. => "recursor" + +end ConstantInfo + +/-! ## Kernel.Env -/ + +def Address.compare (a b : Address) : Ordering := Ord.compare a b + +structure EnvId (m : MetaMode) where + addr : Address + name : MetaField m Ix.Name + +instance : Inhabited (EnvId m) where + default := ⟨default, default⟩ + +instance : BEq (EnvId m) where + beq a b := a.addr == b.addr && a.name == b.name + +def EnvId.compare (a b : EnvId m) : Ordering := + match Address.compare a.addr b.addr with + | .eq => Ord.compare a.name b.name + | ord => ord + +structure Env (m : MetaMode) where + entries : Std.TreeMap (EnvId m) (ConstantInfo m) EnvId.compare + addrIndex : Std.TreeMap Address (EnvId m) Address.compare + +instance : Inhabited (Env m) where + default := { entries := .empty, addrIndex := .empty } + +instance : ForIn n (Env m) (Address × ConstantInfo m) where + forIn env init f := + ForIn.forIn env.entries init fun p acc => f (p.1.addr, p.2) acc + +namespace Env + +def find? (env : Env m) (addr : Address) : Option (ConstantInfo m) := + match env.addrIndex.get? addr with + | some id => env.entries.get? id + | none => none + +def findByEnvId (env : Env m) (id : EnvId m) : Option (ConstantInfo m) := + env.entries.get? id + +def get (env : Env m) (addr : Address) : Except String (ConstantInfo m) := + match env.find? addr with + | some ci => .ok ci + | none => .error s!"unknown constant {addr}" + +def insert (env : Env m) (addr : Address) (ci : ConstantInfo m) : Env m := + let id : EnvId m := ⟨addr, ci.cv.name⟩ + let entries := env.entries.insert id ci + let addrIndex := match env.addrIndex.get? addr with + | some _ => env.addrIndex + | none => env.addrIndex.insert addr id + { entries, addrIndex } + +def add (env : Env m) (addr : Address) (ci : ConstantInfo m) : Env m := + env.insert addr ci + +def size (env : Env m) : Nat := + env.addrIndex.size + +def contains (env : Env m) (addr : Address) : Bool := + env.addrIndex.get? addr |>.isSome + +def isStructureLike (env : Env m) (addr : Address) : Bool := + match env.find? addr with + | some (.inductInfo v) => + !v.isRec && v.numIndices == 0 && v.ctors.size == 1 && + match env.find? v.ctors[0]! with + | some (.ctorInfo cv) => cv.numFields > 0 + | _ => false + | _ => false + +end Env + +/-! ## Primitives -/ + +private def addr! (s : String) : Address := + match Address.fromString s with + | some a => a + | none => panic! s!"invalid hex address: {s}" + +structure Primitives where + nat : Address := default + natZero : Address := default + natSucc : Address := default + natAdd : Address := default + natSub : Address := default + natMul : Address := default + natPow : Address := default + natGcd : Address := default + natMod : Address := default + natDiv : Address := default + natBeq : Address := default + natBle : Address := default + natLand : Address := default + natLor : Address := default + natXor : Address := default + natShiftLeft : Address := default + natShiftRight : Address := default + bool : Address := default + boolTrue : Address := default + boolFalse : Address := default + string : Address := default + stringMk : Address := default + char : Address := default + charMk : Address := default + list : Address := default + listNil : Address := default + listCons : Address := default + quotType : Address := default + quotCtor : Address := default + quotLift : Address := default + quotInd : Address := default + deriving Repr, Inhabited + +def buildPrimitives : Primitives := + { nat := addr! "fc0e1e912f2d7f12049a5b315d76eec29562e34dc39ebca25287ae58807db137" + natZero := addr! "fac82f0d2555d6a63e1b8a1fe8d86bd293197f39c396fdc23c1275c60f182b37" + natSucc := addr! "7190ce56f6a2a847b944a355e3ec595a4036fb07e3c3db9d9064fc041be72b64" + natAdd := addr! "dcc96f3f914e363d1e906a8be4c8f49b994137bfdb077d07b6c8a4cf88a4f7bf" + natSub := addr! "6903e9bbd169b6c5515b27b3fc0c289ba2ff8e7e0c7f984747d572de4e6a7853" + natMul := addr! "8e641c3df8fe3878e5a219c888552802743b9251c3c37c32795f5b9b9e0818a5" + natPow := addr! "d9be78292bb4e79c03daaaad82e756c5eb4dd5535d33b155ea69e5cbce6bc056" + natGcd := addr! "e8a3be39063744a43812e1f7b8785e3f5a4d5d1a408515903aa05d1724aeb465" + natMod := addr! "14031083457b8411f655765167b1a57fcd542c621e0c391b15ff5ee716c22a67" + natDiv := addr! "863c18d3a5b100a5a5e423c20439d8ab4941818421a6bcf673445335cc559e55" + natBeq := addr! "127a9d47a15fc2bf91a36f7c2182028857133b881554ece4df63344ec93eb2ce" + natBle := addr! "6e4c17dc72819954d6d6afc412a3639a07aff6676b0813cdc419809cc4513df5" + natLand := addr! "e1425deee6279e2db2ff649964b1a66d4013cc08f9e968fb22cc0a64560e181a" + natLor := addr! "3649a28f945b281bd8657e55f93ae0b8f8313488fb8669992a1ba1373cbff8f6" + natXor := addr! "a711ef2cb4fa8221bebaa17ef8f4a965cf30678a89bc45ff18a13c902e683cc5" + natShiftLeft := addr! "16e4558f51891516843a5b30ddd9d9b405ec096d3e1c728d09ff152b345dd607" + natShiftRight := addr! "b9515e6c2c6b18635b1c65ebca18b5616483ebd53936f78e4ae123f6a27a089e" + bool := addr! "6405a455ba70c2b2179c7966c6f610bf3417bd0f3dd2ba7a522533c2cd9e1d0b" + boolTrue := addr! "420dead2168abd16a7050edfd8e17d45155237d3118782d0e68b6de87742cb8d" + boolFalse := addr! "c127f89f92e0481f7a3e0631c5615fe7f6cbbf439d5fd7eba400fb0603aedf2f" + string := addr! "591cf1c489d505d4082f2767500f123e29db5227eb1bae4721eeedd672f36190" + stringMk := addr! "f055b87da4265d980cdede04ce5c7d986866e55816dc94d32a5d90e805101230" + char := addr! "563b426b73cdf1538b767308d12d10d746e1f0b3b55047085bf690319a86f893" + charMk := addr! "7156fef44bc309789375d784e5c36e387f7119363dd9cd349226c52df43d2075" + list := addr! "abed9ff1aba4634abc0bd3af76ca544285a32dcfe43dc27b129aea8867457620" + listNil := addr! "0ebe345dc46917c824b6c3f6c42b101f2ac8c0e2c99f033a0ee3c60acb9cd84d" + listCons := addr! "f79842f10206598929e6ba60ce3ebaa00d11f201c99e80285f46cc0e90932832" + -- Quot primitives need to be computed; use default until wired up + } + +end Ix.Kernel diff --git a/Main.lean b/Main.lean index 3d111f56..d775bf88 100644 --- a/Main.lean +++ b/Main.lean @@ -1,5 +1,6 @@ --import Ix.Cli.ProveCmd --import Ix.Cli.StoreCmd +import Ix.Cli.CheckCmd import Ix.Cli.CompileCmd import Ix.Cli.ServeCmd import Ix.Cli.ConnectCmd @@ -15,6 +16,7 @@ def ixCmd : Cli.Cmd := `[Cli| SUBCOMMANDS: --proveCmd; --storeCmd; + checkCmd; compileCmd; serveCmd; connectCmd diff --git a/Tests/Ix/Check.lean b/Tests/Ix/Check.lean new file mode 100644 index 00000000..404b478d --- /dev/null +++ b/Tests/Ix/Check.lean @@ -0,0 +1,107 @@ +/- + Kernel type-checker integration tests. + Tests both the Rust kernel (via FFI) and the Lean NbE kernel. +-/ + +import Ix.Kernel +import Ix.Common +import Ix.Meta +import Ix.CompileM +import Lean +import LSpec + +open LSpec + +namespace Tests.Check + +/-! ## Rust kernel tests -/ + +def testCheckEnv : TestSeq := + .individualIO "Rust kernel check_env" (do + let leanEnv ← get_env! + let totalConsts := leanEnv.constants.toList.length + + IO.println s!"[Check] Environment has {totalConsts} constants" + + let start ← IO.monoMsNow + let errors ← Ix.Kernel.rsCheckEnv leanEnv + let elapsed := (← IO.monoMsNow) - start + + IO.println s!"[Check] Rust kernel checked {totalConsts} constants in {elapsed.formatMs}" + + if errors.isEmpty then + IO.println s!"[Check] All constants passed" + return (true, none) + else + IO.println s!"[Check] {errors.size} error(s):" + for (name, err) in errors[:min 20 errors.size] do + IO.println s!" {repr name}: {repr err}" + return (false, some s!"Kernel check failed with {errors.size} error(s)") + ) .done + +def testCheckConst (name : String) : TestSeq := + .individualIO s!"check {name}" (do + let leanEnv ← get_env! + let start ← IO.monoMsNow + let result ← Ix.Kernel.rsCheckConst leanEnv name + let elapsed := (← IO.monoMsNow) - start + match result with + | none => + IO.println s!" [ok] {name} ({elapsed.formatMs})" + return (true, none) + | some err => + IO.println s!" [fail] {name}: {repr err} ({elapsed.formatMs})" + return (false, some s!"{name} failed: {repr err}") + ) .done + +/-! ## Lean NbE kernel tests -/ + +def testKernelCheckEnv : TestSeq := + .individualIO "Lean NbE kernel check_env" (do + let leanEnv ← get_env! + + IO.println s!"[Kernel-NbE] Compiling to Ixon..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileElapsed := (← IO.monoMsNow) - compileStart + let numConsts := ixonEnv.consts.size + IO.println s!"[Kernel-NbE] Compiled {numConsts} constants in {compileElapsed.formatMs}" + + IO.println s!"[Kernel-NbE] Converting..." + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[Kernel-NbE] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, prims, quotInit) => + let convertElapsed := (← IO.monoMsNow) - convertStart + IO.println s!"[Kernel-NbE] Converted {kenv.size} constants in {convertElapsed.formatMs}" + + IO.println s!"[Kernel-NbE] Typechecking {kenv.size} constants..." + let checkStart ← IO.monoMsNow + match ← Ix.Kernel.typecheckAllIO kenv prims quotInit with + | .error e => + let elapsed := (← IO.monoMsNow) - checkStart + IO.println s!"[Kernel-NbE] typecheckAll error in {elapsed.formatMs}: {e}" + return (false, some s!"Kernel NbE check failed: {e}") + | .ok () => + let elapsed := (← IO.monoMsNow) - checkStart + IO.println s!"[Kernel-NbE] All constants passed in {elapsed.formatMs}" + return (true, none) + ) .done + +/-! ## Test suites -/ + +def checkSuiteIO : List TestSeq := [ + testCheckConst "Nat.add", +] + +def checkAllSuiteIO : List TestSeq := [ + testCheckEnv, +] + +def kernelSuiteIO : List TestSeq := [ + testKernelCheckEnv, +] + +end Tests.Check diff --git a/Tests/Ix/Compile.lean b/Tests/Ix/Compile.lean index fa6dadff..af14f820 100644 --- a/Tests/Ix/Compile.lean +++ b/Tests/Ix/Compile.lean @@ -9,6 +9,8 @@ import Ix.Address import Ix.Common import Ix.Meta import Ix.CompileM +import Ix.DecompileM +import Ix.CanonM import Ix.CondenseM import Ix.GraphM import Ix.Sharing @@ -458,10 +460,79 @@ def testCrossImpl : TestSeq := return (false, some s!"Found {result.mismatchedConstants.size} mismatches") ) .done -/-! ## Test Suite -/ +/-! ## Lean → Ixon → Ix → Lean full roundtrip -/ + +/-- Full roundtrip: Rust-compile Lean env to Ixon, decompile back to Ix, uncanon back to Lean, + then structurally compare every constant against the original. -/ +def testIxonFullRoundtrip : TestSeq := + .individualIO "Lean→Ixon→Ix→Lean full roundtrip" (do + let leanEnv ← get_env! + let totalConsts := leanEnv.constants.toList.length + IO.println s!"[ixon-roundtrip] Lean env: {totalConsts} constants" + + -- Step 1: Rust compile to Ixon.Env + IO.println s!"[ixon-roundtrip] Step 1: Rust compile..." + let compileStart ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - compileStart + IO.println s!"[ixon-roundtrip] {ixonEnv.named.size} named, {ixonEnv.consts.size} consts in {compileMs}ms" + + -- Step 2: Decompile Ixon → Ix + IO.println s!"[ixon-roundtrip] Step 2: Decompile Ixon→Ix (parallel)..." + let decompStart ← IO.monoMsNow + let (ixConsts, decompErrors) := Ix.DecompileM.decompileAllParallel ixonEnv + let decompMs := (← IO.monoMsNow) - decompStart + IO.println s!"[ixon-roundtrip] {ixConsts.size} ok, {decompErrors.size} errors in {decompMs}ms" + if !decompErrors.isEmpty then + IO.println s!"[ixon-roundtrip] First errors:" + for (name, err) in decompErrors.toList.take 5 do + IO.println s!" {name}: {err}" + + -- Step 3: Uncanon Ix → Lean + IO.println s!"[ixon-roundtrip] Step 3: Uncanon Ix→Lean (parallel)..." + let uncanonStart ← IO.monoMsNow + let roundtripped := Ix.CanonM.uncanonEnvParallel ixConsts + let uncanonMs := (← IO.monoMsNow) - uncanonStart + IO.println s!"[ixon-roundtrip] {roundtripped.size} constants in {uncanonMs}ms" + + -- Step 4: Compare roundtripped Lean constants against originals + IO.println s!"[ixon-roundtrip] Step 4: Comparing against original..." + let compareStart ← IO.monoMsNow + let origMap : Std.HashMap Lean.Name Lean.ConstantInfo := + leanEnv.constants.fold (init := {}) fun acc name const => acc.insert name const + let (nMismatches, nMissing, mismatchNames, missingNames) := + Ix.CanonM.compareEnvsParallel origMap roundtripped + let compareMs := (← IO.monoMsNow) - compareStart + IO.println s!"[ixon-roundtrip] {nMissing} missing, {nMismatches} mismatches in {compareMs}ms" + + if !missingNames.isEmpty then + IO.println s!"[ixon-roundtrip] First missing:" + for name in missingNames.toList.take 10 do + IO.println s!" {name}" + + if !mismatchNames.isEmpty then + IO.println s!"[ixon-roundtrip] First mismatches:" + for name in mismatchNames.toList.take 20 do + IO.println s!" {name}" + + let totalMs := compileMs + decompMs + uncanonMs + compareMs + IO.println s!"[ixon-roundtrip] Total: {totalMs}ms" + + let success := decompErrors.size == 0 && nMismatches == 0 && nMissing == 0 + if success then + return (true, none) + else + return (false, some s!"{decompErrors.size} decompile errors, {nMismatches} mismatches, {nMissing} missing") + ) .done + +/-! ## Test Suites -/ def compileSuiteIO : List TestSeq := [ testCrossImpl, ] +def ixonRoundtripSuiteIO : List TestSeq := [ + testIxonFullRoundtrip, +] + end Tests.Compile diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean new file mode 100644 index 00000000..f1ed3c55 --- /dev/null +++ b/Tests/Ix/KernelTests.lean @@ -0,0 +1,761 @@ +/- + Kernel test suite. + - Unit tests for Kernel types, expression operations, and level operations + - Convert tests (Ixon.Env → Kernel.Env) + - Targeted constant-checking tests (individual constants through the full pipeline) +-/ +import Ix.Kernel +import Ix.Kernel.DecompileM +import Ix.CompileM +import Ix.Common +import Ix.Meta +import LSpec + +open LSpec +open Ix.Kernel + +namespace Tests.KernelTests + +/-! ## Unit tests: Expression equality -/ + +def testExprHashEq : TestSeq := + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv0' : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + test "mkBVar 0 == mkBVar 0" (bv0 == bv0') ++ + test "mkBVar 0 != mkBVar 1" (bv0 != bv1) ++ + -- Sort equality + let s0 : Expr .anon := Expr.mkSort Level.zero + let s0' : Expr .anon := Expr.mkSort Level.zero + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "mkSort 0 == mkSort 0" (s0 == s0') ++ + test "mkSort 0 != mkSort 1" (s0 != s1) ++ + -- App equality + let app1 := Expr.mkApp bv0 bv1 + let app1' := Expr.mkApp bv0 bv1 + let app2 := Expr.mkApp bv1 bv0 + test "mkApp bv0 bv1 == mkApp bv0 bv1" (app1 == app1') ++ + test "mkApp bv0 bv1 != mkApp bv1 bv0" (app1 != app2) ++ + -- Lambda equality + let lam1 := Expr.mkLam s0 bv0 + let lam1' := Expr.mkLam s0 bv0 + let lam2 := Expr.mkLam s1 bv0 + test "mkLam s0 bv0 == mkLam s0 bv0" (lam1 == lam1') ++ + test "mkLam s0 bv0 != mkLam s1 bv0" (lam1 != lam2) ++ + -- Forall equality + let pi1 := Expr.mkForallE s0 s1 + let pi1' := Expr.mkForallE s0 s1 + test "mkForallE s0 s1 == mkForallE s0 s1" (pi1 == pi1') ++ + -- Const equality + let addr1 := Address.blake3 (ByteArray.mk #[1]) + let addr2 := Address.blake3 (ByteArray.mk #[2]) + let c1 : Expr .anon := Expr.mkConst addr1 #[] + let c1' : Expr .anon := Expr.mkConst addr1 #[] + let c2 : Expr .anon := Expr.mkConst addr2 #[] + test "mkConst addr1 == mkConst addr1" (c1 == c1') ++ + test "mkConst addr1 != mkConst addr2" (c1 != c2) ++ + -- Const with levels + let c1l : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l' : Expr .anon := Expr.mkConst addr1 #[Level.zero] + let c1l2 : Expr .anon := Expr.mkConst addr1 #[Level.succ Level.zero] + test "mkConst addr1 [0] == mkConst addr1 [0]" (c1l == c1l') ++ + test "mkConst addr1 [0] != mkConst addr1 [1]" (c1l != c1l2) ++ + -- Literal equality + let nat0 : Expr .anon := Expr.mkLit (.natVal 0) + let nat0' : Expr .anon := Expr.mkLit (.natVal 0) + let nat1 : Expr .anon := Expr.mkLit (.natVal 1) + let str1 : Expr .anon := Expr.mkLit (.strVal "hello") + let str1' : Expr .anon := Expr.mkLit (.strVal "hello") + let str2 : Expr .anon := Expr.mkLit (.strVal "world") + test "lit nat 0 == lit nat 0" (nat0 == nat0') ++ + test "lit nat 0 != lit nat 1" (nat0 != nat1) ++ + test "lit str hello == lit str hello" (str1 == str1') ++ + test "lit str hello != lit str world" (str1 != str2) ++ + .done + +/-! ## Unit tests: Expression operations -/ + +def testExprOps : TestSeq := + -- getAppFn / getAppArgs + let bv0 : Expr .anon := Expr.mkBVar 0 + let bv1 : Expr .anon := Expr.mkBVar 1 + let bv2 : Expr .anon := Expr.mkBVar 2 + let app := Expr.mkApp (Expr.mkApp bv0 bv1) bv2 + test "getAppFn (app (app bv0 bv1) bv2) == bv0" (app.getAppFn == bv0) ++ + test "getAppNumArgs == 2" (app.getAppNumArgs == 2) ++ + test "getAppArgs[0] == bv1" (app.getAppArgs[0]! == bv1) ++ + test "getAppArgs[1] == bv2" (app.getAppArgs[1]! == bv2) ++ + -- mkAppN round-trips + let rebuilt := Expr.mkAppN bv0 #[bv1, bv2] + test "mkAppN round-trips" (rebuilt == app) ++ + -- Predicates + test "isApp" app.isApp ++ + test "isSort" (Expr.mkSort (Level.zero : Level .anon)).isSort ++ + test "isLambda" (Expr.mkLam bv0 bv1).isLambda ++ + test "isForall" (Expr.mkForallE bv0 bv1).isForall ++ + test "isLit" (Expr.mkLit (.natVal 42) : Expr .anon).isLit ++ + test "isBVar" bv0.isBVar ++ + test "isConst" (Expr.mkConst (m := .anon) default #[]).isConst ++ + -- Accessors + let s1 : Expr .anon := Expr.mkSort (Level.succ Level.zero) + test "sortLevel!" (s1.sortLevel! == Level.succ Level.zero) ++ + test "bvarIdx!" (bv1.bvarIdx! == 1) ++ + .done + +/-! ## Unit tests: Level operations -/ + +def testLevelOps : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- reduce + test "reduce zero" (Level.reduce l0 == l0) ++ + test "reduce (succ zero)" (Level.reduce l1 == l1) ++ + -- equalLevel + test "zero equiv zero" (Level.equalLevel l0 l0) ++ + test "succ zero equiv succ zero" (Level.equalLevel l1 l1) ++ + test "max a b equiv max b a" + (Level.equalLevel (Level.max p0 p1) (Level.max p1 p0)) ++ + test "zero not equiv succ zero" (!Level.equalLevel l0 l1) ++ + -- leq + test "zero <= zero" (Level.leq l0 l0 0) ++ + test "succ zero <= zero + 1" (Level.leq l1 l0 1) ++ + test "not (succ zero <= zero)" (!Level.leq l1 l0 0) ++ + test "param 0 <= param 0" (Level.leq p0 p0 0) ++ + test "succ (param 0) <= param 0 + 1" + (Level.leq (Level.succ p0) p0 1) ++ + test "not (succ (param 0) <= param 0)" + (!Level.leq (Level.succ p0) p0 0) ++ + .done + +/-! ## Integration tests: Const pipeline -/ + +/-- Parse a dotted name string like "Nat.add" into an Ix.Name. -/ +private def parseIxName (s : String) : Ix.Name := + let parts := s.splitOn "." + parts.foldl (fun acc part => Ix.Name.mkStr acc part) Ix.Name.mkAnon + +/-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ +private partial def leanNameToIx : Lean.Name → Ix.Name + | .anonymous => Ix.Name.mkAnon + | .str pre s => Ix.Name.mkStr (leanNameToIx pre) s + | .num pre n => Ix.Name.mkNat (leanNameToIx pre) n + +def testConvertEnv : TestSeq := + .individualIO "rsCompileEnv + convertEnv" (do + let leanEnv ← get_env! + let leanCount := leanEnv.constants.toList.length + IO.println s!"[kernel] Lean env: {leanCount} constants" + let start ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - start + let ixonCount := ixonEnv.consts.size + let namedCount := ixonEnv.named.size + IO.println s!"[kernel] rsCompileEnv: {ixonCount} consts, {namedCount} named in {compileMs.formatMs}" + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[kernel] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, _, _) => + let convertMs := (← IO.monoMsNow) - convertStart + let kenvCount := kenv.size + IO.println s!"[kernel] convertEnv: {kenvCount} consts in {convertMs.formatMs} ({ixonCount - kenvCount} muts blocks)" + -- Verify every Lean constant is present in the Kernel.Env + let mut missing : Array String := #[] + let mut notCompiled : Array String := #[] + let mut checked := 0 + for (leanName, _) in leanEnv.constants.toList do + let ixName := leanNameToIx leanName + match ixonEnv.named.get? ixName with + | none => notCompiled := notCompiled.push (toString leanName) + | some named => + checked := checked + 1 + if !kenv.contains named.addr then + missing := missing.push (toString leanName) + if !notCompiled.isEmpty then + IO.println s!"[kernel] {notCompiled.size} Lean constants not in ixonEnv.named (unexpected)" + for n in notCompiled[:min 10 notCompiled.size] do + IO.println s!" not compiled: {n}" + if missing.isEmpty then + IO.println s!"[kernel] All {checked} named constants found in Kernel.Env" + return (true, none) + else + IO.println s!"[kernel] {missing.size}/{checked} named constants missing from Kernel.Env" + for n in missing[:min 20 missing.size] do + IO.println s!" missing: {n}" + return (false, some s!"{missing.size} constants missing from Kernel.Env") + ) .done + +/-- Const pipeline: compile, convert, typecheck specific constants. -/ +def testConstPipeline : TestSeq := + .individualIO "kernel const pipeline" (do + let leanEnv ← get_env! + let start ← IO.monoMsNow + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let compileMs := (← IO.monoMsNow) - start + IO.println s!"[kernel] rsCompileEnv: {ixonEnv.consts.size} consts in {compileMs.formatMs}" + + let convertStart ← IO.monoMsNow + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[kernel] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, prims, quotInit) => + let convertMs := (← IO.monoMsNow) - convertStart + IO.println s!"[kernel] convertEnv: {kenv.size} consts in {convertMs.formatMs}" + + -- Check specific constants + let constNames := #[ + "Nat", "Nat.zero", "Nat.succ", "Nat.rec", + "Bool", "Bool.true", "Bool.false", "Bool.rec", + "Eq", "Eq.refl", + "List", "List.nil", "List.cons", + "Nat.below" + ] + let checkStart ← IO.monoMsNow + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"{name}: {e}" + let checkMs := (← IO.monoMsNow) - checkStart + IO.println s!"[kernel] {passed}/{constNames.size} passed in {checkMs.formatMs}" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Primitive address verification -/ + +/-- Look up a primitive address by name (for verification only). -/ +private def lookupPrim (ixonEnv : Ixon.Env) (name : String) : Address := + let ixName := parseIxName name + match ixonEnv.named.get? ixName with + | some n => n.addr + | none => default + +/-- Verify hardcoded primitive addresses match actual compiled addresses. -/ +def testVerifyPrimAddrs : TestSeq := + .individualIO "verify primitive addresses" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let hardcoded := Ix.Kernel.buildPrimitives + let mut failures : Array String := #[] + let checks : Array (String × String × Address) := #[ + ("nat", "Nat", hardcoded.nat), + ("natZero", "Nat.zero", hardcoded.natZero), + ("natSucc", "Nat.succ", hardcoded.natSucc), + ("natAdd", "Nat.add", hardcoded.natAdd), + ("natSub", "Nat.sub", hardcoded.natSub), + ("natMul", "Nat.mul", hardcoded.natMul), + ("natPow", "Nat.pow", hardcoded.natPow), + ("natGcd", "Nat.gcd", hardcoded.natGcd), + ("natMod", "Nat.mod", hardcoded.natMod), + ("natDiv", "Nat.div", hardcoded.natDiv), + ("natBeq", "Nat.beq", hardcoded.natBeq), + ("natBle", "Nat.ble", hardcoded.natBle), + ("natLand", "Nat.land", hardcoded.natLand), + ("natLor", "Nat.lor", hardcoded.natLor), + ("natXor", "Nat.xor", hardcoded.natXor), + ("natShiftLeft", "Nat.shiftLeft", hardcoded.natShiftLeft), + ("natShiftRight", "Nat.shiftRight", hardcoded.natShiftRight), + ("bool", "Bool", hardcoded.bool), + ("boolTrue", "Bool.true", hardcoded.boolTrue), + ("boolFalse", "Bool.false", hardcoded.boolFalse), + ("string", "String", hardcoded.string), + ("stringMk", "String.mk", hardcoded.stringMk), + ("char", "Char", hardcoded.char), + ("charMk", "Char.ofNat", hardcoded.charMk), + ("list", "List", hardcoded.list), + ("listNil", "List.nil", hardcoded.listNil), + ("listCons", "List.cons", hardcoded.listCons) + ] + for (field, name, expected) in checks do + let actual := lookupPrim ixonEnv name + if actual != expected then + failures := failures.push s!"{field}: expected {expected}, got {actual}" + IO.println s!" [MISMATCH] {field} ({name}): {actual} != {expected}" + if failures.isEmpty then + IO.println s!"[prims] All {checks.size} primitive addresses verified" + return (true, none) + else + return (false, some s!"{failures.size} primitive address mismatch(es). Run `lake test -- kernel-dump-prims` to update.") + ) .done + +/-- Dump all primitive addresses for hardcoding. Use this to update buildPrimitives. -/ +def testDumpPrimAddrs : TestSeq := + .individualIO "dump primitive addresses" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let names := #[ + ("nat", "Nat"), ("natZero", "Nat.zero"), ("natSucc", "Nat.succ"), + ("natAdd", "Nat.add"), ("natSub", "Nat.sub"), ("natMul", "Nat.mul"), + ("natPow", "Nat.pow"), ("natGcd", "Nat.gcd"), ("natMod", "Nat.mod"), + ("natDiv", "Nat.div"), ("natBeq", "Nat.beq"), ("natBle", "Nat.ble"), + ("natLand", "Nat.land"), ("natLor", "Nat.lor"), ("natXor", "Nat.xor"), + ("natShiftLeft", "Nat.shiftLeft"), ("natShiftRight", "Nat.shiftRight"), + ("bool", "Bool"), ("boolTrue", "Bool.true"), ("boolFalse", "Bool.false"), + ("string", "String"), ("stringMk", "String.mk"), + ("char", "Char"), ("charMk", "Char.ofNat"), + ("list", "List"), ("listNil", "List.nil"), ("listCons", "List.cons") + ] + for (field, name) in names do + IO.println s!"{field} := \"{lookupPrim ixonEnv name}\"" + return (true, none) + ) .done + +/-! ## Unit tests: Level reduce/imax edge cases -/ + +def testLevelReduceIMax : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- imax u 0 = 0 + test "imax u 0 = 0" (Level.reduceIMax p0 l0 == l0) ++ + -- imax u (succ v) = max u (succ v) + test "imax u (succ v) = max u (succ v)" + (Level.equalLevel (Level.reduceIMax p0 l1) (Level.reduceMax p0 l1)) ++ + -- imax u u = u (same param) + test "imax u u = u" (Level.reduceIMax p0 p0 == p0) ++ + -- imax u v stays imax (different params) + test "imax u v stays imax" + (Level.reduceIMax p0 p1 == Level.imax p0 p1) ++ + -- nested: imax u (imax v 0) — reduce inner first, then outer + let inner := Level.reduceIMax p1 l0 -- = 0 + test "imax u (imax v 0) = imax u 0 = 0" + (Level.reduceIMax p0 inner == l0) ++ + .done + +def testLevelReduceMax : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max 0 u = u + test "max 0 u = u" (Level.reduceMax l0 p0 == p0) ++ + -- max u 0 = u + test "max u 0 = u" (Level.reduceMax p0 l0 == p0) ++ + -- max (succ u) (succ v) = succ (max u v) + test "max (succ u) (succ v) = succ (max u v)" + (Level.reduceMax (Level.succ p0) (Level.succ p1) + == Level.succ (Level.reduceMax p0 p1)) ++ + -- max p0 p0 = p0 + test "max p0 p0 = p0" (Level.reduceMax p0 p0 == p0) ++ + .done + +def testLevelLeqComplex : TestSeq := + let l0 : Level .anon := Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- max u v <= max v u (symmetry) + test "max u v <= max v u" + (Level.leq (Level.max p0 p1) (Level.max p1 p0) 0) ++ + -- u <= max u v + test "u <= max u v" + (Level.leq p0 (Level.max p0 p1) 0) ++ + -- imax u (succ v) <= max u (succ v) — after reduce they're equal + let lhs := Level.reduce (Level.imax p0 (.succ p1)) + let rhs := Level.reduce (Level.max p0 (.succ p1)) + test "imax u (succ v) <= max u (succ v)" + (Level.leq lhs rhs 0) ++ + -- imax u 0 <= 0 + test "imax u 0 <= 0" + (Level.leq (Level.reduce (.imax p0 l0)) l0 0) ++ + -- not (succ (max u v) <= max u v) + test "not (succ (max u v) <= max u v)" + (!Level.leq (Level.succ (Level.max p0 p1)) (Level.max p0 p1) 0) ++ + -- imax u u <= u + test "imax u u <= u" + (Level.leq (Level.reduce (Level.imax p0 p0)) p0 0) ++ + -- imax 1 (imax 1 u) = u (nested imax decomposition) + let l1 : Level .anon := Level.succ Level.zero + let nested := Level.reduce (Level.imax l1 (Level.imax l1 p0)) + test "imax 1 (imax 1 u) <= u" + (Level.leq nested p0 0) ++ + test "u <= imax 1 (imax 1 u)" + (Level.leq p0 nested 0) ++ + .done + +def testLevelInstBulkReduce : TestSeq := + let l0 : Level .anon := Level.zero + let l1 : Level .anon := Level.succ Level.zero + let p0 : Level .anon := Level.param 0 default + let p1 : Level .anon := Level.param 1 default + -- Basic: param 0 with [zero] = zero + test "param 0 with [zero] = zero" + (Level.instBulkReduce #[l0] p0 == l0) ++ + -- Multi: param 1 with [zero, succ zero] = succ zero + test "param 1 with [zero, succ zero] = succ zero" + (Level.instBulkReduce #[l0, l1] p1 == l1) ++ + -- Out-of-bounds: param 2 with 2-element array shifts + let p2 : Level .anon := Level.param 2 default + test "param 2 with 2-elem array shifts to param 0" + (Level.instBulkReduce #[l0, l1] p2 == Level.param 0 default) ++ + -- Compound: imax (param 0) (param 1) with [zero, succ zero] + let compound := Level.imax p0 p1 + let result := Level.instBulkReduce #[l0, l1] compound + -- imax 0 (succ 0) = max 0 (succ 0) = succ 0 + test "imax (param 0) (param 1) subst [zero, succ zero]" + (Level.equalLevel result l1) ++ + .done + +def testReducibilityHintsLt : TestSeq := + test "regular 1 < regular 2" (ReducibilityHints.lt' (.regular 1) (.regular 2)) ++ + test "not (regular 2 < regular 1)" (!ReducibilityHints.lt' (.regular 2) (.regular 1)) ++ + test "regular _ < opaque" (ReducibilityHints.lt' (.regular 5) .opaque) ++ + test "abbrev < opaque" (ReducibilityHints.lt' .abbrev .opaque) ++ + test "not (opaque < opaque)" (!ReducibilityHints.lt' .opaque .opaque) ++ + test "not (regular 5 < regular 5)" (!ReducibilityHints.lt' (.regular 5) (.regular 5)) ++ + .done + +/-! ## Expanded integration tests -/ + +/-- Expanded constant pipeline: more constants including quotients, recursors, projections. -/ +def testMoreConstants : TestSeq := + .individualIO "expanded kernel const pipeline" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => return (false, some e) + | .ok (kenv, prims, quotInit) => + let constNames := #[ + -- Quotient types + "Quot", "Quot.mk", "Quot.lift", "Quot.ind", + -- K-reduction exercisers + "Eq.rec", "Eq.subst", "Eq.symm", "Eq.trans", + -- Proof irrelevance + "And.intro", "Or.inl", "Or.inr", + -- K-like reduction with congr + "congr", "congrArg", "congrFun", + -- Structure projections + eta + "Prod.fst", "Prod.snd", "Prod.mk", "Sigma.mk", "Subtype.mk", + -- Nat primitives + "Nat.add", "Nat.sub", "Nat.mul", "Nat.div", "Nat.mod", + "Nat.gcd", "Nat.beq", "Nat.ble", + "Nat.land", "Nat.lor", "Nat.xor", + "Nat.shiftLeft", "Nat.shiftRight", "Nat.pow", + -- Recursors + "Bool.rec", "List.rec", + -- Delta unfolding + "id", "Function.comp", + -- Various inductives + "Empty", "PUnit", "Fin", "Sigma", "Prod", + -- Proofs / proof irrelevance + "True", "False", "And", "Or", + -- Mutual/nested inductives + "List.map", "List.foldl", "List.append", + -- Universe polymorphism + "ULift", "PLift", + -- More complex + "Option", "Option.some", "Option.none", + "String", "String.mk", "Char", + -- Partial definitions + "WellFounded.fix" + ] + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"{name}: {e}" + IO.println s!"[kernel-expanded] {passed}/{constNames.size} passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Anon mode conversion test -/ + +/-- Test that convertEnv in .anon mode produces the same number of constants. -/ +def testAnonConvert : TestSeq := + .individualIO "anon mode conversion" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + let metaResult := Ix.Kernel.Convert.convertEnv .meta ixonEnv + let anonResult := Ix.Kernel.Convert.convertEnv .anon ixonEnv + match metaResult, anonResult with + | .ok (metaEnv, _, _), .ok (anonEnv, _, _) => + let metaCount := metaEnv.size + let anonCount := anonEnv.size + IO.println s!"[kernel-anon] meta: {metaCount}, anon: {anonCount}" + if metaCount == anonCount then + return (true, none) + else + return (false, some s!"meta ({metaCount}) != anon ({anonCount})") + | .error e, _ => return (false, some s!"meta conversion failed: {e}") + | _, .error e => return (false, some s!"anon conversion failed: {e}") + ) .done + +/-! ## Negative tests -/ + +/-- Negative test suite: verify that the typechecker rejects malformed declarations. -/ +def negativeTests : TestSeq := + .individualIO "kernel negative tests" (do + let testAddr := Address.blake3 (ByteArray.mk #[1, 0, 42]) + let badAddr := Address.blake3 (ByteArray.mk #[99, 0, 42]) + let prims := buildPrimitives + let mut passed := 0 + let mut failures : Array String := #[] + + -- Test 1: Theorem not in Prop (type = Sort 1, which is Type 0 not Prop) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ .zero), name := (), levelParams := () } + let ci : ConstantInfo .anon := .thmInfo { toConstantVal := cv, value := .sort .zero, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "theorem-not-prop: expected error" + + -- Test 2: Type mismatch (definition type = Sort 0, value = Sort 1) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort .zero, name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort (.succ .zero), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "type-mismatch: expected error" + + -- Test 3: Unknown constant reference (type references non-existent address) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .const badAddr #[] (), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort .zero, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "unknown-const: expected error" + + -- Test 4: Variable out of range (type = bvar 0 in empty context) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .bvar 0 (), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo { toConstantVal := cv, value := .sort .zero, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "var-out-of-range: expected error" + + -- Test 5: Application of non-function (Sort 0 is not a function) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ .zero)), name := (), levelParams := () } + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := .app (.sort .zero) (.sort .zero), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "app-non-function: expected error" + + -- Test 6: Let value type doesn't match annotation (Sort 1 : Sort 2, not Sort 0) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ (.succ .zero))), name := (), levelParams := () } + let letVal : Expr .anon := .letE (.sort .zero) (.sort (.succ .zero)) (.bvar 0 ()) () + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := letVal, hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "let-type-mismatch: expected error" + + -- Test 7: Lambda applied to wrong type (domain expects Prop, given Type 0) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .sort (.succ (.succ .zero)), name := (), levelParams := () } + let lam : Expr .anon := .lam (.sort .zero) (.bvar 0 ()) () () + let ci : ConstantInfo .anon := .defnInfo + { toConstantVal := cv, value := .app lam (.sort (.succ .zero)), hints := .opaque, safety := .safe, all := #[] } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "app-wrong-type: expected error" + + -- Test 8: Axiom with non-sort type (type = App (Sort 0) (Sort 0), not a sort) + do + let cv : ConstantVal .anon := + { numLevels := 0, type := .app (.sort .zero) (.sort .zero), name := (), levelParams := () } + let ci : ConstantInfo .anon := .axiomInfo { toConstantVal := cv, isUnsafe := false } + let env := (default : Env .anon).insert testAddr ci + match typecheckConst env prims testAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "axiom-non-sort-type: expected error" + + IO.println s!"[kernel-negative] {passed}/8 passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Focused NbE constant tests -/ + +/-- Test individual constants through the NbE kernel to isolate failures. -/ +def testNbeConsts : TestSeq := + .individualIO "nbe focused const checks" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => return (false, some s!"convertEnv: {e}") + | .ok (kenv, prims, quotInit) => + let constNames := #[ + -- Nat basics + "Nat", "Nat.zero", "Nat.succ", "Nat.rec", + -- Below / brecOn (well-founded recursion scaffolding) + "Nat.below", "Nat.brecOn", + -- PProd (used by Nat.below) + "PProd", "PProd.mk", "PProd.fst", "PProd.snd", + "PUnit", "PUnit.unit", + -- noConfusion (stuck neutral in fresh-state mode) + "Lean.Meta.Grind.Origin.noConfusionType", + "Lean.Meta.Grind.Origin.noConfusion", + "Lean.Meta.Grind.Origin.stx.noConfusion", + -- The previously-hanging constant + "Nat.Linear.Poly.of_denote_eq_cancel", + -- String theorem (fuel-sensitive) + "String.length_empty", + ] + let mut passed := 0 + let mut failures : Array String := #[] + for name in constNames do + let ixName := parseIxName name + let some cNamed := ixonEnv.named.get? ixName + | do failures := failures.push s!"{name}: not found"; continue + let addr := cNamed.addr + IO.println s!" checking {name} ..." + (← IO.getStdout).flush + let start ← IO.monoMsNow + match Ix.Kernel.typecheckConst kenv prims addr quotInit with + | .ok () => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✓ {name} ({ms.formatMs})" + passed := passed + 1 + | .error e => + let ms := (← IO.monoMsNow) - start + IO.println s!" ✗ {name} ({ms.formatMs}): {e}" + failures := failures.push s!"{name}: {e}" + IO.println s!"[nbe-focus] {passed}/{constNames.size} passed" + if failures.isEmpty then + return (true, none) + else + return (false, some s!"{failures.size} failure(s)") + ) .done + +def nbeFocusSuite : List TestSeq := [ + testNbeConsts, +] + +/-! ## Test suites -/ + +def unitSuite : List TestSeq := [ + testExprHashEq, + testExprOps, + testLevelOps, + testLevelReduceIMax, + testLevelReduceMax, + testLevelLeqComplex, + testLevelInstBulkReduce, + testReducibilityHintsLt, +] + +def convertSuite : List TestSeq := [ + testConvertEnv, +] + +def constSuite : List TestSeq := [ + testConstPipeline, + testMoreConstants, +] + +def negativeSuite : List TestSeq := [ + negativeTests, +] + +def anonConvertSuite : List TestSeq := [ + testAnonConvert, +] + +/-! ## Roundtrip test: Lean → Ixon → Kernel → Lean -/ + +/-- Roundtrip test: compile Lean env to Ixon, convert to Kernel, decompile back to Lean, + and structurally compare against the original. -/ +def testRoundtrip : TestSeq := + .individualIO "kernel roundtrip Lean→Ixon→Kernel→Lean" (do + let leanEnv ← get_env! + let ixonEnv ← Ix.CompileM.rsCompileEnv leanEnv + match Ix.Kernel.Convert.convertEnv .meta ixonEnv with + | .error e => + IO.println s!"[roundtrip] convertEnv error: {e}" + return (false, some e) + | .ok (kenv, _, _) => + -- Build Lean.Name → EnvId map from ixonEnv.named (name-aware lookup) + let mut nameToEnvId : Std.HashMap Lean.Name (Ix.Kernel.EnvId .meta) := {} + for (ixName, named) in ixonEnv.named do + nameToEnvId := nameToEnvId.insert (Ix.Kernel.Decompile.ixNameToLean ixName) ⟨named.addr, ixName⟩ + -- Build work items (filter to constants we can check) + let mut workItems : Array (Lean.Name × Lean.ConstantInfo × Ix.Kernel.ConstantInfo .meta) := #[] + let mut notFound := 0 + for (leanName, origCI) in leanEnv.constants.toList do + let some envId := nameToEnvId.get? leanName + | do notFound := notFound + 1; continue + let some kernelCI := kenv.findByEnvId envId + | continue + workItems := workItems.push (leanName, origCI, kernelCI) + -- Chunked parallel comparison + let numWorkers := 32 + let total := workItems.size + let chunkSize := (total + numWorkers - 1) / numWorkers + let mut tasks : Array (Task (Array (Lean.Name × Array (String × String × String)))) := #[] + let mut offset := 0 + while offset < total do + let endIdx := min (offset + chunkSize) total + let chunk := workItems[offset:endIdx] + let task := Task.spawn (prio := .dedicated) fun () => Id.run do + let mut results : Array (Lean.Name × Array (String × String × String)) := #[] + for (leanName, origCI, kernelCI) in chunk.toArray do + let roundtrippedCI := Ix.Kernel.Decompile.decompileConstantInfo kernelCI + let diffs := Ix.Kernel.Decompile.constInfoStructEq origCI roundtrippedCI + if !diffs.isEmpty then + results := results.push (leanName, diffs) + results + tasks := tasks.push task + offset := endIdx + -- Collect results + let checked := total + let mut mismatches := 0 + for task in tasks do + for (leanName, diffs) in task.get do + mismatches := mismatches + 1 + let diffMsgs := diffs.toList.map fun (path, lhs, rhs) => + s!" {path}: {lhs} ≠ {rhs}" + IO.println s!"[roundtrip] MISMATCH {leanName}:" + for msg in diffMsgs do IO.println msg + IO.println s!"[roundtrip] checked {checked}, mismatches {mismatches}, not found {notFound}" + if mismatches == 0 then + return (true, none) + else + return (false, some s!"{mismatches}/{checked} constants have structural mismatches") + ) .done + +def roundtripSuite : List TestSeq := [ + testRoundtrip, +] + +end Tests.KernelTests diff --git a/Tests/Ix/PP.lean b/Tests/Ix/PP.lean new file mode 100644 index 00000000..d96bd0f1 --- /dev/null +++ b/Tests/Ix/PP.lean @@ -0,0 +1,333 @@ +/- + Pretty printer test suite. + + Tests Expr.pp in both .meta and .anon modes, covering: + - Level/Sort display + - Binder/Var/Const name formatting + - App parenthesization + - Pi and Lambda chain collapsing + - Let expressions + - Literals and projections +-/ +import Ix.Kernel +import LSpec + +open LSpec +open Ix.Kernel + +namespace Tests.PP + +/-! ## Helpers -/ + +private def mkName (s : String) : Ix.Name := + Ix.Name.mkStr Ix.Name.mkAnon s + +private def mkDottedName (a b : String) : Ix.Name := + Ix.Name.mkStr (Ix.Name.mkStr Ix.Name.mkAnon a) b + +private def testAddr : Address := Address.blake3 (ByteArray.mk #[1, 2, 3]) +private def testAddr2 : Address := Address.blake3 (ByteArray.mk #[4, 5, 6]) + +/-- First 8 hex chars of testAddr, for anon mode assertions. -/ +private def testAddrShort : String := + String.ofList ((toString testAddr).toList.take 8) + +/-! ## Meta mode: Level / Sort display -/ + +def testPpSortMeta : TestSeq := + -- Sort display + let prop : Expr .meta := .sort .zero + let type : Expr .meta := .sort (.succ .zero) + let type1 : Expr .meta := .sort (.succ (.succ .zero)) + let type2 : Expr .meta := .sort (.succ (.succ (.succ .zero))) + -- Universe params + let uName := mkName "u" + let vName := mkName "v" + let sortU : Expr .meta := .sort (.param 0 uName) + let typeU : Expr .meta := .sort (.succ (.param 0 uName)) + let sortMax : Expr .meta := .sort (.max (.param 0 uName) (.param 1 vName)) + let sortIMax : Expr .meta := .sort (.imax (.param 0 uName) (.param 1 vName)) + -- Succ offset on param: Type (u + 1), Type (u + 2) + let typeU1 : Expr .meta := .sort (.succ (.succ (.param 0 uName))) + let typeU2 : Expr .meta := .sort (.succ (.succ (.succ (.param 0 uName)))) + test "sort zero → Prop" (prop.pp == "Prop") ++ + test "sort 1 → Type" (type.pp == "Type") ++ + test "sort 2 → Type 1" (type1.pp == "Type 1") ++ + test "sort 3 → Type 2" (type2.pp == "Type 2") ++ + test "sort (param u) → Sort u" (sortU.pp == "Sort u") ++ + test "sort (succ (param u)) → Type u" (typeU.pp == "Type u") ++ + test "sort (succ^2 (param u)) → Type (u + 1)" (typeU1.pp == "Type (u + 1)") ++ + test "sort (succ^3 (param u)) → Type (u + 2)" (typeU2.pp == "Type (u + 2)") ++ + test "sort (max u v) → Sort (max (u) (v))" (sortMax.pp == "Sort (max (u) (v))") ++ + test "sort (imax u v) → Sort (imax (u) (v))" (sortIMax.pp == "Sort (imax (u) (v))") ++ + .done + +/-! ## Meta mode: Atoms (bvar, const, lit) -/ + +def testPpAtomsMeta : TestSeq := + let x := mkName "x" + let natAdd := mkDottedName "Nat" "add" + -- bvar with name + let bv : Expr .meta := .bvar 0 x + test "bvar with name → x" (bv.pp == "x") ++ + -- const with name + let c : Expr .meta := .const testAddr #[] natAdd + test "const Nat.add → Nat.add" (c.pp == "Nat.add") ++ + -- nat literal + let n : Expr .meta := .lit (.natVal 42) + test "natLit 42 → 42" (n.pp == "42") ++ + -- string literal + let s : Expr .meta := .lit (.strVal "hello") + test "strLit hello → \"hello\"" (s.pp == "\"hello\"") ++ + .done + +/-! ## Meta mode: App parenthesization -/ + +def testPpAppMeta : TestSeq := + let f : Expr .meta := .const testAddr #[] (mkName "f") + let g : Expr .meta := .const testAddr2 #[] (mkName "g") + let a : Expr .meta := .bvar 0 (mkName "a") + let b : Expr .meta := .bvar 1 (mkName "b") + -- Simple application: no parens at top level + let fa := Expr.app f a + test "f a (no parens)" (fa.pp == "f a") ++ + -- Nested left-assoc: f a b + let fab := Expr.app (Expr.app f a) b + test "f a b (left-assoc, no parens)" (fab.pp == "f a b") ++ + -- Nested arg: f (g a) — arg needs parens + let fga := Expr.app f (Expr.app g a) + test "f (g a) (arg parens)" (fga.pp == "f (g a)") ++ + -- Atom mode: (f a) + test "f a atom → (f a)" (Expr.pp true fa == "(f a)") ++ + -- Deep nesting: f a (g b) + let fagb := Expr.app (Expr.app f a) (Expr.app g b) + test "f a (g b)" (fagb.pp == "f a (g b)") ++ + .done + +/-! ## Meta mode: Lambda and Pi -/ + +def testPpBindersMeta : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let bool : Expr .meta := .const testAddr2 #[] (mkName "Bool") + let body : Expr .meta := .bvar 0 (mkName "x") + let body2 : Expr .meta := .bvar 1 (mkName "y") + -- Single lambda + let lam1 : Expr .meta := .lam nat body (mkName "x") .default + test "λ (x : Nat) => x" (lam1.pp == "λ (x : Nat) => x") ++ + -- Single forall + let pi1 : Expr .meta := .forallE nat body (mkName "x") .default + test "∀ (x : Nat), x" (pi1.pp == "∀ (x : Nat), x") ++ + -- Chained lambdas + let lam2 : Expr .meta := .lam nat (.lam bool body2 (mkName "y") .default) (mkName "x") .default + test "λ (x : Nat) (y : Bool) => y" (lam2.pp == "λ (x : Nat) (y : Bool) => y") ++ + -- Chained foralls + let pi2 : Expr .meta := .forallE nat (.forallE bool body2 (mkName "y") .default) (mkName "x") .default + test "∀ (x : Nat) (y : Bool), y" (pi2.pp == "∀ (x : Nat) (y : Bool), y") ++ + -- Lambda in atom position + test "lambda atom → (λ ...)" (Expr.pp true lam1 == "(λ (x : Nat) => x)") ++ + -- Forall in atom position + test "forall atom → (∀ ...)" (Expr.pp true pi1 == "(∀ (x : Nat), x)") ++ + .done + +/-! ## Meta mode: Let -/ + +def testPpLetMeta : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let zero : Expr .meta := .lit (.natVal 0) + let body : Expr .meta := .bvar 0 (mkName "x") + let letE : Expr .meta := .letE nat zero body (mkName "x") + test "let x : Nat := 0; x" (letE.pp == "let x : Nat := 0; x") ++ + -- Let in atom position + test "let atom → (let ...)" (Expr.pp true letE == "(let x : Nat := 0; x)") ++ + .done + +/-! ## Meta mode: Projection -/ + +def testPpProjMeta : TestSeq := + let struct : Expr .meta := .bvar 0 (mkName "s") + let proj0 : Expr .meta := .proj testAddr 0 struct (mkName "Prod") + test "s.0" (proj0.pp == "s.0") ++ + -- Projection of app (needs parens around struct) + let f : Expr .meta := .const testAddr #[] (mkName "f") + let a : Expr .meta := .bvar 0 (mkName "a") + let projApp : Expr .meta := .proj testAddr 1 (.app f a) (mkName "Prod") + test "(f a).1" (projApp.pp == "(f a).1") ++ + .done + +/-! ## Anon mode -/ + +def testPpAnon : TestSeq := + -- bvar: ^idx + let bv : Expr .anon := .bvar 3 () + test "anon bvar 3 → ^3" (bv.pp == "^3") ++ + -- const: #hash + let c : Expr .anon := .const testAddr #[] () + test "anon const → #hash" (c.pp == s!"#{testAddrShort}") ++ + -- sort + let prop : Expr .anon := .sort .zero + test "anon sort zero → Prop" (prop.pp == "Prop") ++ + -- level param: u_idx + let sortU : Expr .anon := .sort (.param 0 ()) + test "anon sort (param 0) → Sort u_0" (sortU.pp == "Sort u_0") ++ + -- lambda: binder name = _ + let lam : Expr .anon := .lam (.sort .zero) (.bvar 0 ()) () () + test "anon lam → λ (_ : ...) => ..." (lam.pp == "λ (_ : Prop) => ^0") ++ + -- forall: binder name = _ + let pi : Expr .anon := .forallE (.sort .zero) (.bvar 0 ()) () () + test "anon forall → ∀ (_ : ...), ..." (pi.pp == "∀ (_ : Prop), ^0") ++ + -- let: binder name = _ + let letE : Expr .anon := .letE (.sort .zero) (.lit (.natVal 0)) (.bvar 0 ()) () + test "anon let → let _ : ..." (letE.pp == "let _ : Prop := 0; ^0") ++ + -- chained anon lambdas + let lam2 : Expr .anon := .lam (.sort .zero) (.lam (.sort (.succ .zero)) (.bvar 0 ()) () ()) () () + test "anon chained lam" (lam2.pp == "λ (_ : Prop) (_ : Type) => ^0") ++ + .done + +/-! ## Meta mode: ??? detection (flags naming bugs) -/ + +/-- In .meta mode, default/anonymous names produce "???" in binder positions + and full address hashes in const positions. These indicate naming info was + never present in the source expression (e.g., anonymous Ix.Name). + + Binder names survive the eval/quote round-trip: Value.lam and Value.pi + carry MetaField name and binder info, which quote extracts. + + Remaining const-name loss: `strLitToCtorVal`/`toCtorIfLit` create + Neutral.const with default names for synthetic primitive constructors. +-/ +def testPpMetaDefaultNames : TestSeq := + let anonName := Ix.Name.mkAnon + -- bvar with anonymous name shows ??? + let bv : Expr .meta := .bvar 0 anonName + test "meta bvar with anonymous name → ???" (bv.pp == "???") ++ + -- const with anonymous name shows full hash + let c : Expr .meta := .const testAddr #[] anonName + test "meta const with anonymous name → full hash" (c.pp == s!"{testAddr}") ++ + -- lambda with anonymous binder name shows ??? + let lam : Expr .meta := .lam (.sort .zero) (.bvar 0 anonName) anonName .default + test "meta lam with anonymous binder → λ (??? : Prop) => ???" (lam.pp == "λ (??? : Prop) => ???") ++ + -- forall with anonymous binder name shows ??? + let pi : Expr .meta := .forallE (.sort .zero) (.bvar 0 anonName) anonName .default + test "meta forall with anonymous binder → ∀ (??? : Prop), ???" (pi.pp == "∀ (??? : Prop), ???") ++ + .done + +/-! ## Complex expressions -/ + +def testPpComplex : TestSeq := + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + let bool : Expr .meta := .const testAddr2 #[] (mkName "Bool") + -- ∀ (n : Nat), Nat → Nat (arrow sugar approximation) + -- This is: forallE Nat (forallE Nat Nat) + let arrow : Expr .meta := .forallE nat (.forallE nat nat (mkName "m") .default) (mkName "n") .default + test "∀ (n : Nat) (m : Nat), Nat" (arrow.pp == "∀ (n : Nat) (m : Nat), Nat") ++ + -- fun (f : Nat → Bool) (x : Nat) => f x + let fType : Expr .meta := .forallE nat bool (mkName "a") .default + let fApp : Expr .meta := .app (.bvar 1 (mkName "f")) (.bvar 0 (mkName "x")) + let expr : Expr .meta := .lam fType (.lam nat fApp (mkName "x") .default) (mkName "f") .default + test "λ (f : ∀ ...) (x : Nat) => f x" + (expr.pp == "λ (f : ∀ (a : Nat), Bool) (x : Nat) => f x") ++ + -- Nested let: let x : Nat := 0; let y : Nat := x; y + let innerLet : Expr .meta := .letE nat (.bvar 0 (mkName "x")) (.bvar 0 (mkName "y")) (mkName "y") + let outerLet : Expr .meta := .letE nat (.lit (.natVal 0)) innerLet (mkName "x") + test "nested let" (outerLet.pp == "let x : Nat := 0; let y : Nat := x; y") ++ + .done + +/-! ## Quote round-trip: names survive eval → quote → pp -/ + +/-- Build a Value with named binders and verify names survive through quote → pp. + Uses a minimal TypecheckM context. -/ +def testQuoteRoundtrip : TestSeq := + .individualIO "quote round-trip preserves names" (do + let xName : MetaField .meta Ix.Name := mkName "x" + let yName : MetaField .meta Ix.Name := mkName "y" + let nat : Expr .meta := .const testAddr #[] (mkName "Nat") + -- Build Value.pi: ∀ (x : Nat), Nat + let domVal : SusValue .meta := ⟨.none, Thunk.mk fun _ => Value.neu (.const testAddr #[] (mkName "Nat"))⟩ + let imgTE : TypedExpr .meta := ⟨.none, nat⟩ + let piVal : Value .meta := .pi domVal imgTE (.mk [] []) xName .default + -- Build Value.lam: fun (y : Nat) => y + let bodyTE : TypedExpr .meta := ⟨.none, .bvar 0 yName⟩ + let lamVal : Value .meta := .lam domVal bodyTE (.mk [] []) yName .default + -- Quote and pp in a minimal TypecheckM context + let ctx : TypecheckCtx .meta := { + lvl := 0, env := .mk [] [], types := [], + kenv := default, prims := buildPrimitives, + safety := .safe, quotInit := true, mutTypes := default, recAddr? := none + } + let stt : TypecheckState .meta := { typedConsts := default } + -- Test pi + match TypecheckM.run ctx stt (ppValue 0 piVal) with + | .ok s => + if s != "∀ (x : Nat), Nat" then + return (false, some s!"pi round-trip: expected '∀ (x : Nat), Nat', got '{s}'") + else pure () + | .error e => return (false, some s!"pi round-trip error: {e}") + -- Test lam + match TypecheckM.run ctx stt (ppValue 0 lamVal) with + | .ok s => + if s != "λ (y : Nat) => y" then + return (false, some s!"lam round-trip: expected 'λ (y : Nat) => y', got '{s}'") + else pure () + | .error e => return (false, some s!"lam round-trip error: {e}") + return (true, none) + ) .done + +/-! ## Literal folding: Nat/String constructor chains → literals in ppValue -/ + +def testFoldLiterals : TestSeq := + let prims := buildPrimitives + -- Nat.zero → 0 + let natZero : Expr .meta := .const prims.natZero #[] (mkName "Nat.zero") + let folded := foldLiterals prims natZero + test "fold Nat.zero → 0" (folded.pp == "0") ++ + -- Nat.succ Nat.zero → 1 + let natOne : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) natZero + let folded := foldLiterals prims natOne + test "fold Nat.succ Nat.zero → 1" (folded.pp == "1") ++ + -- Nat.succ (Nat.succ Nat.zero) → 2 + let natTwo : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) natOne + let folded := foldLiterals prims natTwo + test "fold Nat.succ^2 Nat.zero → 2" (folded.pp == "2") ++ + -- Nats inside types get folded: ∀ (n : Nat), Eq Nat n Nat.zero + let natType : Expr .meta := .const prims.nat #[] (mkName "Nat") + let eqAddr := Address.blake3 (ByteArray.mk #[99]) + let eq3 : Expr .meta := + .app (.app (.app (.const eqAddr #[] (mkName "Eq")) natType) (.bvar 0 (mkName "n"))) natZero + let piExpr : Expr .meta := .forallE natType eq3 (mkName "n") .default + let folded := foldLiterals prims piExpr + test "fold nat inside forall" (folded.pp == "∀ (n : Nat), Eq Nat n 0") ++ + -- String.mk (List.cons (Char.ofNat 104) (List.cons (Char.ofNat 105) List.nil)) → "hi" + let charH : Expr .meta := .app (.const prims.charMk #[] (mkName "Char.ofNat")) (.lit (.natVal 104)) + let charI : Expr .meta := .app (.const prims.charMk #[] (mkName "Char.ofNat")) (.lit (.natVal 105)) + let charType : Expr .meta := .const prims.char #[] (mkName "Char") + let nilExpr : Expr .meta := .app (.const prims.listNil #[.zero] (mkName "List.nil")) charType + let consI : Expr .meta := + .app (.app (.app (.const prims.listCons #[.zero] (mkName "List.cons")) charType) charI) nilExpr + let consH : Expr .meta := + .app (.app (.app (.const prims.listCons #[.zero] (mkName "List.cons")) charType) charH) consI + let strExpr : Expr .meta := .app (.const prims.stringMk #[] (mkName "String.mk")) consH + let folded := foldLiterals prims strExpr + test "fold String.mk char list → \"hi\"" (folded.pp == "\"hi\"") ++ + -- Nat.succ applied to a non-literal arg stays unfolded + let succX : Expr .meta := .app (.const prims.natSucc #[] (mkName "Nat.succ")) (.bvar 0 (mkName "x")) + let folded := foldLiterals prims succX + test "fold Nat.succ x → Nat.succ x (no fold)" (folded.pp == "Nat.succ x") ++ + .done + +/-! ## Suites -/ + +def suite : List TestSeq := [ + testPpSortMeta, + testPpAtomsMeta, + testPpAppMeta, + testPpBindersMeta, + testPpLetMeta, + testPpProjMeta, + testPpAnon, + testPpMetaDefaultNames, + testPpComplex, + testQuoteRoundtrip, + testFoldLiterals, +] + +end Tests.PP diff --git a/Tests/Main.lean b/Tests/Main.lean index e25300a8..e7ca61c2 100644 --- a/Tests/Main.lean +++ b/Tests/Main.lean @@ -9,6 +9,9 @@ import Tests.Ix.RustDecompile import Tests.Ix.Sharing import Tests.Ix.CanonM import Tests.Ix.GraphM +import Tests.Ix.Check +import Tests.Ix.KernelTests +import Tests.Ix.PP import Tests.Ix.CondenseM import Tests.FFI import Tests.Keccak @@ -32,6 +35,10 @@ def primarySuites : Std.HashMap String (List LSpec.TestSeq) := .ofList [ ("sharing", Tests.Sharing.suite), ("graph-unit", Tests.Ix.GraphM.suite), ("condense-unit", Tests.Ix.CondenseM.suite), + --("check", Tests.Check.checkSuiteIO), -- disable until rust kernel works + ("kernel-unit", Tests.KernelTests.unitSuite), + ("kernel-negative", Tests.KernelTests.negativeSuite), + ("pp", Tests.PP.suite), ] /-- Ignored test suites - expensive, run only when explicitly requested. These require significant RAM -/ @@ -47,6 +54,16 @@ def ignoredSuites : Std.HashMap String (List LSpec.TestSeq) := .ofList [ ("rust-serialize", Tests.RustSerialize.rustSerializeSuiteIO), ("rust-decompile", Tests.RustDecompile.rustDecompileSuiteIO), ("commit-io", Tests.Commit.suiteIO), + --("check-all", Tests.Check.checkAllSuiteIO), + ("kernel-check-env", Tests.Check.kernelSuiteIO), + ("kernel-convert", Tests.KernelTests.convertSuite), + ("kernel-anon-convert", Tests.KernelTests.anonConvertSuite), + ("kernel-const", Tests.KernelTests.constSuite), + ("kernel-verify-prims", [Tests.KernelTests.testVerifyPrimAddrs]), + ("kernel-dump-prims", [Tests.KernelTests.testDumpPrimAddrs]), + ("nbe-focus", Tests.KernelTests.nbeFocusSuite), + ("kernel-roundtrip", Tests.KernelTests.roundtripSuite), + ("ixon-full-roundtrip", Tests.Compile.ixonRoundtripSuiteIO), ] def main (args : List String) : IO UInt32 := do diff --git a/docs/Ixon.md b/docs/Ixon.md index 655f06d8..74509dfd 100644 --- a/docs/Ixon.md +++ b/docs/Ixon.md @@ -736,7 +736,6 @@ pub struct Env { pub blobs: DashMap>, // Raw data (strings, nats) pub names: DashMap, // Hash-consed Name components pub comms: DashMap, // Cryptographic commitments - pub addr_to_name: DashMap, // Reverse index } pub struct Named { @@ -1001,7 +1000,7 @@ Decompilation reconstructs Lean constants from Ixon format. 2. **Initialize tables** from `sharing`, `refs`, `univs` 3. **Load metadata** from `env.named` 4. **Reconstruct expressions** with names and binder info from metadata -5. **Resolve references**: `Ref(idx, _)` → lookup `refs[idx]`, get name from `addr_to_name` +5. **Resolve references**: `Ref(idx, _)` → lookup name from arena metadata via `names` table 6. **Expand shares**: `Share(idx)` → inline `sharing[idx]` (or cache result) ### Roundtrip Verification @@ -1145,7 +1144,7 @@ To reconstruct the Lean constant: 1. Load `Constant` from `consts[address]` 2. Load `Named` from `named["double"]` -3. Resolve `Ref(0, [])` → `refs[0]` → `Nat` (via `addr_to_name`) +3. Resolve `Ref(0, [])` → name from arena metadata → `Nat` (via `names` table) 4. Resolve `Ref(1, [])` → `refs[1]` → `Nat.add` 5. Attach names from metadata: the binder gets name "n" from `type_meta[0]` diff --git a/src/ix/decompile.rs b/src/ix/decompile.rs index 88082135..26bd3dc7 100644 --- a/src/ix/decompile.rs +++ b/src/ix/decompile.rs @@ -565,39 +565,19 @@ pub fn decompile_expr( // Ref: resolve name from arena Ref node or fallback ( ExprMetaData::Ref { name: name_addr }, - Expr::Ref(ref_idx, univ_indices), + Expr::Ref(_ref_idx, univ_indices), ) => { - let name = decompile_name(name_addr, stt).unwrap_or_else(|_| { - // Fallback: resolve from refs table - cache - .refs - .get(*ref_idx as usize) - .and_then(|addr| stt.env.get_name_by_addr(addr)) - .unwrap_or_else(Name::anon) - }); + let name = decompile_name(name_addr, stt)?; let levels = decompile_univ_indices(univ_indices, lvl_names, cache)?; let expr = apply_mdata(LeanExpr::cnst(name, levels), mdata_layers); results.push(expr); }, - (_, Expr::Ref(ref_idx, univ_indices)) => { - // No Ref metadata — resolve from refs table - let addr = cache.refs.get(*ref_idx as usize).ok_or_else(|| { - DecompileError::InvalidRefIndex { - idx: *ref_idx, - refs_len: cache.refs.len(), - constant: cache.current_const.clone(), - } - })?; - let name = stt - .env - .get_name_by_addr(addr) - .ok_or(DecompileError::MissingAddress(addr.clone()))?; - let levels = - decompile_univ_indices(univ_indices, lvl_names, cache)?; - let expr = apply_mdata(LeanExpr::cnst(name, levels), mdata_layers); - results.push(expr); + (_, Expr::Ref(_ref_idx, _univ_indices)) => { + return Err(DecompileError::BadConstantFormat { + msg: "ref without arena metadata".to_string(), + }); }, // Rec: resolve name from arena Ref node or fallback @@ -735,27 +715,10 @@ pub fn decompile_expr( stack.push(Frame::Decompile(struct_val.clone(), *child)); }, - (_, Expr::Prj(type_ref_idx, field_idx, struct_val)) => { - // Fallback: look up from refs table - let addr = - cache.refs.get(*type_ref_idx as usize).ok_or_else(|| { - DecompileError::InvalidRefIndex { - idx: *type_ref_idx, - refs_len: cache.refs.len(), - constant: cache.current_const.clone(), - } - })?; - let named = stt - .env - .get_named_by_addr(addr) - .ok_or(DecompileError::MissingAddress(addr.clone()))?; - let type_name = decompile_name_from_meta(&named.meta, stt)?; - stack.push(Frame::BuildProj( - type_name, - Nat::from(*field_idx), - mdata_layers, - )); - stack.push(Frame::Decompile(struct_val.clone(), u64::MAX)); + (_, Expr::Prj(_type_ref_idx, _field_idx, _struct_val)) => { + return Err(DecompileError::BadConstantFormat { + msg: "prj without arena metadata".to_string(), + }); }, (_, Expr::Share(_)) => unreachable!("Share handled above"), diff --git a/src/ix/ixon/env.rs b/src/ix/ixon/env.rs index b13ce571..80b4349c 100644 --- a/src/ix/ixon/env.rs +++ b/src/ix/ixon/env.rs @@ -36,7 +36,6 @@ impl Named { /// - `blobs`: Raw data (strings, nats, files) /// - `names`: Hash-consed Lean.Name components (Address -> Name) /// - `comms`: Cryptographic commitments (secrets) -/// - `addr_to_name`: Reverse index from constant address to name (for O(1) lookup) #[derive(Debug, Default)] pub struct Env { /// Alpha-invariant constants: Address -> Constant @@ -49,8 +48,6 @@ pub struct Env { pub names: DashMap, /// Cryptographic commitments: commitment Address -> Comm pub comms: DashMap, - /// Reverse index: constant Address -> Name (for fast lookup during decompile) - pub addr_to_name: DashMap, } impl Env { @@ -61,7 +58,6 @@ impl Env { blobs: DashMap::new(), names: DashMap::new(), comms: DashMap::new(), - addr_to_name: DashMap::new(), } } @@ -90,8 +86,6 @@ impl Env { /// Register a named constant. pub fn register_name(&self, name: Name, named: Named) { - // Also insert into reverse index for O(1) lookup by address - self.addr_to_name.insert(named.addr.clone(), name.clone()); self.named.insert(name, named); } @@ -100,16 +94,6 @@ impl Env { self.named.get(name).map(|r| r.clone()) } - /// Look up name by constant address (O(1) using reverse index). - pub fn get_name_by_addr(&self, addr: &Address) -> Option { - self.addr_to_name.get(addr).map(|r| r.clone()) - } - - /// Look up named entry by constant address (O(1) using reverse index). - pub fn get_named_by_addr(&self, addr: &Address) -> Option { - self.get_name_by_addr(addr).and_then(|name| self.lookup_name(&name)) - } - /// Store a hash-consed name component. pub fn store_name(&self, addr: Address, name: Name) { self.names.insert(addr, name); @@ -183,12 +167,7 @@ impl Clone for Env { comms.insert(entry.key().clone(), entry.value().clone()); } - let addr_to_name = DashMap::new(); - for entry in self.addr_to_name.iter() { - addr_to_name.insert(entry.key().clone(), entry.value().clone()); - } - - Env { consts, named, blobs, names, comms, addr_to_name } + Env { consts, named, blobs, names, comms } } } @@ -244,28 +223,6 @@ mod tests { assert_eq!(got.addr, addr); } - #[test] - fn get_name_by_addr_reverse_index() { - let env = Env::new(); - let name = n("Reverse"); - let addr = Address::hash(b"reverse-addr"); - let named = Named::with_addr(addr.clone()); - env.register_name(name.clone(), named); - let got_name = env.get_name_by_addr(&addr).unwrap(); - assert_eq!(got_name, name); - } - - #[test] - fn get_named_by_addr_resolves_through_reverse_index() { - let env = Env::new(); - let name = n("Through"); - let addr = Address::hash(b"through-addr"); - let named = Named::with_addr(addr.clone()); - env.register_name(name.clone(), named); - let got = env.get_named_by_addr(&addr).unwrap(); - assert_eq!(got.addr, addr); - } - #[test] fn store_and_get_name_component() { let env = Env::new(); @@ -322,8 +279,6 @@ mod tests { assert!(env.get_blob(&missing).is_none()); assert!(env.get_const(&missing).is_none()); assert!(env.lookup_name(&n("missing")).is_none()); - assert!(env.get_name_by_addr(&missing).is_none()); - assert!(env.get_named_by_addr(&missing).is_none()); assert!(env.get_name(&missing).is_none()); assert!(env.get_comm(&missing).is_none()); } diff --git a/src/ix/ixon/serialize.rs b/src/ix/ixon/serialize.rs index c0572160..aa56d9a2 100644 --- a/src/ix/ixon/serialize.rs +++ b/src/ix/ixon/serialize.rs @@ -1186,7 +1186,6 @@ impl Env { let name = names_lookup.get(&name_addr).cloned().ok_or_else(|| { format!("Env::get: missing name for addr {:?}", name_addr) })?; - env.addr_to_name.insert(named.addr.clone(), name.clone()); env.named.insert(name, named); } @@ -1456,7 +1455,6 @@ mod tests { let name = names[i % names.len()].clone(); let meta = ConstantMeta::default(); let named = Named { addr: addr.clone(), meta }; - env.addr_to_name.insert(addr, name.clone()); env.named.insert(name, named); } } diff --git a/src/ix/kernel/convert.rs b/src/ix/kernel/convert.rs index 90811948..c6f5af2c 100644 --- a/src/ix/kernel/convert.rs +++ b/src/ix/kernel/convert.rs @@ -1,7 +1,10 @@ use core::ptr::NonNull; use std::collections::BTreeMap; +use std::sync::Arc; -use crate::ix::env::{Expr, ExprData, Level, Name}; +use rustc_hash::FxHashMap; + +use crate::ix::env::{BinderInfo, Expr, ExprData, Level, Name}; use crate::lean::nat::Nat; use super::dag::*; @@ -23,208 +26,427 @@ fn from_expr_go( ctx: &BTreeMap>, parents: Option>, ) -> DAGPtr { - match expr.as_data() { - ExprData::Bvar(idx, _) => { - let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); - if idx_u64 < depth { - let level = depth - 1 - idx_u64; - match ctx.get(&level) { - Some(&var_ptr) => { - if let Some(parent_link) = parents { - add_to_parents(DAGPtr::Var(var_ptr), parent_link); + // Frame-based iterative Expr → DAG conversion. + // + // For compound nodes, we pre-allocate the DAG node with dangling child + // pointers, then push frames to fill in children after they're converted. + // + // The ctx is cloned at binder boundaries (Fun, Pi, Let) to track + // bound variable bindings. + enum Frame<'a> { + Visit { + expr: &'a Expr, + depth: u64, + ctx: BTreeMap>, + parents: Option>, + }, + SetAppFun(NonNull), + SetAppArg(NonNull), + SetFunDom(NonNull), + SetPiDom(NonNull), + SetLetTyp(NonNull), + SetLetVal(NonNull), + SetProjExpr(NonNull), + // After domain is set, wire up binder body with new ctx + FunBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + PiBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + LetBody { + lam_ptr: NonNull, + body: &'a Expr, + depth: u64, + ctx: BTreeMap>, + }, + SetLamBod(NonNull), + } + + let mut work: Vec> = vec![Frame::Visit { + expr, + depth, + ctx: ctx.clone(), + parents, + }]; + // Results stack holds DAGPtr for each completed subtree + let mut results: Vec = Vec::new(); + let mut visit_count: u64 = 0; + // Cache for context-independent leaf nodes (Cnst, Sort, Lit). + // Keyed by Arc pointer identity. Enables DAG sharing so the infer cache + // (keyed by DAGPtr address) can dedup repeated references to the same constant. + let mut leaf_cache: FxHashMap<*const ExprData, DAGPtr> = FxHashMap::default(); + + while let Some(frame) = work.pop() { + visit_count += 1; + if visit_count % 100_000 == 0 { + eprintln!("[from_expr_go] visit_count={visit_count} work_len={}", work.len()); + } + match frame { + Frame::Visit { expr, depth, ctx, parents } => { + match expr.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 < depth { + let level = depth - 1 - idx_u64; + match ctx.get(&level) { + Some(&var_ptr) => { + if let Some(parent_link) = parents { + add_to_parents(DAGPtr::Var(var_ptr), parent_link); + } + results.push(DAGPtr::Var(var_ptr)); + }, + None => { + let var = alloc_val(Var { + depth: level, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); + }, + } + } else { + let var = alloc_val(Var { + depth: idx_u64, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); } - DAGPtr::Var(var_ptr) }, - None => { + + ExprData::Fvar(name, _) => { let var = alloc_val(Var { - depth: level, + depth: 0, binder: BinderPtr::Free, + fvar_name: Some(name.clone()), parents, }); - DAGPtr::Var(var) + results.push(DAGPtr::Var(var)); }, - } - } else { - // Free bound variable (dangling de Bruijn index) - let var = - alloc_val(Var { depth: idx_u64, binder: BinderPtr::Free, parents }); - DAGPtr::Var(var) - } - }, - ExprData::Fvar(_name, _) => { - // Encode fvar name into depth as a unique ID. - // We'll recover it during to_expr using a side table. - let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); - // Store name→var mapping (caller should manage the side table) - DAGPtr::Var(var) - }, + ExprData::Sort(level, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let sort = alloc_val(Sort { level: level.clone(), parents }); + let ptr = DAGPtr::Sort(sort); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Sort(level, _) => { - let sort = alloc_val(Sort { level: level.clone(), parents }); - DAGPtr::Sort(sort) - }, + ExprData::Const(name, levels, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let cnst = alloc_val(Cnst { + name: name.clone(), + levels: levels.clone(), + parents, + }); + let ptr = DAGPtr::Cnst(cnst); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Const(name, levels, _) => { - let cnst = alloc_val(Cnst { - name: name.clone(), - levels: levels.clone(), - parents, - }); - DAGPtr::Cnst(cnst) - }, + ExprData::Lit(lit, _) => { + let key = Arc::as_ptr(&expr.0); + if let Some(&cached) = leaf_cache.get(&key) { + if let Some(parent_link) = parents { + add_to_parents(cached, parent_link); + } + results.push(cached); + } else { + let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); + let ptr = DAGPtr::Lit(lit_node); + leaf_cache.insert(key, ptr); + results.push(ptr); + } + }, - ExprData::Lit(lit, _) => { - let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); - DAGPtr::Lit(lit_node) - }, + ExprData::App(fun_expr, arg_expr, _) => { + let app_ptr = alloc_app( + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let app = &mut *app_ptr.as_ptr(); + let fun_ref = + NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); + let arg_ref = + NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); + // Process arg first (pushed last = processed first after fun) + work.push(Frame::SetAppArg(app_ptr)); + work.push(Frame::Visit { + expr: arg_expr, + depth, + ctx: ctx.clone(), + parents: Some(arg_ref), + }); + work.push(Frame::SetAppFun(app_ptr)); + work.push(Frame::Visit { + expr: fun_expr, + depth, + ctx, + parents: Some(fun_ref), + }); + } + results.push(DAGPtr::App(app_ptr)); + }, - ExprData::App(fun_expr, arg_expr, _) => { - let app_ptr = alloc_app( - DAGPtr::Var(NonNull::dangling()), - DAGPtr::Var(NonNull::dangling()), - parents, - ); - unsafe { - let app = &mut *app_ptr.as_ptr(); - let fun_ref_ptr = - NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); - let arg_ref_ptr = - NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); - app.fun = from_expr_go(fun_expr, depth, ctx, Some(fun_ref_ptr)); - app.arg = from_expr_go(arg_expr, depth, ctx, Some(arg_ref_ptr)); - } - DAGPtr::App(app_ptr) - }, + ExprData::Lam(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let fun_ptr = alloc_fun( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + let dom_ref = + NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); + let img_ref = + NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref); + + let dom_ctx = ctx.clone(); + work.push(Frame::FunBody { + lam_ptr, + body, + depth, + ctx, + }); + work.push(Frame::SetFunDom(fun_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx: dom_ctx, + parents: Some(dom_ref), + }); + } + results.push(DAGPtr::Fun(fun_ptr)); + }, - ExprData::Lam(name, typ, body, bi, _) => { - // Lean Lam → DAG Fun(dom, Lam(bod, var)) - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let fun_ptr = alloc_fun( - name.clone(), - bi.clone(), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let fun = &mut *fun_ptr.as_ptr(); - let dom_ref_ptr = - NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); - fun.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); - - // Set Lam's parent to FunImg - let img_ref_ptr = - NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + ExprData::ForallE(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let pi_ptr = alloc_pi( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + let dom_ref = + NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); + let img_ref = + NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref); + + let dom_ctx = ctx.clone(); + work.push(Frame::PiBody { + lam_ptr, + body, + depth, + ctx, + }); + work.push(Frame::SetPiDom(pi_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx: dom_ctx, + parents: Some(dom_ref), + }); + } + results.push(DAGPtr::Pi(pi_ptr)); + }, + + ExprData::LetE(name, typ, val, body, non_dep, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let let_ptr = alloc_let( + name.clone(), + *non_dep, + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let typ_ref = + NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); + let val_ref = + NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); + let bod_ref = + NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref); + + work.push(Frame::LetBody { + lam_ptr, + body, + depth, + ctx: ctx.clone(), + }); + work.push(Frame::SetLetVal(let_ptr)); + work.push(Frame::Visit { + expr: val, + depth, + ctx: ctx.clone(), + parents: Some(val_ref), + }); + work.push(Frame::SetLetTyp(let_ptr)); + work.push(Frame::Visit { + expr: typ, + depth, + ctx, + parents: Some(typ_ref), + }); + } + results.push(DAGPtr::Let(let_ptr)); + }, + ExprData::Proj(type_name, idx, structure, _) => { + let proj_ptr = alloc_proj( + type_name.clone(), + idx.clone(), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + let expr_ref = + NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); + work.push(Frame::SetProjExpr(proj_ptr)); + work.push(Frame::Visit { + expr: structure, + depth, + ctx, + parents: Some(expr_ref), + }); + } + results.push(DAGPtr::Proj(proj_ptr)); + }, + + ExprData::Mdata(_, inner, _) => { + // Strip metadata, convert inner + work.push(Frame::Visit { expr: inner, depth, ctx, parents }); + }, + + ExprData::Mvar(_name, _) => { + let var = alloc_val(Var { + depth: 0, + binder: BinderPtr::Free, + fvar_name: None, + parents, + }); + results.push(DAGPtr::Var(var)); + }, + } + }, + Frame::SetAppFun(app_ptr) => unsafe { + let result = results.pop().unwrap(); + (*app_ptr.as_ptr()).fun = result; + }, + Frame::SetAppArg(app_ptr) => unsafe { + let result = results.pop().unwrap(); + (*app_ptr.as_ptr()).arg = result; + }, + Frame::SetFunDom(fun_ptr) => unsafe { + let result = results.pop().unwrap(); + (*fun_ptr.as_ptr()).dom = result; + }, + Frame::SetPiDom(pi_ptr) => unsafe { + let result = results.pop().unwrap(); + (*pi_ptr.as_ptr()).dom = result; + }, + Frame::SetLetTyp(let_ptr) => unsafe { + let result = results.pop().unwrap(); + (*let_ptr.as_ptr()).typ = result; + }, + Frame::SetLetVal(let_ptr) => unsafe { + let result = results.pop().unwrap(); + (*let_ptr.as_ptr()).val = result; + }, + Frame::SetProjExpr(proj_ptr) => unsafe { + let result = results.pop().unwrap(); + (*proj_ptr.as_ptr()).expr = result; + }, + Frame::SetLamBod(lam_ptr) => unsafe { + let result = results.pop().unwrap(); + (*lam_ptr.as_ptr()).bod = result; + }, + Frame::FunBody { lam_ptr, body, depth, mut ctx } => unsafe { + // Domain has been set; now set up body with var binding let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); - } - DAGPtr::Fun(fun_ptr) - }, - - ExprData::ForallE(name, typ, body, bi, _) => { - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let pi_ptr = alloc_pi( - name.clone(), - bi.clone(), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let pi = &mut *pi_ptr.as_ptr(); - let dom_ref_ptr = - NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); - pi.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); - - let img_ref_ptr = - NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); - + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + Frame::PiBody { lam_ptr, body, depth, mut ctx } => unsafe { let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); - } - DAGPtr::Pi(pi_ptr) - }, - - ExprData::LetE(name, typ, val, body, non_dep, _) => { - let lam_ptr = - alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); - let let_ptr = alloc_let( - name.clone(), - *non_dep, - DAGPtr::Var(NonNull::dangling()), - DAGPtr::Var(NonNull::dangling()), - lam_ptr, - parents, - ); - unsafe { - let let_node = &mut *let_ptr.as_ptr(); - let typ_ref_ptr = - NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); - let val_ref_ptr = - NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); - let_node.typ = from_expr_go(typ, depth, ctx, Some(typ_ref_ptr)); - let_node.val = from_expr_go(val, depth, ctx, Some(val_ref_ptr)); - - let bod_ref_ptr = - NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); - add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref_ptr); - + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + Frame::LetBody { lam_ptr, body, depth, mut ctx } => unsafe { let lam = &mut *lam_ptr.as_ptr(); let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); - let mut new_ctx = ctx.clone(); - new_ctx.insert(depth, var_ptr); - let inner_bod_ref_ptr = + ctx.insert(depth, var_ptr); + let bod_ref = NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); - lam.bod = - from_expr_go(body, depth + 1, &new_ctx, Some(inner_bod_ref_ptr)); - } - DAGPtr::Let(let_ptr) - }, - - ExprData::Proj(type_name, idx, structure, _) => { - let proj_ptr = alloc_proj( - type_name.clone(), - idx.clone(), - DAGPtr::Var(NonNull::dangling()), - parents, - ); - unsafe { - let proj = &mut *proj_ptr.as_ptr(); - let expr_ref_ptr = - NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); - proj.expr = - from_expr_go(structure, depth, ctx, Some(expr_ref_ptr)); - } - DAGPtr::Proj(proj_ptr) - }, - - // Mdata: strip metadata, convert inner expression - ExprData::Mdata(_, inner, _) => from_expr_go(inner, depth, ctx, parents), - - // Mvar: treat as terminal (shouldn't appear in well-typed terms) - ExprData::Mvar(_name, _) => { - let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); - DAGPtr::Var(var) - }, + work.push(Frame::SetLamBod(lam_ptr)); + work.push(Frame::Visit { + expr: body, + depth: depth + 1, + ctx, + parents: Some(bod_ref), + }); + }, + } } + + results.pop().unwrap() } // ============================================================================ @@ -250,124 +472,193 @@ impl Clone for crate::ix::env::Literal { pub fn to_expr(dag: &DAG) -> Expr { let mut var_map: BTreeMap<*const Var, u64> = BTreeMap::new(); - to_expr_go(dag.head, &mut var_map, 0) + let mut cache: rustc_hash::FxHashMap<(usize, u64), Expr> = + rustc_hash::FxHashMap::default(); + to_expr_go(dag.head, &mut var_map, 0, &mut cache) } fn to_expr_go( node: DAGPtr, var_map: &mut BTreeMap<*const Var, u64>, depth: u64, + cache: &mut rustc_hash::FxHashMap<(usize, u64), Expr>, ) -> Expr { - unsafe { - match node { - DAGPtr::Var(link) => { - let var = link.as_ptr(); - let var_key = var as *const Var; - if let Some(&bind_depth) = var_map.get(&var_key) { - let idx = depth - bind_depth - 1; - Expr::bvar(Nat::from(idx)) - } else { - // Free variable - Expr::bvar(Nat::from((*var).depth)) - } - }, - - DAGPtr::Sort(link) => { - let sort = &*link.as_ptr(); - Expr::sort(sort.level.clone()) - }, - - DAGPtr::Cnst(link) => { - let cnst = &*link.as_ptr(); - Expr::cnst(cnst.name.clone(), cnst.levels.clone()) - }, + // Frame-based iterative conversion from DAG to Expr. + // + // Uses a cache keyed on (dag_ptr_key, depth) to avoid exponential + // blowup when the DAG has sharing (e.g., after beta reduction). + // + // For binder nodes (Fun, Pi, Let, Lam), the pattern is: + // 1. Visit domain/type/value children + // 2. BinderBody: register var in var_map, push Visit for body + // 3. *Build: pop results, unregister var, build Expr + // 4. CacheStore: cache the built result + enum Frame { + Visit(DAGPtr, u64), + App, + BinderBody(*const Var, DAGPtr, u64), + FunBuild(Name, BinderInfo, *const Var), + PiBuild(Name, BinderInfo, *const Var), + LetBuild(Name, bool, *const Var), + Proj(Name, Nat), + LamBuild(*const Var), + CacheStore(usize, u64), + } - DAGPtr::Lit(link) => { - let lit = &*link.as_ptr(); - Expr::lit(lit.val.clone()) + let mut work: Vec = vec![Frame::Visit(node, depth)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(node, depth) => unsafe { + // Check cache first for non-Var nodes + match node { + DAGPtr::Var(_) => {}, // Vars depend on var_map, skip cache + _ => { + let key = (dag_ptr_key(node), depth); + if let Some(cached) = cache.get(&key) { + results.push(cached.clone()); + continue; + } + }, + } + match node { + DAGPtr::Var(link) => { + let var = link.as_ptr(); + let var_key = var as *const Var; + if let Some(&bind_depth) = var_map.get(&var_key) { + results.push(Expr::bvar(Nat::from(depth - bind_depth - 1))); + } else if let Some(name) = &(*var).fvar_name { + results.push(Expr::fvar(name.clone())); + } else { + results.push(Expr::bvar(Nat::from((*var).depth))); + } + }, + DAGPtr::Sort(link) => { + let sort = &*link.as_ptr(); + results.push(Expr::sort(sort.level.clone())); + }, + DAGPtr::Cnst(link) => { + let cnst = &*link.as_ptr(); + results.push(Expr::cnst(cnst.name.clone(), cnst.levels.clone())); + }, + DAGPtr::Lit(link) => { + let lit = &*link.as_ptr(); + results.push(Expr::lit(lit.val.clone())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::App); + work.push(Frame::Visit(app.arg, depth)); + work.push(Frame::Visit(app.fun, depth)); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let lam = &*fun.img.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::FunBuild( + fun.binder_name.clone(), + fun.binder_info.clone(), + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(fun.dom, depth)); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let lam = &*pi.img.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::PiBuild( + pi.binder_name.clone(), + pi.binder_info.clone(), + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(pi.dom, depth)); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let lam = &*let_node.bod.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::LetBuild( + let_node.binder_name.clone(), + let_node.non_dep, + var_ptr, + )); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + work.push(Frame::Visit(let_node.val, depth)); + work.push(Frame::Visit(let_node.typ, depth)); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::Proj(proj.type_name.clone(), proj.idx.clone())); + work.push(Frame::Visit(proj.expr, depth)); + }, + DAGPtr::Lam(link) => { + // Standalone Lam: no domain to visit, just body + let lam = &*link.as_ptr(); + let var_ptr = &lam.var as *const Var; + work.push(Frame::CacheStore(dag_ptr_key(node), depth)); + work.push(Frame::LamBuild(var_ptr)); + work.push(Frame::BinderBody(var_ptr, lam.bod, depth)); + }, + } }, - - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - let fun = to_expr_go(app.fun, var_map, depth); - let arg = to_expr_go(app.arg, var_map, depth); - Expr::app(fun, arg) + Frame::App => { + let arg = results.pop().unwrap(); + let fun = results.pop().unwrap(); + results.push(Expr::app(fun, arg)); }, - - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - let lam = &*fun.img.as_ptr(); - let dom = to_expr_go(fun.dom, var_map, depth); - let var_ptr = &lam.var as *const Var; + Frame::BinderBody(var_ptr, body, depth) => { var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + work.push(Frame::Visit(body, depth + 1)); + }, + Frame::FunBuild(name, bi, var_ptr) => { var_map.remove(&var_ptr); - Expr::lam( - fun.binder_name.clone(), - dom, - bod, - fun.binder_info.clone(), - ) + let bod = results.pop().unwrap(); + let dom = results.pop().unwrap(); + results.push(Expr::lam(name, dom, bod, bi)); }, - - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - let lam = &*pi.img.as_ptr(); - let dom = to_expr_go(pi.dom, var_map, depth); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::PiBuild(name, bi, var_ptr) => { var_map.remove(&var_ptr); - Expr::all( - pi.binder_name.clone(), - dom, - bod, - pi.binder_info.clone(), - ) + let bod = results.pop().unwrap(); + let dom = results.pop().unwrap(); + results.push(Expr::all(name, dom, bod, bi)); }, - - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - let lam = &*let_node.bod.as_ptr(); - let typ = to_expr_go(let_node.typ, var_map, depth); - let val = to_expr_go(let_node.val, var_map, depth); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::LetBuild(name, non_dep, var_ptr) => { var_map.remove(&var_ptr); - Expr::letE( - let_node.binder_name.clone(), - typ, - val, - bod, - let_node.non_dep, - ) + let bod = results.pop().unwrap(); + let val = results.pop().unwrap(); + let typ = results.pop().unwrap(); + results.push(Expr::letE(name, typ, val, bod, non_dep)); }, - - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - let structure = to_expr_go(proj.expr, var_map, depth); - Expr::proj(proj.type_name.clone(), proj.idx.clone(), structure) + Frame::Proj(name, idx) => { + let structure = results.pop().unwrap(); + results.push(Expr::proj(name, idx, structure)); }, - - DAGPtr::Lam(link) => { - // Standalone Lam shouldn't appear at the top level, - // but handle it gracefully for completeness. - let lam = &*link.as_ptr(); - let var_ptr = &lam.var as *const Var; - var_map.insert(var_ptr, depth); - let bod = to_expr_go(lam.bod, var_map, depth + 1); + Frame::LamBuild(var_ptr) => { var_map.remove(&var_ptr); - // Wrap in a lambda with anonymous name and default binder info - Expr::lam( + let bod = results.pop().unwrap(); + results.push(Expr::lam( Name::anon(), Expr::sort(Level::zero()), bod, - crate::ix::env::BinderInfo::Default, - ) + BinderInfo::Default, + )); + }, + Frame::CacheStore(key, depth) => { + let result = results.last().unwrap().clone(); + cache.insert((key, depth), result); }, } } + + results.pop().unwrap() } #[cfg(test)] diff --git a/src/ix/kernel/dag.rs b/src/ix/kernel/dag.rs index 9837405f..ae021431 100644 --- a/src/ix/kernel/dag.rs +++ b/src/ix/kernel/dag.rs @@ -2,7 +2,9 @@ use core::ptr::NonNull; use crate::ix::env::{BinderInfo, Level, Literal, Name}; use crate::lean::nat::Nat; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; + +use super::level::subst_level; use super::dll::DLL; @@ -131,17 +133,12 @@ pub struct Var { pub depth: u64, /// Points to the binding Lam, or Free for free variables. pub binder: BinderPtr, + /// If this Var came from an Fvar, preserves the name for roundtrip. + pub fvar_name: Option, /// Parent pointers. pub parents: Option>, } -impl Copy for Var {} -impl Clone for Var { - fn clone(&self) -> Self { - *self - } -} - /// Sort node (universe). #[repr(C)] pub struct Sort { @@ -260,7 +257,7 @@ pub fn alloc_lam( let lam_ptr = alloc_val(Lam { bod, bod_ref: DLL::singleton(ParentPtr::Root), - var: Var { depth, binder: BinderPtr::Free, parents: None }, + var: Var { depth, binder: BinderPtr::Free, fvar_name: None, parents: None }, parents, }); unsafe { @@ -469,59 +466,587 @@ pub fn free_dag(dag: DAG) { free_dag_nodes(dag.head, &mut visited); } -fn free_dag_nodes(node: DAGPtr, visited: &mut FxHashSet) { - let key = dag_ptr_key(node); - if !visited.insert(key) { - return; - } - unsafe { - match node { - DAGPtr::Var(link) => { - let var = &*link.as_ptr(); - // Only free separately-allocated free vars; bound vars are - // embedded in their Lam struct and freed with it. - if let BinderPtr::Free = var.binder { +fn free_dag_nodes(root: DAGPtr, visited: &mut FxHashSet) { + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + let key = dag_ptr_key(node); + if !visited.insert(key) { + continue; + } + unsafe { + match node { + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + stack.push(lam.bod); drop(Box::from_raw(link.as_ptr())); - } - }, - DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lam(link) => { - let lam = &*link.as_ptr(); - free_dag_nodes(lam.bod, visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - free_dag_nodes(fun.dom, visited); - free_dag_nodes(DAGPtr::Lam(fun.img), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - free_dag_nodes(pi.dom, visited); - free_dag_nodes(DAGPtr::Lam(pi.img), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - free_dag_nodes(app.fun, visited); - free_dag_nodes(app.arg, visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - free_dag_nodes(let_node.typ, visited); - free_dag_nodes(let_node.val, visited); - free_dag_nodes(DAGPtr::Lam(let_node.bod), visited); - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - free_dag_nodes(proj.expr, visited); - drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + stack.push(fun.dom); + stack.push(DAGPtr::Lam(fun.img)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + stack.push(pi.dom); + stack.push(DAGPtr::Lam(pi.img)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + stack.push(app.fun); + stack.push(app.arg); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + stack.push(let_node.typ); + stack.push(let_node.val); + stack.push(DAGPtr::Lam(let_node.bod)); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + stack.push(proj.expr); + drop(Box::from_raw(link.as_ptr())); + }, + } + } + } +} + +// ============================================================================ +// DAG utilities for typechecker +// ============================================================================ + +/// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])` at the DAG level. +pub fn dag_unfold_apps(dag: DAGPtr) -> (DAGPtr, Vec) { + let mut args = Vec::new(); + let mut cursor = dag; + loop { + match cursor { + DAGPtr::App(app) => unsafe { + let app_ref = &*app.as_ptr(); + args.push(app_ref.arg); + cursor = app_ref.fun; }, + _ => break, } } + args.reverse(); + (cursor, args) +} + +/// Reconstruct `f a1 a2 ... an` from a head and arguments at the DAG level. +pub fn dag_foldl_apps(fun: DAGPtr, args: &[DAGPtr]) -> DAGPtr { + let mut result = fun; + for &arg in args { + let app = alloc_app(result, arg, None); + result = DAGPtr::App(app); + } + result +} + +/// Substitute universe level parameters in-place throughout a DAG. +/// +/// Replaces `Level::param(params[i])` with `values[i]` in all Sort and Cnst +/// nodes reachable from `root`. Uses a visited set to handle DAG sharing. +/// +/// The DAG must not be shared with other live structures, since this mutates +/// nodes in place (intended for freshly `from_expr`'d DAGs). +pub fn subst_dag_levels( + root: DAGPtr, + params: &[Name], + values: &[Level], +) -> DAGPtr { + if params.is_empty() { + return root; + } + let mut visited = FxHashSet::default(); + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + let key = dag_ptr_key(node); + if !visited.insert(key) { + continue; + } + unsafe { + match node { + DAGPtr::Sort(p) => { + let sort = &mut *p.as_ptr(); + sort.level = subst_level(&sort.level, params, values); + }, + DAGPtr::Cnst(p) => { + let cnst = &mut *p.as_ptr(); + cnst.levels = + cnst.levels.iter().map(|l| subst_level(l, params, values)).collect(); + }, + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + stack.push(app.fun); + stack.push(app.arg); + }, + DAGPtr::Fun(p) => { + let fun = &*p.as_ptr(); + stack.push(fun.dom); + stack.push(DAGPtr::Lam(fun.img)); + }, + DAGPtr::Pi(p) => { + let pi = &*p.as_ptr(); + stack.push(pi.dom); + stack.push(DAGPtr::Lam(pi.img)); + }, + DAGPtr::Lam(p) => { + let lam = &*p.as_ptr(); + stack.push(lam.bod); + }, + DAGPtr::Let(p) => { + let let_node = &*p.as_ptr(); + stack.push(let_node.typ); + stack.push(let_node.val); + stack.push(DAGPtr::Lam(let_node.bod)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + stack.push(proj.expr); + }, + DAGPtr::Var(_) | DAGPtr::Lit(_) => {}, + } + } + } + root +} + +// ============================================================================ +// Deep-copy substitution for typechecker +// ============================================================================ + +/// Deep-copy a Lam body, substituting `replacement` for the Lam's bound variable. +/// +/// Unlike `subst_pi_body` (which mutates nodes in place via BUBS), this creates +/// a completely fresh DAG. This prevents the type DAG from sharing mutable nodes +/// with the term DAG, avoiding corruption when WHNF later beta-reduces in the +/// type DAG. +/// +/// The `replacement` is also deep-copied to prevent WHNF's `reduce_lam` from +/// modifying the original term DAG when it beta-reduces through substituted +/// Fun/Lam nodes. Vars not bound within the copy scope (outer-binder vars and +/// free vars) are preserved by pointer to maintain identity for `def_eq`. +/// +/// Deep-copy the Lam body with substitution. Used when the Lam is from +/// the TERM DAG (e.g., `infer_lambda`, `infer_pi`, `infer_let`) to +/// protect the term from destructive in-place modification. +/// +/// The replacement is also deep-copied to isolate the term DAG from +/// WHNF mutations. Vars not bound within the copy scope are preserved +/// by pointer to maintain identity for `def_eq`. +pub fn dag_copy_subst(lam: NonNull, replacement: DAGPtr) -> DAGPtr { + use std::sync::atomic::{AtomicU64, Ordering}; + static COPY_SUBST_CALLS: AtomicU64 = AtomicU64::new(0); + static COPY_SUBST_NODES: AtomicU64 = AtomicU64::new(0); + let call_num = COPY_SUBST_CALLS.fetch_add(1, Ordering::Relaxed); + + let mut cache: FxHashMap = FxHashMap::default(); + unsafe { + let lambda = &*lam.as_ptr(); + let var_ptr = + NonNull::new(&lambda.var as *const Var as *mut Var).unwrap(); + let var_key = dag_ptr_key(DAGPtr::Var(var_ptr)); + // Deep-copy the replacement (isolates from term DAG mutations) + let copied_replacement = dag_copy_node(replacement, &mut cache); + let repl_nodes = cache.len(); + // Clear cache: body and replacement are separate DAGs, no shared nodes. + cache.clear(); + // Map the target var to the copied replacement + cache.insert(var_key, copied_replacement); + // Deep copy the body + let result = dag_copy_node(lambda.bod, &mut cache); + let body_nodes = cache.len(); + let total = COPY_SUBST_NODES.fetch_add(body_nodes as u64, Ordering::Relaxed) + body_nodes as u64; + if call_num % 10 == 0 || body_nodes > 1000 { + eprintln!("[dag_copy_subst] call={call_num} repl={repl_nodes} body={body_nodes} total_nodes={total}"); + } + result + } +} + +/// Lightweight substitution for TYPE DAG Lams (from `from_expr` or derived). +/// Only the replacement is deep-copied; the body is modified in-place via +/// BUBS `subst_pi_body`, preserving DAG sharing and avoiding exponential +/// blowup. +pub fn dag_type_subst(lam: NonNull, replacement: DAGPtr) -> DAGPtr { + use super::upcopy::subst_pi_body; + let mut cache: FxHashMap = FxHashMap::default(); + let copied_replacement = dag_copy_node(replacement, &mut cache); + subst_pi_body(lam, copied_replacement) +} + +/// Iteratively copy a DAG node, using `cache` for sharing and var substitution. +/// +/// Uses an explicit work stack to avoid stack overflow on deeply nested DAGs +/// (e.g., 40000+ left-nested App chains from unfolded definitions). +fn dag_copy_node( + root: DAGPtr, + cache: &mut FxHashMap, +) -> DAGPtr { + // Stack frames for the iterative traversal. + // Compound nodes use a two-phase approach: + // Visit → push children + Finish frame → children processed → Finish builds node + // Binder nodes (Fun/Pi/Let/Lam) use three phases: + // Visit → push dom/typ/val + CreateLam → CreateLam inserts var mapping + pushes body + Finish + enum Frame { + Visit(DAGPtr), + FinishApp(usize, NonNull), + FinishProj(usize, NonNull), + CreateFunLam(usize, NonNull), + FinishFun(usize, NonNull, NonNull), + CreatePiLam(usize, NonNull), + FinishPi(usize, NonNull, NonNull), + CreateLamBody(usize, NonNull), + // FinishLam(key, new_lam, old_lam) — old_lam needed to look up body key + FinishLam(usize, NonNull, NonNull), + CreateLetLam(usize, NonNull), + FinishLet(usize, NonNull, NonNull), + } + + let mut stack: Vec = vec![Frame::Visit(root)]; + // Track nodes that have been visited (started processing) to prevent + // exponential blowup when copying DAGs with shared compound nodes. + // Without this, a shared node visited from two parents would be + // processed twice, leading to 2^depth duplication. + let mut visited: FxHashSet = FxHashSet::default(); + // Deferred back-edge patches: (key_of_placeholder, original_node) + // WHNF iota reduction can create cyclic DAGs (e.g., Nat.rec step + // function body → recursive Nat.rec result → step function). + // When we encounter a back-edge during copy, we allocate a placeholder + // and record it here. After the main traversal completes, we patch + // each placeholder's children to point to the cached (copied) versions. + let mut deferred: Vec<(usize, DAGPtr)> = Vec::new(); + + while let Some(frame) = stack.pop() { + unsafe { + match frame { + Frame::Visit(node) => { + let key = dag_ptr_key(node); + if cache.contains_key(&key) { + continue; + } + if visited.contains(&key) { + // Cycle back-edge: allocate placeholder, defer patching + match node { + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + let placeholder = alloc_app(app.fun, app.arg, None); + cache.insert(key, DAGPtr::App(placeholder)); + deferred.push((key, node)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + let placeholder = alloc_proj( + proj.type_name.clone(), proj.idx.clone(), proj.expr, None, + ); + cache.insert(key, DAGPtr::Proj(placeholder)); + deferred.push((key, node)); + }, + // Leaf-like nodes shouldn't cycle; handle just in case + _ => { + cache.insert(key, node); + }, + } + continue; + } + visited.insert(key); + match node { + DAGPtr::Var(_) => { + // Not in cache: outer-binder or free var. Preserve original. + cache.insert(key, node); + }, + DAGPtr::Sort(p) => { + let sort = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Sort(alloc_val(Sort { + level: sort.level.clone(), + parents: None, + })), + ); + }, + DAGPtr::Cnst(p) => { + let cnst = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Cnst(alloc_val(Cnst { + name: cnst.name.clone(), + levels: cnst.levels.clone(), + parents: None, + })), + ); + }, + DAGPtr::Lit(p) => { + let lit = &*p.as_ptr(); + cache.insert( + key, + DAGPtr::Lit(alloc_val(LitNode { + val: lit.val.clone(), + parents: None, + })), + ); + }, + DAGPtr::App(p) => { + let app = &*p.as_ptr(); + // Finish after children; visit fun then arg + stack.push(Frame::FinishApp(key, p)); + stack.push(Frame::Visit(app.arg)); + stack.push(Frame::Visit(app.fun)); + }, + DAGPtr::Fun(p) => { + let fun = &*p.as_ptr(); + // Phase 1: visit dom, then create Lam + stack.push(Frame::CreateFunLam(key, p)); + stack.push(Frame::Visit(fun.dom)); + }, + DAGPtr::Pi(p) => { + let pi = &*p.as_ptr(); + stack.push(Frame::CreatePiLam(key, p)); + stack.push(Frame::Visit(pi.dom)); + }, + DAGPtr::Lam(p) => { + // Standalone Lam: create Lam, then visit body + stack.push(Frame::CreateLamBody(key, p)); + }, + DAGPtr::Let(p) => { + let let_node = &*p.as_ptr(); + // Visit typ and val, then create Lam + stack.push(Frame::CreateLetLam(key, p)); + stack.push(Frame::Visit(let_node.val)); + stack.push(Frame::Visit(let_node.typ)); + }, + DAGPtr::Proj(p) => { + let proj = &*p.as_ptr(); + stack.push(Frame::FinishProj(key, p)); + stack.push(Frame::Visit(proj.expr)); + }, + } + }, + + Frame::FinishApp(key, app_ptr) => { + let app = &*app_ptr.as_ptr(); + let new_fun = cache[&dag_ptr_key(app.fun)]; + let new_arg = cache[&dag_ptr_key(app.arg)]; + let new_app = alloc_app(new_fun, new_arg, None); + let app_ref = &mut *new_app.as_ptr(); + let fun_ref = + NonNull::new(&mut app_ref.fun_ref as *mut Parents).unwrap(); + add_to_parents(new_fun, fun_ref); + let arg_ref = + NonNull::new(&mut app_ref.arg_ref as *mut Parents).unwrap(); + add_to_parents(new_arg, arg_ref); + cache.insert(key, DAGPtr::App(new_app)); + }, + + Frame::FinishProj(key, proj_ptr) => { + let proj = &*proj_ptr.as_ptr(); + let new_expr = cache[&dag_ptr_key(proj.expr)]; + let new_proj = alloc_proj( + proj.type_name.clone(), + proj.idx.clone(), + new_expr, + None, + ); + let proj_ref = &mut *new_proj.as_ptr(); + let expr_ref = + NonNull::new(&mut proj_ref.expr_ref as *mut Parents).unwrap(); + add_to_parents(new_expr, expr_ref); + cache.insert(key, DAGPtr::Proj(new_proj)); + }, + + // --- Fun binder: dom visited, create Lam, visit body --- + Frame::CreateFunLam(key, fun_ptr) => { + let fun = &*fun_ptr.as_ptr(); + let old_lam = &*fun.img.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + // Phase 2: visit body, then finish + stack.push(Frame::FinishFun(key, fun_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishFun(key, fun_ptr, new_lam) => { + let fun = &*fun_ptr.as_ptr(); + let old_lam = &*fun.img.as_ptr(); + let new_dom = cache[&dag_ptr_key(fun.dom)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_fun_node = alloc_fun( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_dom, + new_lam, + None, + ); + let fun_ref = &mut *new_fun_node.as_ptr(); + let dom_ref = + NonNull::new(&mut fun_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(new_dom, dom_ref); + let img_ref = + NonNull::new(&mut fun_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), img_ref); + cache.insert(key, DAGPtr::Fun(new_fun_node)); + }, + + // --- Pi binder: dom visited, create Lam, visit body --- + Frame::CreatePiLam(key, pi_ptr) => { + let pi = &*pi_ptr.as_ptr(); + let old_lam = &*pi.img.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishPi(key, pi_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishPi(key, pi_ptr, new_lam) => { + let pi = &*pi_ptr.as_ptr(); + let old_lam = &*pi.img.as_ptr(); + let new_dom = cache[&dag_ptr_key(pi.dom)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_pi = alloc_pi( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_dom, + new_lam, + None, + ); + let pi_ref = &mut *new_pi.as_ptr(); + let dom_ref = + NonNull::new(&mut pi_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(new_dom, dom_ref); + let img_ref = + NonNull::new(&mut pi_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), img_ref); + cache.insert(key, DAGPtr::Pi(new_pi)); + }, + + // --- Standalone Lam: create Lam, visit body --- + Frame::CreateLamBody(key, old_lam_ptr) => { + let old_lam = &*old_lam_ptr.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishLam(key, new_lam, old_lam_ptr)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishLam(key, new_lam, old_lam_ptr) => { + let old_lam = &*old_lam_ptr.as_ptr(); + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + cache.insert(key, DAGPtr::Lam(new_lam)); + }, + + // --- Let binder: typ+val visited, create Lam, visit body --- + Frame::CreateLetLam(key, let_ptr) => { + let let_node = &*let_ptr.as_ptr(); + let old_lam = &*let_node.bod.as_ptr(); + let old_var_ptr = + NonNull::new(&old_lam.var as *const Var as *mut Var).unwrap(); + let old_var_key = dag_ptr_key(DAGPtr::Var(old_var_ptr)); + let new_lam = alloc_lam( + old_lam.var.depth, + DAGPtr::Var(NonNull::dangling()), + None, + ); + let new_lam_ref = &mut *new_lam.as_ptr(); + let new_var = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + cache.insert(old_var_key, DAGPtr::Var(new_var)); + stack.push(Frame::FinishLet(key, let_ptr, new_lam)); + stack.push(Frame::Visit(old_lam.bod)); + }, + + Frame::FinishLet(key, let_ptr, new_lam) => { + let let_node = &*let_ptr.as_ptr(); + let old_lam = &*let_node.bod.as_ptr(); + let new_typ = cache[&dag_ptr_key(let_node.typ)]; + let new_val = cache[&dag_ptr_key(let_node.val)]; + let new_bod = cache[&dag_ptr_key(old_lam.bod)]; + let new_lam_ref = &mut *new_lam.as_ptr(); + new_lam_ref.bod = new_bod; + let bod_ref = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_bod, bod_ref); + let new_let = alloc_let( + let_node.binder_name.clone(), + let_node.non_dep, + new_typ, + new_val, + new_lam, + None, + ); + let let_ref = &mut *new_let.as_ptr(); + let typ_ref = + NonNull::new(&mut let_ref.typ_ref as *mut Parents).unwrap(); + add_to_parents(new_typ, typ_ref); + let val_ref = + NonNull::new(&mut let_ref.val_ref as *mut Parents).unwrap(); + add_to_parents(new_val, val_ref); + let bod_ref2 = + NonNull::new(&mut let_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(new_lam), bod_ref2); + cache.insert(key, DAGPtr::Let(new_let)); + }, + } + } + } + + cache[&dag_ptr_key(root)] } diff --git a/src/ix/kernel/dag_tc.rs b/src/ix/kernel/dag_tc.rs new file mode 100644 index 00000000..3b70d03d --- /dev/null +++ b/src/ix/kernel/dag_tc.rs @@ -0,0 +1,2857 @@ +use core::ptr::NonNull; + +use num_bigint::BigUint; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rustc_hash::FxHashMap; + +use crate::ix::env::{ + BinderInfo, ConstantInfo, Env, Level, Literal, Name, ReducibilityHints, +}; +use crate::lean::nat::Nat; + +use super::convert::{from_expr, to_expr}; +use super::dag::*; +use super::error::TcError; +use super::level::{ + all_expr_uparams_defined, eq_antisymm, eq_antisymm_many, is_zero, + no_dupes_all_params, +}; +use super::upcopy::replace_child; +use super::whnf::{ + has_loose_bvars, mk_name2, nat_lit_dag, subst_expr_levels, + try_reduce_native_dag, try_reduce_nat_dag, whnf_dag, +}; + +type TcResult = Result; + +/// DAG-native type checker. +/// +/// Operates directly on `DAGPtr` nodes, avoiding Expr↔DAG round-trips. +/// Caches are keyed by `dag_ptr_key` (raw pointer address), which is safe +/// because DAG nodes are never freed during a single `check_declar` call. +pub struct DagTypeChecker<'env> { + pub env: &'env Env, + pub whnf_cache: FxHashMap, + pub whnf_no_delta_cache: FxHashMap, + pub infer_cache: FxHashMap, + /// Cache for `infer_const` results, keyed by the Blake3 hash of the + /// Cnst node's Expr representation (name + levels). Avoids repeated + /// `from_expr` calls for the same constant at the same universe levels. + pub const_type_cache: FxHashMap, + pub local_counter: u64, + pub local_types: FxHashMap, + /// Stack of corresponding bound variable pairs for binder comparison. + /// Each entry `(key_x, key_y)` means `Var_x` and `Var_y` should be + /// treated as equal when comparing under their respective binders. + binder_eq_map: Vec<(usize, usize)>, + // Debug counters + whnf_calls: u64, + def_eq_calls: u64, + infer_calls: u64, + infer_depth: u64, + infer_max_depth: u64, +} + +impl<'env> DagTypeChecker<'env> { + pub fn new(env: &'env Env) -> Self { + DagTypeChecker { + env, + whnf_cache: FxHashMap::default(), + whnf_no_delta_cache: FxHashMap::default(), + infer_cache: FxHashMap::default(), + const_type_cache: FxHashMap::default(), + local_counter: 0, + local_types: FxHashMap::default(), + binder_eq_map: Vec::new(), + whnf_calls: 0, + def_eq_calls: 0, + infer_calls: 0, + infer_depth: 0, + infer_max_depth: 0, + } + } + + // ========================================================================== + // WHNF with caching + // ========================================================================== + + /// Reduce a DAG node to weak head normal form. + /// + /// Checks the cache first, then calls `whnf_dag` and caches the result. + pub fn whnf(&mut self, ptr: DAGPtr) -> DAGPtr { + self.whnf_calls += 1; + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.whnf_cache.get(&key) { + return cached; + } + let t0 = std::time::Instant::now(); + let mut dag = DAG { head: ptr }; + whnf_dag(&mut dag, self.env, false); + let result = dag.head; + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[whnf SLOW] {}ms whnf_calls={}", ms, self.whnf_calls); + } + self.whnf_cache.insert(key, result); + result + } + + /// Reduce to WHNF without delta (definition) unfolding. + /// + /// Used in definitional equality to try structural comparison before + /// committing to delta reduction. + pub fn whnf_no_delta(&mut self, ptr: DAGPtr) -> DAGPtr { + self.whnf_calls += 1; + if self.whnf_calls % 100 == 0 { + eprintln!("[DagTC::whnf_no_delta] calls={}", self.whnf_calls); + } + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.whnf_no_delta_cache.get(&key) { + return cached; + } + let mut dag = DAG { head: ptr }; + whnf_dag(&mut dag, self.env, true); + let result = dag.head; + self.whnf_no_delta_cache.insert(key, result); + result + } + + // ========================================================================== + // Ensure helpers + // ========================================================================== + + /// If `ptr` is already a Sort, return its level. Otherwise WHNF and check. + pub fn ensure_sort(&mut self, ptr: DAGPtr) -> TcResult { + if let DAGPtr::Sort(p) = ptr { + let level = unsafe { &(*p.as_ptr()).level }; + return Ok(level.clone()); + } + let t0 = std::time::Instant::now(); + let whnfd = self.whnf(ptr); + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[ensure_sort] whnf took {}ms", ms); + } + match whnfd { + DAGPtr::Sort(p) => { + let level = unsafe { &(*p.as_ptr()).level }; + Ok(level.clone()) + }, + _ => Err(TcError::TypeExpected { + expr: dag_to_expr(ptr), + inferred: dag_to_expr(whnfd), + }), + } + } + + /// If `ptr` is already a Pi, return it. Otherwise WHNF and check. + pub fn ensure_pi(&mut self, ptr: DAGPtr) -> TcResult { + if let DAGPtr::Pi(_) = ptr { + return Ok(ptr); + } + let t0 = std::time::Instant::now(); + let whnfd = self.whnf(ptr); + let ms = t0.elapsed().as_millis(); + if ms > 100 { + eprintln!("[ensure_pi] whnf took {}ms", ms); + } + match whnfd { + DAGPtr::Pi(_) => Ok(whnfd), + _ => Err(TcError::FunctionExpected { + expr: dag_to_expr(ptr), + inferred: dag_to_expr(whnfd), + }), + } + } + + /// Infer the type of `ptr` and ensure it's a Sort; return the universe level. + pub fn infer_sort_of(&mut self, ptr: DAGPtr) -> TcResult { + let ty = self.infer(ptr)?; + let whnfd = self.whnf(ty); + self.ensure_sort(whnfd) + } + + // ========================================================================== + // Definitional equality + // ========================================================================== + + /// Check definitional equality of two DAG nodes. + /// + /// Uses a conjunction work stack: processes pairs iteratively, all must + /// be equal. Binder comparison uses recursive calls with a binder + /// correspondence map rather than pushing raw bodies. + pub fn def_eq(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.def_eq_calls += 1; + eprintln!("[def_eq#{}] depth={}", self.def_eq_calls, self.infer_depth); + const STEP_LIMIT: u64 = 1_000_000; + let mut work: Vec<(DAGPtr, DAGPtr)> = vec![(x, y)]; + let mut steps: u64 = 0; + while let Some((x, y)) = work.pop() { + steps += 1; + if steps > STEP_LIMIT { + return false; + } + if !self.def_eq_step(x, y, &mut work) { + return false; + } + } + true + } + + /// Quick syntactic checks at DAG level. + fn def_eq_quick_check(&self, x: DAGPtr, y: DAGPtr) -> Option { + if dag_ptr_key(x) == dag_ptr_key(y) { + return Some(true); + } + unsafe { + match (x, y) { + (DAGPtr::Sort(a), DAGPtr::Sort(b)) => { + Some(eq_antisymm(&(*a.as_ptr()).level, &(*b.as_ptr()).level)) + }, + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => { + let ca = &*a.as_ptr(); + let cb = &*b.as_ptr(); + if ca.name == cb.name && eq_antisymm_many(&ca.levels, &cb.levels) { + Some(true) + } else { + None // different names may still be delta-equal + } + }, + (DAGPtr::Lit(a), DAGPtr::Lit(b)) => { + Some((*a.as_ptr()).val == (*b.as_ptr()).val) + }, + (DAGPtr::Var(a), DAGPtr::Var(b)) => { + let va = &*a.as_ptr(); + let vb = &*b.as_ptr(); + match (&va.fvar_name, &vb.fvar_name) { + (Some(na), Some(nb)) => { + if na == nb { Some(true) } else { None } + }, + (None, None) => { + let ka = dag_ptr_key(x); + let kb = dag_ptr_key(y); + Some( + self + .binder_eq_map + .iter() + .any(|&(ma, mb)| ma == ka && mb == kb), + ) + }, + _ => Some(false), + } + }, + _ => None, + } + } + } + + /// Process one def_eq pair. + fn def_eq_step( + &mut self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + if let Some(quick) = self.def_eq_quick_check(x, y) { + return quick; + } + let x_n = self.whnf_no_delta(x); + let y_n = self.whnf_no_delta(y); + if let Some(quick) = self.def_eq_quick_check(x_n, y_n) { + return quick; + } + if self.proof_irrel_eq(x_n, y_n) { + return true; + } + match self.lazy_delta_step(x_n, y_n) { + DagDeltaResult::Found(result) => result, + DagDeltaResult::Exhausted(x_e, y_e) => { + if self.def_eq_const(x_e, y_e) { return true; } + if self.def_eq_proj_push(x_e, y_e, work) { return true; } + if self.def_eq_app_push(x_e, y_e, work) { return true; } + if self.def_eq_binder_full(x_e, y_e) { return true; } + if self.try_eta_expansion(x_e, y_e) { return true; } + if self.try_eta_struct(x_e, y_e) { return true; } + if self.is_def_eq_unit_like(x_e, y_e) { return true; } + false + }, + } + } + + // --- Proof irrelevance --- + + /// If both x and y are proofs of the same proposition, they are def-eq. + fn proof_irrel_eq(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + // Skip for binder types: inferring Fun/Pi/Lam would recurse into + // binder bodies. Kept as a conservative guard for def_eq_binder_full. + if matches!(x, DAGPtr::Fun(_) | DAGPtr::Pi(_) | DAGPtr::Lam(_)) { + return false; + } + if matches!(y, DAGPtr::Fun(_) | DAGPtr::Pi(_) | DAGPtr::Lam(_)) { + return false; + } + let x_ty = match self.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.is_proposition(x_ty) { + return false; + } + let y_ty = match self.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.is_proposition(y_ty) { + return false; + } + self.def_eq(x_ty, y_ty) + } + + /// Check if a type lives in Prop (Sort 0). + fn is_proposition(&mut self, ty: DAGPtr) -> bool { + let whnfd = self.whnf(ty); + match whnfd { + DAGPtr::Sort(s) => unsafe { is_zero(&(*s.as_ptr()).level) }, + _ => false, + } + } + + // --- Lazy delta --- + + fn lazy_delta_step( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> DagDeltaResult { + let mut x = x; + let mut y = y; + let mut iters: u32 = 0; + const MAX_DELTA_ITERS: u32 = 10_000; + loop { + iters += 1; + if iters > MAX_DELTA_ITERS { + return DagDeltaResult::Exhausted(x, y); + } + + if let Some(quick) = self.def_eq_nat_offset(x, y) { + return DagDeltaResult::Found(quick); + } + + if let Some(x_r) = try_lazy_delta_nat_native(x, self.env) { + let x_r = self.whnf_no_delta(x_r); + if let Some(quick) = self.def_eq_quick_check(x_r, y) { + return DagDeltaResult::Found(quick); + } + x = x_r; + continue; + } + if let Some(y_r) = try_lazy_delta_nat_native(y, self.env) { + let y_r = self.whnf_no_delta(y_r); + if let Some(quick) = self.def_eq_quick_check(x, y_r) { + return DagDeltaResult::Found(quick); + } + y = y_r; + continue; + } + + let x_def = dag_get_applied_def(x, self.env); + let y_def = dag_get_applied_def(y, self.env); + match (&x_def, &y_def) { + (None, None) => return DagDeltaResult::Exhausted(x, y), + (Some(_), None) => { + x = self.dag_delta(x); + }, + (None, Some(_)) => { + y = self.dag_delta(y); + }, + (Some((x_name, x_hint)), Some((y_name, y_hint))) => { + if x_name == y_name && x_hint == y_hint { + if self.def_eq_app_eager(x, y) { + return DagDeltaResult::Found(true); + } + x = self.dag_delta(x); + y = self.dag_delta(y); + } else if hint_lt(x_hint, y_hint) { + y = self.dag_delta(y); + } else { + x = self.dag_delta(x); + } + }, + } + + if let Some(quick) = self.def_eq_quick_check(x, y) { + return DagDeltaResult::Found(quick); + } + } + } + + /// Unfold a definition and do cheap WHNF (no delta). + fn dag_delta(&mut self, ptr: DAGPtr) -> DAGPtr { + match dag_try_unfold_def(ptr, self.env) { + Some(unfolded) => self.whnf_no_delta(unfolded), + None => ptr, + } + } + + // --- Nat offset equality --- + + fn def_eq_nat_offset( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> Option { + if is_nat_zero_dag(x) && is_nat_zero_dag(y) { + return Some(true); + } + match (is_nat_succ_dag(x), is_nat_succ_dag(y)) { + (Some(x_pred), Some(y_pred)) => Some(self.def_eq(x_pred, y_pred)), + _ => None, + } + } + + // --- Congruence --- + + fn def_eq_const(&self, x: DAGPtr, y: DAGPtr) -> bool { + unsafe { + match (x, y) { + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => { + let ca = &*a.as_ptr(); + let cb = &*b.as_ptr(); + ca.name == cb.name && eq_antisymm_many(&ca.levels, &cb.levels) + }, + _ => false, + } + } + } + + fn def_eq_proj_push( + &self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + unsafe { + match (x, y) { + (DAGPtr::Proj(a), DAGPtr::Proj(b)) => { + let pa = &*a.as_ptr(); + let pb = &*b.as_ptr(); + if pa.idx == pb.idx { + work.push((pa.expr, pb.expr)); + true + } else { + false + } + }, + _ => false, + } + } + } + + fn def_eq_app_push( + &self, + x: DAGPtr, + y: DAGPtr, + work: &mut Vec<(DAGPtr, DAGPtr)>, + ) -> bool { + let (f1, args1) = dag_unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = dag_unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + work.push((f1, f2)); + for (&a, &b) in args1.iter().zip(args2.iter()) { + work.push((a, b)); + } + true + } + + /// Eager app congruence (used by lazy_delta_step). + fn def_eq_app_eager(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let (f1, args1) = dag_unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = dag_unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + if !self.def_eq(f1, f2) { + return false; + } + args1.iter().zip(args2.iter()).all(|(&a, &b)| self.def_eq(a, b)) + } + + // --- Binder full --- + + /// Compare Pi/Fun binders: peel matching layers, push var correspondence + /// into `binder_eq_map`, and compare bodies recursively. + fn def_eq_binder_full(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let mut cx = x; + let mut cy = y; + let mut matched = false; + let mut n_pushed: usize = 0; + loop { + unsafe { + match (cx, cy) { + (DAGPtr::Pi(px), DAGPtr::Pi(py)) => { + let pi_x = &*px.as_ptr(); + let pi_y = &*py.as_ptr(); + if !self.def_eq(pi_x.dom, pi_y.dom) { + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + return false; + } + let lam_x = &*pi_x.img.as_ptr(); + let lam_y = &*pi_y.img.as_ptr(); + let var_x_ptr = NonNull::new( + &lam_x.var as *const Var as *mut Var, + ) + .unwrap(); + let var_y_ptr = NonNull::new( + &lam_y.var as *const Var as *mut Var, + ) + .unwrap(); + self.binder_eq_map.push(( + dag_ptr_key(DAGPtr::Var(var_x_ptr)), + dag_ptr_key(DAGPtr::Var(var_y_ptr)), + )); + n_pushed += 1; + cx = lam_x.bod; + cy = lam_y.bod; + matched = true; + }, + (DAGPtr::Fun(fx), DAGPtr::Fun(fy)) => { + let fun_x = &*fx.as_ptr(); + let fun_y = &*fy.as_ptr(); + if !self.def_eq(fun_x.dom, fun_y.dom) { + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + return false; + } + let lam_x = &*fun_x.img.as_ptr(); + let lam_y = &*fun_y.img.as_ptr(); + let var_x_ptr = NonNull::new( + &lam_x.var as *const Var as *mut Var, + ) + .unwrap(); + let var_y_ptr = NonNull::new( + &lam_y.var as *const Var as *mut Var, + ) + .unwrap(); + self.binder_eq_map.push(( + dag_ptr_key(DAGPtr::Var(var_x_ptr)), + dag_ptr_key(DAGPtr::Var(var_y_ptr)), + )); + n_pushed += 1; + cx = lam_x.bod; + cy = lam_y.bod; + matched = true; + }, + _ => break, + } + } + } + if !matched { + return false; + } + let result = self.def_eq(cx, cy); + for _ in 0..n_pushed { + self.binder_eq_map.pop(); + } + result + } + + // --- Eta expansion --- + + fn try_eta_expansion(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.try_eta_expansion_aux(x, y) + || self.try_eta_expansion_aux(y, x) + } + + /// Eta: `fun x => f x` ≡ `f` when `f : (x : A) → B`. + fn try_eta_expansion_aux( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> bool { + let fx = match x { + DAGPtr::Fun(f) => f, + _ => return false, + }; + let y_ty = match self.infer(y) { + Ok(t) => t, + Err(_) => return false, + }; + let y_ty_whnf = self.whnf(y_ty); + if !matches!(y_ty_whnf, DAGPtr::Pi(_)) { + return false; + } + unsafe { + let fun_x = &*fx.as_ptr(); + let lam_x = &*fun_x.img.as_ptr(); + let var_x_ptr = + NonNull::new(&lam_x.var as *const Var as *mut Var).unwrap(); + let var_x = DAGPtr::Var(var_x_ptr); + // Build eta body: App(y, var_x) + // Using the SAME var_x on both sides, so pointer identity + // handles bound variable matching without binder_eq_map. + let eta_body = DAGPtr::App(alloc_app(y, var_x, None)); + self.def_eq(lam_x.bod, eta_body) + } + } + + // --- Struct eta --- + + fn try_eta_struct(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + self.try_eta_struct_core(x, y) + || self.try_eta_struct_core(y, x) + } + + /// Structure eta: `p =def= S.mk (S.1 p) (S.2 p)` when S is a + /// single-constructor non-recursive inductive with no indices. + fn try_eta_struct_core(&mut self, t: DAGPtr, s: DAGPtr) -> bool { + let (head, args) = dag_unfold_apps(s); + let ctor_name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return false, + }; + let ctor_info = match self.env.get(&ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return false, + }; + if !is_structure_like(&ctor_info.induct, self.env) { + return false; + } + let num_params = ctor_info.num_params.to_u64().unwrap() as usize; + let num_fields = ctor_info.num_fields.to_u64().unwrap() as usize; + if args.len() != num_params + num_fields { + return false; + } + for i in 0..num_fields { + let field = args[num_params + i]; + let proj = alloc_proj( + ctor_info.induct.clone(), + Nat::from(i as u64), + t, + None, + ); + if !self.def_eq(field, DAGPtr::Proj(proj)) { + return false; + } + } + true + } + + // --- Unit-like equality --- + + /// Types with a single zero-field constructor have all inhabitants def-eq. + fn is_def_eq_unit_like(&mut self, x: DAGPtr, y: DAGPtr) -> bool { + let x_ty = match self.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + let y_ty = match self.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !self.def_eq(x_ty, y_ty) { + return false; + } + let whnf_ty = self.whnf(x_ty); + let (head, _) = dag_unfold_apps(whnf_ty); + let name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return false, + }; + match self.env.get(&name) { + Some(ConstantInfo::InductInfo(iv)) => { + if iv.ctors.len() != 1 { + return false; + } + if let Some(ConstantInfo::CtorInfo(c)) = + self.env.get(&iv.ctors[0]) + { + c.num_fields == Nat::ZERO + } else { + false + } + }, + _ => false, + } + } + + /// Assert that two DAG nodes are definitionally equal; return TcError if not. + pub fn assert_def_eq( + &mut self, + x: DAGPtr, + y: DAGPtr, + ) -> TcResult<()> { + if self.def_eq(x, y) { + Ok(()) + } else { + Err(TcError::DefEqFailure { + lhs: dag_to_expr(x), + rhs: dag_to_expr(y), + }) + } + } + + // ========================================================================== + // Local context management + // ========================================================================== + + /// Create a fresh free variable for entering a binder. + /// + /// Returns a `DAGPtr::Var` with a unique `fvar_name` (derived from the + /// binder name and a monotonic counter) and records `ty` as its type + /// in `local_types`. + pub fn mk_dag_local(&mut self, name: &Name, ty: DAGPtr) -> DAGPtr { + let id = self.local_counter; + self.local_counter += 1; + let local_name = Name::num(name.clone(), Nat::from(id)); + let var = alloc_val(Var { + depth: 0, + binder: BinderPtr::Free, + fvar_name: Some(local_name.clone()), + parents: None, + }); + self.local_types.insert(local_name, ty); + DAGPtr::Var(var) + } + + // ========================================================================== + // Type inference + // ========================================================================== + + /// Infer the type of a DAG node. + /// + /// Stub: will be fully implemented in Step 3. + pub fn infer(&mut self, ptr: DAGPtr) -> TcResult { + self.infer_calls += 1; + self.infer_depth += 1; + // Heartbeat every 500 calls + if self.infer_calls % 500 == 0 { + eprintln!("[infer HEARTBEAT] calls={} depth={} cache={} whnf={} def_eq={} copy_subst_total_nodes=?", + self.infer_calls, self.infer_depth, self.infer_cache.len(), self.whnf_calls, self.def_eq_calls); + } + if self.infer_depth > self.infer_max_depth { + self.infer_max_depth = self.infer_depth; + if self.infer_max_depth % 5 == 0 || self.infer_max_depth > 20 { + let detail = unsafe { match ptr { + DAGPtr::Cnst(p) => format!("Cnst({})", (*p.as_ptr()).name.pretty()), + DAGPtr::App(_) => "App".to_string(), + DAGPtr::Fun(p) => format!("Fun({})", (*p.as_ptr()).binder_name.pretty()), + DAGPtr::Pi(p) => format!("Pi({})", (*p.as_ptr()).binder_name.pretty()), + _ => format!("{:?}", std::mem::discriminant(&ptr)), + }}; + eprintln!("[infer] NEW MAX DEPTH={} calls={} cache={} {detail}", self.infer_max_depth, self.infer_calls, self.infer_cache.len()); + } + } + if self.infer_calls % 1000 == 0 { + eprintln!("[infer] calls={} depth={} cache={}", self.infer_calls, self.infer_depth, self.infer_cache.len()); + } + let key = dag_ptr_key(ptr); + if let Some(&cached) = self.infer_cache.get(&key) { + self.infer_depth -= 1; + return Ok(cached); + } + let t0 = std::time::Instant::now(); + let result = self.infer_core(ptr)?; + let ms = t0.elapsed().as_millis(); + if ms > 100 { + let detail = unsafe { match ptr { + DAGPtr::Cnst(p) => format!("Cnst({})", (*p.as_ptr()).name.pretty()), + DAGPtr::App(_) => "App".to_string(), + DAGPtr::Fun(p) => format!("Fun({})", (*p.as_ptr()).binder_name.pretty()), + DAGPtr::Pi(p) => format!("Pi({})", (*p.as_ptr()).binder_name.pretty()), + _ => format!("{:?}", std::mem::discriminant(&ptr)), + }}; + eprintln!("[infer] depth={} took {}ms {detail}", self.infer_depth, ms); + } + self.infer_cache.insert(key, result); + self.infer_depth -= 1; + Ok(result) + } + + fn infer_core(&mut self, ptr: DAGPtr) -> TcResult { + match ptr { + DAGPtr::Var(p) => unsafe { + let var = &*p.as_ptr(); + match &var.fvar_name { + Some(name) => match self.local_types.get(name) { + Some(&ty) => Ok(ty), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context" + .into(), + }), + }, + None => match var.binder { + BinderPtr::Free => Err(TcError::FreeBoundVariable { + idx: var.depth, + }), + BinderPtr::Lam(_) => Err(TcError::KernelException { + msg: "unexpected bound variable during inference".into(), + }), + }, + } + }, + DAGPtr::Sort(p) => { + let level = unsafe { &(*p.as_ptr()).level }; + let result = alloc_val(Sort { + level: Level::succ(level.clone()), + parents: None, + }); + Ok(DAGPtr::Sort(result)) + }, + DAGPtr::Cnst(p) => { + let (name, levels) = unsafe { + let cnst = &*p.as_ptr(); + (cnst.name.clone(), cnst.levels.clone()) + }; + self.infer_const(&name, &levels) + }, + DAGPtr::App(_) => self.infer_app(ptr), + DAGPtr::Fun(_) => self.infer_lambda(ptr), + DAGPtr::Pi(_) => self.infer_pi(ptr), + DAGPtr::Let(p) => { + let (typ, val, bod_lam) = unsafe { + let let_node = &*p.as_ptr(); + (let_node.typ, let_node.val, let_node.bod) + }; + let val_ty = self.infer(val)?; + self.assert_def_eq(val_ty, typ)?; + let body = dag_copy_subst(bod_lam, val); + self.infer(body) + }, + DAGPtr::Lit(p) => { + let val = unsafe { &(*p.as_ptr()).val }; + self.infer_lit(val) + }, + DAGPtr::Proj(p) => { + let (type_name, idx, structure) = unsafe { + let proj = &*p.as_ptr(); + (proj.type_name.clone(), proj.idx.clone(), proj.expr) + }; + self.infer_proj(&type_name, &idx, structure, ptr) + }, + DAGPtr::Lam(_) => Err(TcError::KernelException { + msg: "unexpected standalone Lam during inference".into(), + }), + } + } + + fn infer_const( + &mut self, + name: &Name, + levels: &[Level], + ) -> TcResult { + // Build a cache key from the constant's name + universe level hashes. + let cache_key = { + let mut hasher = blake3::Hasher::new(); + hasher.update(name.get_hash().as_bytes()); + for l in levels { + hasher.update(l.get_hash().as_bytes()); + } + hasher.finalize() + }; + if let Some(&cached) = self.const_type_cache.get(&cache_key) { + return Ok(cached); + } + + let ci = self + .env + .get(name) + .ok_or_else(|| TcError::UnknownConst { name: name.clone() })?; + + let decl_params = ci.get_level_params(); + if levels.len() != decl_params.len() { + return Err(TcError::KernelException { + msg: format!( + "universe parameter count mismatch for {}", + name.pretty() + ), + }); + } + + let ty = ci.get_type(); + let dag = from_expr(ty); + let result = subst_dag_levels(dag.head, decl_params, levels); + self.const_type_cache.insert(cache_key, result); + Ok(result) + } + + fn infer_app(&mut self, e: DAGPtr) -> TcResult { + let (fun, args) = dag_unfold_apps(e); + let mut fun_ty = self.infer(fun)?; + + for &arg in args.iter() { + let pi = self.ensure_pi(fun_ty)?; + + let (dom, img) = unsafe { + match pi { + DAGPtr::Pi(p) => { + let pi_ref = &*p.as_ptr(); + (pi_ref.dom, pi_ref.img) + }, + _ => unreachable!(), + } + }; + let arg_ty = self.infer(arg)?; + if !self.def_eq(arg_ty, dom) { + return Err(TcError::DefEqFailure { + lhs: dag_to_expr(arg_ty), + rhs: dag_to_expr(dom), + }); + } + eprintln!("[infer_app] before dag_copy_subst"); + fun_ty = dag_copy_subst(img, arg); + eprintln!("[infer_app] after dag_copy_subst"); + } + + Ok(fun_ty) + } + + fn infer_lambda(&mut self, e: DAGPtr) -> TcResult { + let mut cursor = e; + let mut locals: Vec = Vec::new(); + let mut binder_doms: Vec = Vec::new(); + let mut binder_infos: Vec = Vec::new(); + let mut binder_names: Vec = Vec::new(); + + // Peel Fun layers + let mut binder_idx = 0usize; + while let DAGPtr::Fun(fun_ptr) = cursor { + let t_binder = std::time::Instant::now(); + let (name, bi, dom, img) = unsafe { + let fun = &*fun_ptr.as_ptr(); + ( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + fun.img, + ) + }; + + let t_sort = std::time::Instant::now(); + self.infer_sort_of(dom)?; + let sort_ms = t_sort.elapsed().as_millis(); + + let local = self.mk_dag_local(&name, dom); + locals.push(local); + binder_doms.push(dom); + binder_infos.push(bi); + binder_names.push(name.clone()); + + // Enter the binder: deep copy because img is from the TERM DAG + let t_copy = std::time::Instant::now(); + cursor = dag_copy_subst(img, local); + let copy_ms = t_copy.elapsed().as_millis(); + + let total_ms = t_binder.elapsed().as_millis(); + if total_ms > 5 { + eprintln!("[infer_lambda] binder#{binder_idx} {} total={}ms sort={}ms copy={}ms", + name.pretty(), total_ms, sort_ms, copy_ms); + } + binder_idx += 1; + } + + // Infer the body type + let t_body = std::time::Instant::now(); + let body_ty = self.infer(cursor)?; + let body_ms = t_body.elapsed().as_millis(); + if body_ms > 5 { + eprintln!("[infer_lambda] body={}ms after {} binders", body_ms, binder_idx); + } + + // Abstract back: build Pi telescope over the locals + Ok(build_pi_over_locals( + body_ty, + &locals, + &binder_names, + &binder_infos, + &binder_doms, + )) + } + + fn infer_pi(&mut self, e: DAGPtr) -> TcResult { + let mut cursor = e; + let mut locals: Vec = Vec::new(); + let mut universes: Vec = Vec::new(); + + // Peel Pi layers + while let DAGPtr::Pi(pi_ptr) = cursor { + let (name, dom, img) = unsafe { + let pi = &*pi_ptr.as_ptr(); + (pi.binder_name.clone(), pi.dom, pi.img) + }; + + let dom_univ = self.infer_sort_of(dom)?; + universes.push(dom_univ); + + let local = self.mk_dag_local(&name, dom); + locals.push(local); + + // Enter the binder: deep copy because img is from the TERM DAG + cursor = dag_copy_subst(img, local); + } + + // The body must also be a type + let mut result_level = self.infer_sort_of(cursor)?; + + // Compute imax of all levels (innermost first) + for univ in universes.into_iter().rev() { + result_level = Level::imax(univ, result_level); + } + + let result = alloc_val(Sort { + level: result_level, + parents: None, + }); + Ok(DAGPtr::Sort(result)) + } + + fn infer_lit(&mut self, lit: &Literal) -> TcResult { + let name = match lit { + Literal::NatVal(_) => Name::str(Name::anon(), "Nat".into()), + Literal::StrVal(_) => Name::str(Name::anon(), "String".into()), + }; + let cnst = alloc_val(Cnst { name, levels: vec![], parents: None }); + Ok(DAGPtr::Cnst(cnst)) + } + + fn infer_proj( + &mut self, + type_name: &Name, + idx: &Nat, + structure: DAGPtr, + _proj_expr: DAGPtr, + ) -> TcResult { + let structure_ty = self.infer(structure)?; + let structure_ty_whnf = self.whnf(structure_ty); + + let (head, struct_ty_args) = dag_unfold_apps(structure_ty_whnf); + let (head_name, head_levels) = unsafe { + match head { + DAGPtr::Cnst(p) => { + let cnst = &*p.as_ptr(); + (cnst.name.clone(), cnst.levels.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection structure type is not a constant".into(), + }) + }, + } + }; + + let ind = self.env.get(&head_name).ok_or_else(|| { + TcError::UnknownConst { name: head_name.clone() } + })?; + + let (num_params, ctor_name) = match ind { + ConstantInfo::InductInfo(iv) => { + let ctor = iv.ctors.first().ok_or_else(|| { + TcError::KernelException { + msg: "inductive has no constructors".into(), + } + })?; + (iv.num_params.to_u64().unwrap(), ctor.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection type is not an inductive".into(), + }) + }, + }; + + let ctor_ci = self.env.get(&ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + + let ctor_ty_dag = from_expr(ctor_ci.get_type()); + let mut ctor_ty = subst_dag_levels( + ctor_ty_dag.head, + ctor_ci.get_level_params(), + &head_levels, + ); + + // Skip params: instantiate with the actual type arguments + for i in 0..num_params as usize { + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let img = unsafe { (*p.as_ptr()).img }; + ctor_ty = dag_copy_subst(img, struct_ty_args[i]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (params)".into(), + }) + }, + } + } + + // Walk to the idx-th field, substituting projections + let idx_usize = idx.to_u64().unwrap() as usize; + for i in 0..idx_usize { + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let img = unsafe { (*p.as_ptr()).img }; + let proj = alloc_proj( + type_name.clone(), + Nat::from(i as u64), + structure, + None, + ); + ctor_ty = dag_copy_subst(img, DAGPtr::Proj(proj)); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (fields)".into(), + }) + }, + } + } + + // Extract the target field's type (the domain of the next Pi) + let whnf_ty = self.whnf(ctor_ty); + match whnf_ty { + DAGPtr::Pi(p) => { + let dom = unsafe { (*p.as_ptr()).dom }; + Ok(dom) + }, + _ => Err(TcError::KernelException { + msg: "ran out of constructor telescope (target field)".into(), + }), + } + } + + // ========================================================================== + // Declaration checking + // ========================================================================== + + /// Validate a declaration's type: no duplicate uparams, no loose bvars, + /// all uparams defined, and type infers to a Sort. + pub fn check_declar_info( + &mut self, + info: &crate::ix::env::ConstantVal, + ) -> TcResult<()> { + if !no_dupes_all_params(&info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "duplicate universe parameters in {}", + info.name.pretty() + ), + }); + } + if has_loose_bvars(&info.typ) { + return Err(TcError::KernelException { + msg: format!( + "free bound variables in type of {}", + info.name.pretty() + ), + }); + } + if !all_expr_uparams_defined(&info.typ, &info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in type of {}", + info.name.pretty() + ), + }); + } + let ty_dag = from_expr(&info.typ).head; + self.infer_sort_of(ty_dag)?; + Ok(()) + } + + /// Check a declaration with both type and value (DefnInfo, ThmInfo, OpaqueInfo). + fn check_value_declar( + &mut self, + cnst: &crate::ix::env::ConstantVal, + value: &crate::ix::env::Expr, + ) -> TcResult<()> { + let t_start = std::time::Instant::now(); + self.check_declar_info(cnst)?; + eprintln!("[cvd @{}ms] check_declar_info done", t_start.elapsed().as_millis()); + if !all_expr_uparams_defined(value, &cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + cnst.name.pretty() + ), + }); + } + let t1 = std::time::Instant::now(); + let val_dag = from_expr(value).head; + eprintln!("[check_value_declar] {} from_expr(value): {}ms", cnst.name.pretty(), t1.elapsed().as_millis()); + let t2 = std::time::Instant::now(); + let inferred_type = self.infer(val_dag)?; + eprintln!("[check_value_declar] {} infer: {}ms", cnst.name.pretty(), t2.elapsed().as_millis()); + let t3 = std::time::Instant::now(); + let ty_dag = from_expr(&cnst.typ).head; + eprintln!("[check_value_declar] {} from_expr(type): {}ms", cnst.name.pretty(), t3.elapsed().as_millis()); + if !self.def_eq(inferred_type, ty_dag) { + let lhs_expr = dag_to_expr(inferred_type); + let rhs_expr = dag_to_expr(ty_dag); + return Err(TcError::DefEqFailure { + lhs: lhs_expr, + rhs: rhs_expr, + }); + } + Ok(()) + } + + /// Check a single declaration. + pub fn check_declar( + &mut self, + ci: &ConstantInfo, + ) -> TcResult<()> { + match ci { + ConstantInfo::AxiomInfo(v) => { + self.check_declar_info(&v.cnst)?; + }, + ConstantInfo::DefnInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::ThmInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::OpaqueInfo(v) => { + self.check_value_declar(&v.cnst, &v.value)?; + }, + ConstantInfo::QuotInfo(v) => { + self.check_declar_info(&v.cnst)?; + super::quot::check_quot(self.env)?; + }, + ConstantInfo::InductInfo(v) => { + // Use Expr-level TypeChecker for structural inductive validation + // (positivity, return types, field universes). These checks aren't + // performance-critical and work on small type telescopes. + let mut expr_tc = super::tc::TypeChecker::new(self.env); + super::inductive::check_inductive(v, &mut expr_tc)?; + }, + ConstantInfo::CtorInfo(v) => { + self.check_declar_info(&v.cnst)?; + if self.env.get(&v.induct).is_none() { + return Err(TcError::UnknownConst { + name: v.induct.clone(), + }); + } + }, + ConstantInfo::RecInfo(v) => { + self.check_declar_info(&v.cnst)?; + for ind_name in &v.all { + if self.env.get(ind_name).is_none() { + return Err(TcError::UnknownConst { + name: ind_name.clone(), + }); + } + } + super::inductive::validate_k_flag(v, self.env)?; + }, + } + Ok(()) + } +} + + +/// Convert a DAGPtr to an Expr. Used only when constructing TcError values. +fn dag_to_expr(ptr: DAGPtr) -> crate::ix::env::Expr { + let dag = DAG { head: ptr }; + to_expr(&dag) +} + +/// Check all declarations in an environment in parallel using the DAG TC. +pub fn dag_check_env(env: &Env) -> Vec<(Name, TcError)> { + use std::collections::BTreeSet; + use std::io::Write; + use std::sync::Mutex; + use std::sync::atomic::{AtomicUsize, Ordering}; + + let total = env.len(); + let checked = AtomicUsize::new(0); + + struct Display { + active: BTreeSet, + prev_lines: usize, + } + let display = + Mutex::new(Display { active: BTreeSet::new(), prev_lines: 0 }); + + let refresh = |d: &mut Display, checked: usize| { + let mut stderr = std::io::stderr().lock(); + if d.prev_lines > 0 { + write!(stderr, "\x1b[{}A", d.prev_lines).ok(); + } + write!( + stderr, + "\x1b[2K[dag_check_env] {}/{} — {} active\n", + checked, + total, + d.active.len() + ) + .ok(); + let mut new_lines = 1; + for name in &d.active { + write!(stderr, "\x1b[2K {}\n", name).ok(); + new_lines += 1; + } + let extra = d.prev_lines.saturating_sub(new_lines); + for _ in 0..extra { + write!(stderr, "\x1b[2K\n").ok(); + } + if extra > 0 { + write!(stderr, "\x1b[{}A", extra).ok(); + } + d.prev_lines = new_lines; + stderr.flush().ok(); + }; + + env + .par_iter() + .filter_map(|(name, ci): (&Name, &ConstantInfo)| { + let pretty = name.pretty(); + { + let mut d = display.lock().unwrap(); + d.active.insert(pretty.clone()); + refresh(&mut d, checked.load(Ordering::Relaxed)); + } + + let mut tc = DagTypeChecker::new(env); + let result = tc.check_declar(ci); + + let n = checked.fetch_add(1, Ordering::Relaxed) + 1; + { + let mut d = display.lock().unwrap(); + d.active.remove(&pretty); + refresh(&mut d, n); + } + + match result { + Ok(()) => None, + Err(e) => Some((name.clone(), e)), + } + }) + .collect() +} + +// ============================================================================ +// build_pi_over_locals +// ============================================================================ + +/// Abstract free variables back into a Pi telescope. +/// +/// Given a `body` type (DAGPtr containing free Vars created by `mk_dag_local`) +/// and corresponding binder information, builds a Pi telescope at the DAG level. +/// +/// Processes binders from innermost (last) to outermost (first). For each: +/// 1. Allocates a `Lam` with `bod = current_result` +/// 2. Calls `replace_child(free_var, lam.var)` to redirect all references +/// 3. Allocates `Pi(name, bi, dom, lam)` and wires parent pointers +pub fn build_pi_over_locals( + body: DAGPtr, + locals: &[DAGPtr], + names: &[Name], + bis: &[BinderInfo], + doms: &[DAGPtr], +) -> DAGPtr { + let mut result = body; + // Process from innermost (last) to outermost (first) + for i in (0..locals.len()).rev() { + // 1. Allocate Lam wrapping the current result + let lam = alloc_lam(0, result, None); + unsafe { + let lam_ref = &mut *lam.as_ptr(); + // Wire bod_ref as parent of result + let bod_ref = + NonNull::new(&mut lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(result, bod_ref); + // 2. Redirect all references from the free var to the bound var + let new_var = NonNull::new(&mut lam_ref.var as *mut Var).unwrap(); + replace_child(locals[i], DAGPtr::Var(new_var)); + } + // 3. Allocate Pi + let pi = alloc_pi(names[i].clone(), bis[i].clone(), doms[i], lam, None); + unsafe { + let pi_ref = &mut *pi.as_ptr(); + // Wire dom_ref as parent of doms[i] + let dom_ref = + NonNull::new(&mut pi_ref.dom_ref as *mut Parents).unwrap(); + add_to_parents(doms[i], dom_ref); + // Wire img_ref as parent of Lam + let img_ref = + NonNull::new(&mut pi_ref.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam), img_ref); + } + result = DAGPtr::Pi(pi); + } + result +} + +// ============================================================================ +// Definitional equality helpers (free functions) +// ============================================================================ + +/// Result of lazy delta reduction at DAG level. +enum DagDeltaResult { + Found(bool), + Exhausted(DAGPtr, DAGPtr), +} + +/// Get the name and reducibility hint of an applied definition. +fn dag_get_applied_def( + ptr: DAGPtr, + env: &Env, +) -> Option<(Name, ReducibilityHints)> { + let (head, _) = dag_unfold_apps(ptr); + let name = match head { + DAGPtr::Cnst(c) => unsafe { (*c.as_ptr()).name.clone() }, + _ => return None, + }; + let ci = env.get(&name)?; + match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + None + } else { + Some((name, d.hints)) + } + }, + ConstantInfo::ThmInfo(_) => { + Some((name, ReducibilityHints::Opaque)) + }, + _ => None, + } +} + +/// Try to unfold a definition at DAG level. +fn dag_try_unfold_def(ptr: DAGPtr, env: &Env) -> Option { + let (head, args) = dag_unfold_apps(ptr); + let (name, levels) = match head { + DAGPtr::Cnst(c) => unsafe { + let cr = &*c.as_ptr(); + (cr.name.clone(), cr.levels.clone()) + }, + _ => return None, + }; + let ci = env.get(&name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + _ => return None, + }; + if levels.len() != def_params.len() { + return None; + } + let val = subst_expr_levels(def_value, def_params, &levels); + let val_dag = from_expr(&val); + Some(dag_foldl_apps(val_dag.head, &args)) +} + +/// Try nat/native reduction before delta. +fn try_lazy_delta_nat_native(ptr: DAGPtr, env: &Env) -> Option { + let (head, args) = dag_unfold_apps(ptr); + match head { + DAGPtr::Cnst(c) => unsafe { + let name = &(*c.as_ptr()).name; + if let Some(r) = try_reduce_native_dag(name, &args) { + return Some(r); + } + if let Some(r) = try_reduce_nat_dag(name, &args, env) { + return Some(r); + } + None + }, + _ => None, + } +} + +/// Check if a DAGPtr is Nat.zero (either constructor or literal 0). +fn is_nat_zero_dag(ptr: DAGPtr) -> bool { + unsafe { + match ptr { + DAGPtr::Cnst(c) => (*c.as_ptr()).name == mk_name2("Nat", "zero"), + DAGPtr::Lit(l) => { + matches!(&(*l.as_ptr()).val, Literal::NatVal(n) if n.0 == BigUint::ZERO) + }, + _ => false, + } + } +} + +/// If expression is `Nat.succ arg` or `lit (n+1)`, return the predecessor. +fn is_nat_succ_dag(ptr: DAGPtr) -> Option { + unsafe { + match ptr { + DAGPtr::App(app) => { + let a = &*app.as_ptr(); + match a.fun { + DAGPtr::Cnst(c) + if (*c.as_ptr()).name == mk_name2("Nat", "succ") => + { + Some(a.arg) + }, + _ => None, + } + }, + DAGPtr::Lit(l) => match &(*l.as_ptr()).val { + Literal::NatVal(n) if n.0 > BigUint::ZERO => { + Some(nat_lit_dag(Nat(n.0.clone() - BigUint::from(1u64)))) + }, + _ => None, + }, + _ => None, + } + } +} + +/// Check if a name refers to a structure-like inductive: +/// exactly 1 constructor, not recursive, no indices. +fn is_structure_like(name: &Name, env: &Env) -> bool { + match env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + iv.ctors.len() == 1 && !iv.is_rec && iv.num_indices == Nat::ZERO + }, + _ => false, + } +} + +/// Compare reducibility hints for ordering. +fn hint_lt(a: &ReducibilityHints, b: &ReducibilityHints) -> bool { + match (a, b) { + (ReducibilityHints::Opaque, _) => true, + (_, ReducibilityHints::Opaque) => false, + (ReducibilityHints::Abbrev, _) => false, + (_, ReducibilityHints::Abbrev) => true, + (ReducibilityHints::Regular(ha), ReducibilityHints::Regular(hb)) => { + ha < hb + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::env::{BinderInfo, Expr, Level, Literal}; + use crate::ix::kernel::convert::from_expr; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + // ======================================================================== + // subst_dag_levels tests + // ======================================================================== + + #[test] + fn subst_dag_levels_empty_params() { + let e = Expr::sort(Level::param(mk_name("u"))); + let dag = from_expr(&e); + let result = subst_dag_levels(dag.head, &[], &[]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, e); + } + + #[test] + fn subst_dag_levels_sort() { + let u_name = mk_name("u"); + let e = Expr::sort(Level::param(u_name.clone())); + let dag = from_expr(&e); + let result = subst_dag_levels(dag.head, &[u_name], &[Level::zero()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, Expr::sort(Level::zero())); + } + + #[test] + fn subst_dag_levels_cnst() { + let u_name = mk_name("u"); + let e = Expr::cnst(mk_name("List"), vec![Level::param(u_name.clone())]); + let dag = from_expr(&e); + let one = Level::succ(Level::zero()); + let result = subst_dag_levels(dag.head, &[u_name], &[one.clone()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, Expr::cnst(mk_name("List"), vec![one])); + } + + #[test] + fn subst_dag_levels_nested() { + // Pi (A : Sort u) → Sort u with u := 1 + let u_name = mk_name("u"); + let sort_u = Expr::sort(Level::param(u_name.clone())); + let e = Expr::all( + mk_name("A"), + sort_u.clone(), + sort_u, + BinderInfo::Default, + ); + let dag = from_expr(&e); + let one = Level::succ(Level::zero()); + let result = subst_dag_levels(dag.head, &[u_name], &[one.clone()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let sort_1 = Expr::sort(one); + let expected = Expr::all( + mk_name("A"), + sort_1.clone(), + sort_1, + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn subst_dag_levels_no_levels_unchanged() { + // Expression with no Sort or Cnst nodes — pure lambda + let e = Expr::lam( + mk_name("x"), + Expr::lit(Literal::NatVal(Nat::from(0u64))), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let u_name = mk_name("u"); + let result = + subst_dag_levels(dag.head, &[u_name], &[Level::zero()]); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + assert_eq!(result_expr, e); + } + + // ======================================================================== + // mk_dag_local tests + // ======================================================================== + + #[test] + fn mk_dag_local_creates_free_var() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let name = mk_name("x"); + let ty = from_expr(&nat_type()).head; + let local = tc.mk_dag_local(&name, ty); + match local { + DAGPtr::Var(p) => unsafe { + let var = &*p.as_ptr(); + assert!(matches!(var.binder, BinderPtr::Free)); + assert!(var.fvar_name.is_some()); + }, + _ => panic!("Expected Var"), + } + assert_eq!(tc.local_counter, 1); + assert_eq!(tc.local_types.len(), 1); + } + + #[test] + fn mk_dag_local_unique_names() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let name = mk_name("x"); + let ty = from_expr(&nat_type()).head; + let l1 = tc.mk_dag_local(&name, ty); + let ty2 = from_expr(&nat_type()).head; + let l2 = tc.mk_dag_local(&name, ty2); + // Different pointer identities + assert_ne!(dag_ptr_key(l1), dag_ptr_key(l2)); + // Different fvar names + unsafe { + let n1 = match l1 { + DAGPtr::Var(p) => (*p.as_ptr()).fvar_name.clone().unwrap(), + _ => panic!(), + }; + let n2 = match l2 { + DAGPtr::Var(p) => (*p.as_ptr()).fvar_name.clone().unwrap(), + _ => panic!(), + }; + assert_ne!(n1, n2); + } + } + + // ======================================================================== + // build_pi_over_locals tests + // ======================================================================== + + #[test] + fn build_pi_single_binder() { + // Build: Pi (x : Nat) → Nat + // body = Nat (doesn't reference x), locals = [x_free] + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let nat_dag = from_expr(&nat_type()).head; + let x_local = tc.mk_dag_local(&mk_name("x"), nat_dag); + // Body doesn't use x + let body = from_expr(&nat_type()).head; + let result = build_pi_over_locals( + body, + &[x_local], + &[mk_name("x")], + &[BinderInfo::Default], + &[nat_dag], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn build_pi_dependent() { + // Build: Pi (A : Sort 0) → A + // body = A_local (references A), locals = [A_local] + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort0 = from_expr(&Expr::sort(Level::zero())).head; + let a_local = tc.mk_dag_local(&mk_name("A"), sort0); + // Body IS the local variable + let result = build_pi_over_locals( + a_local, + &[a_local], + &[mk_name("A")], + &[BinderInfo::Default], + &[sort0], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + #[test] + fn build_pi_two_binders() { + // Build: Pi (A : Sort 0) (x : A) → A + // Should produce: ForallE A (Sort 0) (ForallE x (bvar 0) (bvar 1)) + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort0 = from_expr(&Expr::sort(Level::zero())).head; + let a_local = tc.mk_dag_local(&mk_name("A"), sort0); + let x_local = tc.mk_dag_local(&mk_name("x"), a_local); + // Body is a_local (the type A) + let result = build_pi_over_locals( + a_local, + &[a_local, x_local], + &[mk_name("A"), mk_name("x")], + &[BinderInfo::Default, BinderInfo::Default], + &[sort0, a_local], + ); + let result_dag = DAG { head: result }; + let result_expr = to_expr(&result_dag); + let expected = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(result_expr, expected); + } + + // ======================================================================== + // DagTypeChecker core method tests + // ======================================================================== + + #[test] + fn whnf_sort_is_identity() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let result = tc.whnf(ptr); + assert_eq!(dag_ptr_key(result), dag_ptr_key(ptr)); + } + + #[test] + fn whnf_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.whnf(ptr); + let r2 = tc.whnf(ptr); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.whnf_cache.len(), 1); + } + + #[test] + fn whnf_no_delta_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.whnf_no_delta(ptr); + let r2 = tc.whnf_no_delta(ptr); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.whnf_no_delta_cache.len(), 1); + } + + #[test] + fn ensure_sort_on_sort() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let result = tc.ensure_sort(DAGPtr::Sort(sort)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Level::zero()); + } + + #[test] + fn ensure_sort_on_non_sort() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let lit = alloc_val(LitNode { + val: Literal::NatVal(Nat::from(42u64)), + parents: None, + }); + let result = tc.ensure_sort(DAGPtr::Lit(lit)); + assert!(result.is_err()); + } + + #[test] + fn ensure_pi_on_pi() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let lam = alloc_lam(0, DAGPtr::Sort(sort), None); + let pi = alloc_pi( + mk_name("x"), + BinderInfo::Default, + DAGPtr::Sort(sort), + lam, + None, + ); + let result = tc.ensure_pi(DAGPtr::Pi(pi)); + assert!(result.is_ok()); + } + + #[test] + fn ensure_pi_on_non_pi() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let lit = alloc_val(LitNode { + val: Literal::NatVal(Nat::from(42u64)), + parents: None, + }); + let result = tc.ensure_pi(DAGPtr::Lit(lit)); + assert!(result.is_err()); + } + + #[test] + fn infer_sort_zero() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let result = tc.infer(DAGPtr::Sort(sort)).unwrap(); + match result { + DAGPtr::Sort(p) => unsafe { + assert_eq!((*p.as_ptr()).level, Level::succ(Level::zero())); + }, + _ => panic!("Expected Sort"), + } + } + + #[test] + fn infer_fvar() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let nat_dag = from_expr(&nat_type()).head; + let local = tc.mk_dag_local(&mk_name("x"), nat_dag); + let result = tc.infer(local).unwrap(); + assert_eq!(dag_ptr_key(result), dag_ptr_key(nat_dag)); + } + + #[test] + fn infer_caches_result() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + let r1 = tc.infer(ptr).unwrap(); + let r2 = tc.infer(ptr).unwrap(); + assert_eq!(dag_ptr_key(r1), dag_ptr_key(r2)); + assert_eq!(tc.infer_cache.len(), 1); + } + + #[test] + fn def_eq_pointer_identity() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + assert!(tc.def_eq(ptr, ptr)); + } + + #[test] + fn def_eq_sort_structural() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { level: Level::zero(), parents: None }); + // Same level, different pointers — structurally equal + assert!(tc.def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2))); + } + + #[test] + fn def_eq_sort_different_levels() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { + level: Level::succ(Level::zero()), + parents: None, + }); + assert!(!tc.def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2))); + } + + #[test] + fn assert_def_eq_ok() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let sort = alloc_val(Sort { level: Level::zero(), parents: None }); + let ptr = DAGPtr::Sort(sort); + assert!(tc.assert_def_eq(ptr, ptr).is_ok()); + } + + #[test] + fn assert_def_eq_err() { + let env = Env::default(); + let mut tc = DagTypeChecker::new(&env); + let s1 = alloc_val(Sort { level: Level::zero(), parents: None }); + let s2 = alloc_val(Sort { + level: Level::succ(Level::zero()), + parents: None, + }); + assert!(tc.assert_def_eq(DAGPtr::Sort(s1), DAGPtr::Sort(s2)).is_err()); + } + + // ======================================================================== + // Type inference tests (Step 3) + // ======================================================================== + + use crate::ix::env::{ + AxiomVal, ConstantVal, ConstructorVal, InductiveVal, + }; + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + fn prop() -> Expr { + Expr::sort(Level::zero()) + } + + /// Build a minimal environment with Nat, Nat.zero, Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + let succ_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: succ_ty, + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + /// Helper: infer the type of an Expr via the DAG TC, return as Expr. + fn dag_infer(env: &Env, e: &Expr) -> Result { + let mut tc = DagTypeChecker::new(env); + let dag = from_expr(e); + let result = tc.infer(dag.head)?; + Ok(dag_to_expr(result)) + } + + // -- Const inference -- + + #[test] + fn dag_infer_const_nat() { + let env = mk_nat_env(); + let ty = dag_infer(&env, &Expr::cnst(mk_name("Nat"), vec![])).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn dag_infer_const_nat_zero() { + let env = mk_nat_env(); + let ty = dag_infer(&env, &nat_zero()).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_const_nat_succ() { + let env = mk_nat_env(); + let ty = + dag_infer(&env, &Expr::cnst(mk_name2("Nat", "succ"), vec![])).unwrap(); + let expected = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn dag_infer_const_unknown() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::cnst(mk_name("Nope"), vec![])).is_err()); + } + + #[test] + fn dag_infer_const_universe_mismatch() { + let env = mk_nat_env(); + assert!( + dag_infer(&env, &Expr::cnst(mk_name("Nat"), vec![Level::zero()])) + .is_err() + ); + } + + // -- Lit inference -- + + #[test] + fn dag_infer_nat_lit() { + let env = Env::default(); + let ty = + dag_infer(&env, &Expr::lit(Literal::NatVal(Nat::from(42u64)))).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_string_lit() { + let env = Env::default(); + let ty = + dag_infer(&env, &Expr::lit(Literal::StrVal("hello".into()))).unwrap(); + assert_eq!(ty, Expr::cnst(mk_name("String"), vec![])); + } + + // -- App inference -- + + #[test] + fn dag_infer_app_succ_zero() { + // Nat.succ Nat.zero : Nat + let env = mk_nat_env(); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn dag_infer_app_identity() { + // (fun x : Nat => x) Nat.zero : Nat + let env = mk_nat_env(); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // -- Lambda inference -- + + #[test] + fn dag_infer_identity_lambda() { + // fun (x : Nat) => x : Nat → Nat + let env = mk_nat_env(); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &e).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn dag_infer_const_lambda() { + // fun (x : Nat) (y : Nat) => x : Nat → Nat → Nat + let env = mk_nat_env(); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &k_fn).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + // -- Pi inference -- + + #[test] + fn dag_infer_pi_nat_to_nat() { + // (Nat → Nat) : Sort 1 + let env = mk_nat_env(); + let pi = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ty = dag_infer(&env, &pi).unwrap(); + if let crate::ix::env::ExprData::Sort(level, _) = ty.as_data() { + assert!( + crate::ix::kernel::level::eq_antisymm( + level, + &Level::succ(Level::zero()) + ), + "Nat → Nat should live in Sort 1, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + #[test] + fn dag_infer_pi_prop_to_prop() { + // P → P : Prop (where P : Prop) + let mut env = Env::default(); + let p_name = mk_name("P"); + env.insert( + p_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: prop(), + }, + is_unsafe: false, + }), + ); + let p = Expr::cnst(p_name, vec![]); + let pi = + Expr::all(mk_name("x"), p.clone(), p.clone(), BinderInfo::Default); + let ty = dag_infer(&env, &pi).unwrap(); + if let crate::ix::env::ExprData::Sort(level, _) = ty.as_data() { + assert!( + crate::ix::kernel::level::is_zero(level), + "Prop → Prop should live in Prop, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + // -- Let inference -- + + #[test] + fn dag_infer_let_simple() { + // let x : Nat := Nat.zero in x : Nat + let env = mk_nat_env(); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let ty = dag_infer(&env, &e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // -- Error cases -- + + #[test] + fn dag_infer_free_bvar_fails() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::bvar(Nat::from(0u64))).is_err()); + } + + #[test] + fn dag_infer_fvar_unknown_fails() { + let env = Env::default(); + assert!(dag_infer(&env, &Expr::fvar(mk_name("x"))).is_err()); + } + + // ======================================================================== + // Definitional equality tests (Step 4) + // ======================================================================== + + use crate::ix::env::{ + DefinitionSafety, DefinitionVal, ReducibilityHints, TheoremVal, + }; + + /// Helper: check def_eq of two Expr via the DAG TC. + fn dag_def_eq(env: &Env, x: &Expr, y: &Expr) -> bool { + let mut tc = DagTypeChecker::new(env); + let dx = from_expr(x); + let dy = from_expr(y); + tc.def_eq(dx.head, dy.head) + } + + // -- Reflexivity -- + + #[test] + fn dag_def_eq_reflexive_sort() { + let env = Env::default(); + let e = Expr::sort(Level::zero()); + assert!(dag_def_eq(&env, &e, &e)); + } + + #[test] + fn dag_def_eq_reflexive_const() { + let env = mk_nat_env(); + let e = nat_zero(); + assert!(dag_def_eq(&env, &e, &e)); + } + + // -- Sort equality -- + + #[test] + fn dag_def_eq_sort_max_comm() { + let env = Env::default(); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let s1 = Expr::sort(Level::max(u.clone(), v.clone())); + let s2 = Expr::sort(Level::max(v, u)); + assert!(dag_def_eq(&env, &s1, &s2)); + } + + #[test] + fn dag_def_eq_sort_not_equal() { + let env = Env::default(); + let s0 = Expr::sort(Level::zero()); + let s1 = Expr::sort(Level::succ(Level::zero())); + assert!(!dag_def_eq(&env, &s0, &s1)); + } + + // -- Alpha equivalence -- + + #[test] + fn dag_def_eq_alpha_lambda() { + let env = mk_nat_env(); + let e1 = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e2 = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &e1, &e2)); + } + + #[test] + fn dag_def_eq_alpha_pi() { + let env = mk_nat_env(); + let e1 = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e2 = Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &e1, &e2)); + } + + // -- Beta equivalence -- + + #[test] + fn dag_def_eq_beta() { + let env = mk_nat_env(); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let lhs = Expr::app(id_fn, nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + #[test] + fn dag_def_eq_beta_nested() { + let env = mk_nat_env(); + let inner = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + inner, + BinderInfo::Default, + ); + let lhs = Expr::app(Expr::app(k_fn, nat_zero()), nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Delta equivalence -- + + #[test] + fn dag_def_eq_delta() { + let mut env = mk_nat_env(); + let my_zero = mk_name("myZero"); + env.insert( + my_zero.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_zero.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_zero.clone()], + }), + ); + let lhs = Expr::cnst(my_zero, vec![]); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + #[test] + fn dag_def_eq_delta_both_sides() { + let mut env = mk_nat_env(); + for name_str in &["a", "b"] { + let n = mk_name(name_str); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + } + let a = Expr::cnst(mk_name("a"), vec![]); + let b = Expr::cnst(mk_name("b"), vec![]); + assert!(dag_def_eq(&env, &a, &b)); + } + + // -- Zeta equivalence -- + + #[test] + fn dag_def_eq_zeta() { + let env = mk_nat_env(); + let lhs = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Negative tests -- + + #[test] + fn dag_def_eq_different_consts() { + let env = Env::default(); + let nat = nat_type(); + let string = Expr::cnst(mk_name("String"), vec![]); + assert!(!dag_def_eq(&env, &nat, &string)); + } + + // -- App congruence -- + + #[test] + fn dag_def_eq_app_congruence() { + let env = mk_nat_env(); + let f = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let a = nat_zero(); + let lhs = Expr::app(f.clone(), a.clone()); + let rhs = Expr::app(f, a); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_app_different_args() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let lhs = Expr::app(succ.clone(), nat_zero()); + let rhs = Expr::app(succ.clone(), Expr::app(succ, nat_zero())); + assert!(!dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Eta expansion -- + + #[test] + fn dag_def_eq_eta_lam_vs_const() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &eta_expanded, &succ)); + } + + #[test] + fn dag_def_eq_eta_symmetric() { + let env = mk_nat_env(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &succ, &eta_expanded)); + } + + // -- Binder full comparison -- + + #[test] + fn dag_def_eq_binder_full_different_domains() { + // (x : myNat) → Nat =def= (x : Nat) → Nat + let mut env = mk_nat_env(); + let my_nat = mk_name("myNat"); + env.insert( + my_nat.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_nat.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: nat_type(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_nat.clone()], + }), + ); + let lhs = Expr::all( + mk_name("x"), + Expr::cnst(my_nat, vec![]), + nat_type(), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_binder_dependent() { + // Pi (A : Sort 0) (x : A) → A =def= Pi (B : Sort 0) (y : B) → B + let env = Env::default(); + let lhs = Expr::all( + mk_name("A"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("B"), + Expr::sort(Level::zero()), + Expr::all( + mk_name("y"), + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Nat offset equality -- + + #[test] + fn dag_def_eq_nat_zero_ctor_vs_lit() { + let env = mk_nat_env(); + let lit0 = Expr::lit(Literal::NatVal(Nat::from(0u64))); + assert!(dag_def_eq(&env, &nat_zero(), &lit0)); + } + + #[test] + fn dag_def_eq_nat_lit_vs_succ_lit() { + let env = mk_nat_env(); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::lit(Literal::NatVal(Nat::from(4u64))), + ); + let lit5 = Expr::lit(Literal::NatVal(Nat::from(5u64))); + assert!(dag_def_eq(&env, &lit5, &succ_4)); + } + + #[test] + fn dag_def_eq_nat_lit_not_equal() { + let env = Env::default(); + let a = Expr::lit(Literal::NatVal(Nat::from(1u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(2u64))); + assert!(!dag_def_eq(&env, &a, &b)); + } + + // -- Lazy delta with hints -- + + #[test] + fn dag_def_eq_lazy_delta_higher_unfolds_first() { + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Regular(1), + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Regular(2), + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let lhs = Expr::cnst(b, vec![]); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Proof irrelevance -- + + #[test] + fn dag_def_eq_proof_irrel() { + let mut env = mk_nat_env(); + let true_name = mk_name("True"); + let intro_name = mk_name2("True", "intro"); + env.insert( + true_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: true_name.clone(), + level_params: vec![], + typ: prop(), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![true_name.clone()], + ctors: vec![intro_name.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + intro_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: intro_name.clone(), + level_params: vec![], + typ: Expr::cnst(true_name.clone(), vec![]), + }, + induct: true_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let true_ty = Expr::cnst(true_name, vec![]); + let thm_a = mk_name("thmA"); + let thm_b = mk_name("thmB"); + env.insert( + thm_a.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_a.clone(), + level_params: vec![], + typ: true_ty.clone(), + }, + value: Expr::cnst(intro_name.clone(), vec![]), + all: vec![thm_a.clone()], + }), + ); + env.insert( + thm_b.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_b.clone(), + level_params: vec![], + typ: true_ty, + }, + value: Expr::cnst(intro_name, vec![]), + all: vec![thm_b.clone()], + }), + ); + let a = Expr::cnst(thm_a, vec![]); + let b = Expr::cnst(thm_b, vec![]); + assert!(dag_def_eq(&env, &a, &b)); + } + + // -- Proj congruence -- + + #[test] + fn dag_def_eq_proj_congruence() { + let env = Env::default(); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(0u64), s); + assert!(dag_def_eq(&env, &lhs, &rhs)); + } + + #[test] + fn dag_def_eq_proj_different_idx() { + let env = Env::default(); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(1u64), s); + assert!(!dag_def_eq(&env, &lhs, &rhs)); + } + + // -- Beta-delta combined -- + + #[test] + fn dag_def_eq_beta_delta_combined() { + let mut env = mk_nat_env(); + let my_id = mk_name("myId"); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + my_id.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_id.clone(), + level_params: vec![], + typ: fun_ty, + }, + value: Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_id.clone()], + }), + ); + let lhs = Expr::app(Expr::cnst(my_id, vec![]), nat_zero()); + assert!(dag_def_eq(&env, &lhs, &nat_zero())); + } + + // -- Unit-like equality -- + + #[test] + fn dag_def_eq_unit_like() { + let mut env = mk_nat_env(); + let unit_name = mk_name("Unit"); + let unit_star = mk_name2("Unit", "star"); + env.insert( + unit_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: unit_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![unit_name.clone()], + ctors: vec![unit_star.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + unit_star.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: unit_star.clone(), + level_params: vec![], + typ: Expr::cnst(unit_name.clone(), vec![]), + }, + induct: unit_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + // Two distinct fvars of type Unit should be def-eq + let unit_ty = Expr::cnst(unit_name, vec![]); + let mut tc = DagTypeChecker::new(&env); + let x_ty = from_expr(&unit_ty).head; + let x = tc.mk_dag_local(&mk_name("x"), x_ty); + let y_ty = from_expr(&unit_ty).head; + let y = tc.mk_dag_local(&mk_name("y"), y_ty); + assert!(tc.def_eq(x, y)); + } + + // -- Nat add through def_eq -- + + #[test] + fn dag_def_eq_nat_add_result_vs_lit() { + let env = mk_nat_env(); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + Expr::lit(Literal::NatVal(Nat::from(3u64))), + ), + Expr::lit(Literal::NatVal(Nat::from(4u64))), + ); + let lit7 = Expr::lit(Literal::NatVal(Nat::from(7u64))); + assert!(dag_def_eq(&env, &add_3_4, &lit7)); + } +} diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs index c2110381..ada12904 100644 --- a/src/ix/kernel/def_eq.rs +++ b/src/ix/kernel/def_eq.rs @@ -1,5 +1,6 @@ use crate::ix::env::*; use crate::lean::nat::Nat; +use num_bigint::BigUint; use super::level::{eq_antisymm, eq_antisymm_many}; use super::tc::TypeChecker; @@ -12,13 +13,40 @@ enum DeltaResult { } /// Check definitional equality of two expressions. +/// +/// Uses a conjunction work stack: processes pairs iteratively, all must be equal. pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + const DEF_EQ_STEP_LIMIT: u64 = 1_000_000; + let mut work: Vec<(Expr, Expr)> = vec![(x.clone(), y.clone())]; + let mut steps: u64 = 0; + + while let Some((x, y)) = work.pop() { + steps += 1; + if steps > DEF_EQ_STEP_LIMIT { + eprintln!("[def_eq] step limit exceeded ({steps} steps)"); + return false; + } + if !def_eq_step(&x, &y, &mut work, tc) { + return false; + } + } + true +} + +/// Process one def_eq pair. Returns false if definitely not equal. +/// May push additional pairs onto `work` that must all be equal. +fn def_eq_step( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, + tc: &mut TypeChecker, +) -> bool { if let Some(quick) = def_eq_quick_check(x, y) { return quick; } - let x_n = tc.whnf(x); - let y_n = tc.whnf(y); + let x_n = tc.whnf_no_delta(x); + let y_n = tc.whnf_no_delta(y); if let Some(quick) = def_eq_quick_check(&x_n, &y_n) { return quick; @@ -32,9 +60,9 @@ pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { DeltaResult::Found(result) => result, DeltaResult::Exhausted(x_e, y_e) => { def_eq_const(&x_e, &y_e) - || def_eq_proj(&x_e, &y_e, tc) - || def_eq_app(&x_e, &y_e, tc) - || def_eq_binder_full(&x_e, &y_e, tc) + || def_eq_proj_push(&x_e, &y_e, work) + || def_eq_app_push(&x_e, &y_e, work) + || def_eq_binder_full_push(&x_e, &y_e, work) || try_eta_expansion(&x_e, &y_e, tc) || try_eta_struct(&x_e, &y_e, tc) || is_def_eq_unit_like(&x_e, &y_e, tc) @@ -82,16 +110,50 @@ fn def_eq_const(x: &Expr, y: &Expr) -> bool { } } -fn def_eq_proj(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { +/// Proj congruence: push structure pair onto work stack. +fn def_eq_proj_push( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, +) -> bool { match (x.as_data(), y.as_data()) { ( ExprData::Proj(_, idx_l, structure_l, _), ExprData::Proj(_, idx_r, structure_r, _), - ) => idx_l == idx_r && def_eq(structure_l, structure_r, tc), + ) if idx_l == idx_r => { + work.push((structure_l.clone(), structure_r.clone())); + true + }, _ => false, } } +/// App congruence: push head + arg pairs onto work stack. +fn def_eq_app_push( + x: &Expr, + y: &Expr, + work: &mut Vec<(Expr, Expr)>, +) -> bool { + let (f1, args1) = unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + + work.push((f1, f2)); + for (a, b) in args1.into_iter().zip(args2.into_iter()) { + work.push((a, b)); + } + true +} + +/// Eager app congruence (used by lazy_delta_step where we need a definitive answer). fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { let (f1, args1) = unfold_apps(x); if args1.is_empty() { @@ -111,24 +173,47 @@ fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { args1.iter().zip(args2.iter()).all(|(a, b)| def_eq(a, b, tc)) } -/// Full recursive binder comparison: two Pi or two Lam types with -/// definitionally equal domain types and bodies (ignoring binder names). -fn def_eq_binder_full( +/// Iterative binder comparison: peel matching Pi/Lam layers, pushing +/// domain pairs and the final body pair onto the work stack. +fn def_eq_binder_full_push( x: &Expr, y: &Expr, - tc: &mut TypeChecker, + work: &mut Vec<(Expr, Expr)>, ) -> bool { - match (x.as_data(), y.as_data()) { - ( - ExprData::ForallE(_, t1, b1, _, _), - ExprData::ForallE(_, t2, b2, _, _), - ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), - ( - ExprData::Lam(_, t1, b1, _, _), - ExprData::Lam(_, t2, b2, _, _), - ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), - _ => false, + let mut cx = x.clone(); + let mut cy = y.clone(); + let mut matched = false; + + loop { + match (cx.as_data(), cy.as_data()) { + ( + ExprData::ForallE(_, t1, b1, _, _), + ExprData::ForallE(_, t2, b2, _, _), + ) => { + work.push((t1.clone(), t2.clone())); + cx = b1.clone(); + cy = b2.clone(); + matched = true; + }, + ( + ExprData::Lam(_, t1, b1, _, _), + ExprData::Lam(_, t2, b2, _, _), + ) => { + work.push((t1.clone(), t2.clone())); + cx = b1.clone(); + cy = b2.clone(); + matched = true; + }, + _ => break, + } + } + + if !matched { + return false; } + // Push the final body pair + work.push((cx, cy)); + true } /// Proof irrelevance: if both x and y are proofs of the same proposition, @@ -293,6 +378,66 @@ fn is_def_eq_unit_like(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { } } +/// Check if expression is Nat zero (either `Nat.zero` or `lit 0`). +/// Matches Lean 4's `is_nat_zero`. +fn is_nat_zero(e: &Expr) -> bool { + match e.as_data() { + ExprData::Const(name, _, _) => *name == mk_name2("Nat", "zero"), + ExprData::Lit(Literal::NatVal(n), _) => n.0 == BigUint::ZERO, + _ => false, + } +} + +/// If expression is `Nat.succ arg` or `lit (n+1)`, return the predecessor. +/// Matches Lean 4's `is_nat_succ` / lean4lean's `isNatSuccOf?`. +fn is_nat_succ(e: &Expr) -> Option { + match e.as_data() { + ExprData::App(f, arg, _) => match f.as_data() { + ExprData::Const(name, _, _) if *name == mk_name2("Nat", "succ") => { + Some(arg.clone()) + }, + _ => None, + }, + ExprData::Lit(Literal::NatVal(n), _) if n.0 > BigUint::ZERO => { + Some(Expr::lit(Literal::NatVal(Nat( + n.0.clone() - BigUint::from(1u64), + )))) + }, + _ => None, + } +} + +/// Nat offset equality: `Nat.zero =?= Nat.zero` → true, +/// `Nat.succ n =?= Nat.succ m` → `n =?= m` (recursively via def_eq). +/// Also handles nat literals: `lit 5 =?= Nat.succ (lit 4)` → true. +/// Matches Lean 4's `is_def_eq_offset`. +fn def_eq_nat_offset(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> Option { + if is_nat_zero(x) && is_nat_zero(y) { + return Some(true); + } + match (is_nat_succ(x), is_nat_succ(y)) { + (Some(x_pred), Some(y_pred)) => Some(def_eq(&x_pred, &y_pred, tc)), + _ => None, + } +} + +/// Try to reduce via nat operations or native reductions, returning the reduced form if successful. +fn try_lazy_delta_nat_native(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + match head.as_data() { + ExprData::Const(name, _, _) => { + if let Some(r) = try_reduce_native(name, &args) { + return Some(r); + } + if let Some(r) = try_reduce_nat(e, env) { + return Some(r); + } + None + }, + _ => None, + } +} + /// Lazy delta reduction: unfold definitions step by step. fn lazy_delta_step( x: &Expr, @@ -301,8 +446,38 @@ fn lazy_delta_step( ) -> DeltaResult { let mut x = x.clone(); let mut y = y.clone(); + let mut iters: u32 = 0; + const MAX_DELTA_ITERS: u32 = 10_000; loop { + iters += 1; + if iters > MAX_DELTA_ITERS { + return DeltaResult::Exhausted(x, y); + } + + // Nat offset comparison (Lean 4: isDefEqOffset) + if let Some(quick) = def_eq_nat_offset(&x, &y, tc) { + return DeltaResult::Found(quick); + } + + // Try nat/native reduction on each side before delta + if let Some(x_r) = try_lazy_delta_nat_native(&x, tc.env) { + let x_r = tc.whnf_no_delta(&x_r); + if let Some(quick) = def_eq_quick_check(&x_r, &y) { + return DeltaResult::Found(quick); + } + x = x_r; + continue; + } + if let Some(y_r) = try_lazy_delta_nat_native(&y, tc.env) { + let y_r = tc.whnf_no_delta(&y_r); + if let Some(quick) = def_eq_quick_check(&x, &y_r) { + return DeltaResult::Found(quick); + } + y = y_r; + continue; + } + let x_def = get_applied_def(&x, tc.env); let y_def = get_applied_def(&y, tc.env); @@ -362,10 +537,11 @@ fn get_applied_def( } } -/// Unfold a definition and do cheap WHNF. +/// Unfold a definition and do cheap WHNF (no delta). +/// Matches lean4lean: `let delta e := whnfCore (unfoldDefinition env e).get!`. fn delta(e: &Expr, tc: &mut TypeChecker) -> Expr { match try_unfold_def(e, tc.env) { - Some(unfolded) => tc.whnf(&unfolded), + Some(unfolded) => tc.whnf_no_delta(&unfolded), None => e.clone(), } } @@ -1295,4 +1471,262 @@ mod tests { let y = tc.mk_local(&mk_name("y"), &unit_ty); assert!(tc.def_eq(&x, &y)); } + + // ========================================================================== + // ThmInfo fix: theorems must not enter lazy_delta_step + // ========================================================================== + + /// Build an env with Nat + two ThmInfo constants. + fn mk_thm_env() -> Env { + let mut env = mk_nat_env(); + let thm_a = mk_name("thmA"); + let thm_b = mk_name("thmB"); + let prop = Expr::sort(Level::zero()); + // Two theorems with the same type (True : Prop) + let true_name = mk_name("True"); + env.insert( + true_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: true_name.clone(), + level_params: vec![], + typ: prop.clone(), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![true_name.clone()], + ctors: vec![mk_name2("True", "intro")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let intro_name = mk_name2("True", "intro"); + env.insert( + intro_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: intro_name.clone(), + level_params: vec![], + typ: Expr::cnst(true_name.clone(), vec![]), + }, + induct: true_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let true_ty = Expr::cnst(true_name, vec![]); + env.insert( + thm_a.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_a.clone(), + level_params: vec![], + typ: true_ty.clone(), + }, + value: Expr::cnst(intro_name.clone(), vec![]), + all: vec![thm_a.clone()], + }), + ); + env.insert( + thm_b.clone(), + ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: thm_b.clone(), + level_params: vec![], + typ: true_ty, + }, + value: Expr::cnst(intro_name, vec![]), + all: vec![thm_b.clone()], + }), + ); + env + } + + #[test] + fn test_def_eq_theorem_vs_theorem_terminates() { + // Two theorem constants of the same Prop type should be def-eq + // via proof irrelevance (not via delta). Before the fix, this + // would infinite loop because get_applied_def returned Some for ThmInfo. + let env = mk_thm_env(); + let mut tc = TypeChecker::new(&env); + let a = Expr::cnst(mk_name("thmA"), vec![]); + let b = Expr::cnst(mk_name("thmB"), vec![]); + assert!(tc.def_eq(&a, &b)); + } + + #[test] + fn test_def_eq_theorem_vs_constructor_terminates() { + // A theorem constant vs a constructor of the same type must terminate. + let env = mk_thm_env(); + let mut tc = TypeChecker::new(&env); + let thm = Expr::cnst(mk_name("thmA"), vec![]); + let ctor = Expr::cnst(mk_name2("True", "intro"), vec![]); + // Both have type True (a Prop), so proof irrelevance should make them def-eq + assert!(tc.def_eq(&thm, &ctor)); + } + + #[test] + fn test_get_applied_def_includes_theorems_as_opaque() { + let env = mk_thm_env(); + let thm = Expr::cnst(mk_name("thmA"), vec![]); + let result = get_applied_def(&thm, &env); + assert!(result.is_some()); + let (_, hints) = result.unwrap(); + assert_eq!(hints, ReducibilityHints::Opaque); + } + + // ========================================================================== + // Nat offset equality (is_nat_zero, is_nat_succ, def_eq_nat_offset) + // ========================================================================== + + fn nat_lit(n: u64) -> Expr { + Expr::lit(Literal::NatVal(Nat::from(n))) + } + + #[test] + fn test_is_nat_zero_ctor() { + assert!(super::is_nat_zero(&nat_zero())); + } + + #[test] + fn test_is_nat_zero_lit() { + assert!(super::is_nat_zero(&nat_lit(0))); + } + + #[test] + fn test_is_nat_zero_nonzero_lit() { + assert!(!super::is_nat_zero(&nat_lit(5))); + } + + #[test] + fn test_is_nat_succ_ctor() { + let succ_zero = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + let pred = super::is_nat_succ(&succ_zero); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(4)); + } + + #[test] + fn test_is_nat_succ_lit() { + // lit 5 should decompose to lit 4 (Lean 4: isNatSuccOf?) + let pred = super::is_nat_succ(&nat_lit(5)); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(4)); + } + + #[test] + fn test_is_nat_succ_lit_one() { + // lit 1 should decompose to lit 0 + let pred = super::is_nat_succ(&nat_lit(1)); + assert!(pred.is_some()); + assert_eq!(pred.unwrap(), nat_lit(0)); + } + + #[test] + fn test_is_nat_succ_lit_zero() { + // lit 0 should NOT decompose (it's zero, not succ of anything) + assert!(super::is_nat_succ(&nat_lit(0)).is_none()); + } + + #[test] + fn test_is_nat_succ_nat_zero_ctor() { + assert!(super::is_nat_succ(&nat_zero()).is_none()); + } + + #[test] + fn def_eq_nat_zero_ctor_vs_lit() { + // Nat.zero =def= lit 0 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + assert!(tc.def_eq(&nat_zero(), &nat_lit(0))); + } + + #[test] + fn def_eq_nat_lit_vs_succ_lit() { + // lit 5 =def= Nat.succ (lit 4) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + assert!(tc.def_eq(&nat_lit(5), &succ_4)); + } + + #[test] + fn def_eq_nat_succ_lit_vs_lit() { + // Nat.succ (lit 4) =def= lit 5 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_4 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(4), + ); + assert!(tc.def_eq(&succ_4, &nat_lit(5))); + } + + #[test] + fn def_eq_nat_lit_one_vs_succ_zero() { + // lit 1 =def= Nat.succ Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_zero = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + assert!(tc.def_eq(&nat_lit(1), &succ_zero)); + } + + #[test] + fn def_eq_nat_lit_not_equal_succ() { + // lit 5 ≠ Nat.succ (lit 5) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ_5 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(5), + ); + assert!(!tc.def_eq(&nat_lit(5), &succ_5)); + } + + #[test] + fn def_eq_nat_add_result_vs_lit() { + // Nat.add (lit 3) (lit 4) =def= lit 7 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_lit(3), + ), + nat_lit(4), + ); + assert!(tc.def_eq(&add_3_4, &nat_lit(7))); + } + + #[test] + fn def_eq_nat_add_vs_succ() { + // Nat.add (lit 3) (lit 4) =def= Nat.succ (lit 6) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let add_3_4 = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_lit(3), + ), + nat_lit(4), + ); + let succ_6 = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_lit(6), + ); + assert!(tc.def_eq(&add_3_4, &succ_6)); + } } diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs index a06ed819..4cf79d45 100644 --- a/src/ix/kernel/inductive.rs +++ b/src/ix/kernel/inductive.rs @@ -157,23 +157,33 @@ pub fn validate_k_flag( /// Check if an expression mentions a constant by name. fn expr_mentions_const(e: &Expr, name: &Name) -> bool { - match e.as_data() { - ExprData::Const(n, _, _) => n == name, - ExprData::App(f, a, _) => { - expr_mentions_const(f, name) || expr_mentions_const(a, name) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - expr_mentions_const(t, name) || expr_mentions_const(b, name) - }, - ExprData::LetE(_, t, v, b, _, _) => { - expr_mentions_const(t, name) - || expr_mentions_const(v, name) - || expr_mentions_const(b, name) - }, - ExprData::Proj(_, _, s, _) => expr_mentions_const(s, name), - ExprData::Mdata(_, inner, _) => expr_mentions_const(inner, name), - _ => false, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Const(n, _, _) => { + if n == name { + return true; + } + }, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + _ => {}, + } } + false } /// Check that no inductive name from `ind.all` appears in a negative position @@ -228,44 +238,49 @@ fn check_strict_positivity( ind_names: &[Name], tc: &mut TypeChecker, ) -> TcResult<()> { - let whnf_ty = tc.whnf(ty); - - // If no inductive name is mentioned, we're fine - if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { - return Ok(()); - } - - match whnf_ty.as_data() { - ExprData::ForallE(_, domain, body, _, _) => { - // Domain must NOT mention any inductive name - for ind_name in ind_names { - if expr_mentions_const(domain, ind_name) { - return Err(TcError::KernelException { - msg: format!( - "inductive {} occurs in negative position (strict positivity violation)", - ind_name.pretty() - ), - }); + let mut current = ty.clone(); + loop { + let whnf_ty = tc.whnf(¤t); + + // If no inductive name is mentioned, we're fine + if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { + return Ok(()); + } + + match whnf_ty.as_data() { + ExprData::ForallE(_, domain, body, _, _) => { + // Domain must NOT mention any inductive name + for ind_name in ind_names { + if expr_mentions_const(domain, ind_name) { + return Err(TcError::KernelException { + msg: format!( + "inductive {} occurs in negative position (strict positivity violation)", + ind_name.pretty() + ), + }); + } } - } - // Recurse into body - check_strict_positivity(body, ind_names, tc) - }, - _ => { - // The inductive is mentioned and we're not in a Pi — check if - // it's simply an application `I args...` (which is OK). - let (head, _) = unfold_apps(&whnf_ty); - match head.as_data() { - ExprData::Const(name, _, _) - if ind_names.iter().any(|n| n == name) => - { - Ok(()) - }, - _ => Err(TcError::KernelException { - msg: "inductive type occurs in a non-positive position".into(), - }), - } - }, + // Continue with body (was tail-recursive) + current = body.clone(); + }, + _ => { + // The inductive is mentioned and we're not in a Pi — check if + // it's simply an application `I args...` (which is OK). + let (head, _) = unfold_apps(&whnf_ty); + match head.as_data() { + ExprData::Const(name, _, _) + if ind_names.iter().any(|n| n == name) => + { + return Ok(()); + }, + _ => { + return Err(TcError::KernelException { + msg: "inductive type occurs in a non-positive position".into(), + }); + }, + } + }, + } } } diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs index 90931ca6..80195e35 100644 --- a/src/ix/kernel/level.rs +++ b/src/ix/kernel/level.rs @@ -245,31 +245,41 @@ pub fn all_uparams_defined(level: &Level, params: &[Name]) -> bool { /// Check that all universe parameters in an expression are contained in `params`. /// Recursively walks the Expr, checking all Levels in Sort and Const nodes. pub fn all_expr_uparams_defined(e: &Expr, params: &[Name]) -> bool { - match e.as_data() { - ExprData::Sort(level, _) => all_uparams_defined(level, params), - ExprData::Const(_, levels, _) => { - levels.iter().all(|l| all_uparams_defined(l, params)) - }, - ExprData::App(f, a, _) => { - all_expr_uparams_defined(f, params) - && all_expr_uparams_defined(a, params) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - all_expr_uparams_defined(t, params) - && all_expr_uparams_defined(b, params) - }, - ExprData::LetE(_, t, v, b, _, _) => { - all_expr_uparams_defined(t, params) - && all_expr_uparams_defined(v, params) - && all_expr_uparams_defined(b, params) - }, - ExprData::Proj(_, _, s, _) => all_expr_uparams_defined(s, params), - ExprData::Mdata(_, inner, _) => all_expr_uparams_defined(inner, params), - ExprData::Bvar(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => true, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Sort(level, _) => { + if !all_uparams_defined(level, params) { + return false; + } + }, + ExprData::Const(_, levels, _) => { + if !levels.iter().all(|l| all_uparams_defined(l, params)) { + return false; + } + }, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => {}, + } } + true } /// Check that a list of levels are all Params with no duplicates. diff --git a/src/ix/kernel/mod.rs b/src/ix/kernel/mod.rs index d6a5750e..23aea4f6 100644 --- a/src/ix/kernel/mod.rs +++ b/src/ix/kernel/mod.rs @@ -1,5 +1,6 @@ pub mod convert; pub mod dag; +pub mod dag_tc; pub mod def_eq; pub mod dll; pub mod error; diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs index e80416fd..604fbf02 100644 --- a/src/ix/kernel/tc.rs +++ b/src/ix/kernel/tc.rs @@ -1,5 +1,6 @@ use crate::ix::env::*; use crate::lean::nat::Nat; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rustc_hash::FxHashMap; use super::def_eq::def_eq; @@ -13,9 +14,13 @@ type TcResult = Result; pub struct TypeChecker<'env> { pub env: &'env Env, pub whnf_cache: FxHashMap, + pub whnf_no_delta_cache: FxHashMap, pub infer_cache: FxHashMap, pub local_counter: u64, pub local_types: FxHashMap, + pub def_eq_calls: u64, + pub whnf_calls: u64, + pub infer_calls: u64, } impl<'env> TypeChecker<'env> { @@ -23,9 +28,13 @@ impl<'env> TypeChecker<'env> { TypeChecker { env, whnf_cache: FxHashMap::default(), + whnf_no_delta_cache: FxHashMap::default(), infer_cache: FxHashMap::default(), local_counter: 0, local_types: FxHashMap::default(), + def_eq_calls: 0, + whnf_calls: 0, + infer_calls: 0, } } @@ -37,8 +46,33 @@ impl<'env> TypeChecker<'env> { if let Some(cached) = self.whnf_cache.get(e) { return cached.clone(); } + self.whnf_calls += 1; + let tag = match e.as_data() { + ExprData::Sort(..) => "Sort", + ExprData::Const(_, _, _) => "Const", + ExprData::App(..) => "App", + ExprData::Lam(..) => "Lam", + ExprData::ForallE(..) => "Pi", + ExprData::LetE(..) => "Let", + ExprData::Lit(..) => "Lit", + ExprData::Proj(..) => "Proj", + ExprData::Fvar(..) => "Fvar", + ExprData::Bvar(..) => "Bvar", + ExprData::Mvar(..) => "Mvar", + ExprData::Mdata(..) => "Mdata", + }; + eprintln!("[tc.whnf] #{} {tag}", self.whnf_calls); let result = whnf(e, self.env); - self.whnf_cache.insert(e.clone(), result.clone()); + eprintln!("[tc.whnf] #{} {tag} done", self.whnf_calls); + result + } + + pub fn whnf_no_delta(&mut self, e: &Expr) -> Expr { + if let Some(cached) = self.whnf_no_delta_cache.get(e) { + return cached.clone(); + } + let result = whnf_no_delta(e, self.env); + self.whnf_no_delta_cache.insert(e.clone(), result.clone()); result } @@ -102,40 +136,87 @@ impl<'env> TypeChecker<'env> { if let Some(cached) = self.infer_cache.get(e) { return Ok(cached.clone()); } + self.infer_calls += 1; + let tag = match e.as_data() { + ExprData::Sort(..) => "Sort".to_string(), + ExprData::Const(n, _, _) => format!("Const({})", n.pretty()), + ExprData::App(..) => "App".to_string(), + ExprData::Lam(..) => "Lam".to_string(), + ExprData::ForallE(..) => "Pi".to_string(), + ExprData::LetE(..) => "Let".to_string(), + ExprData::Lit(..) => "Lit".to_string(), + ExprData::Proj(..) => "Proj".to_string(), + ExprData::Fvar(n, _) => format!("Fvar({})", n.pretty()), + ExprData::Bvar(..) => "Bvar".to_string(), + ExprData::Mvar(..) => "Mvar".to_string(), + ExprData::Mdata(..) => "Mdata".to_string(), + }; + eprintln!("[tc.infer] #{} {tag}", self.infer_calls); let result = self.infer_core(e)?; self.infer_cache.insert(e.clone(), result.clone()); Ok(result) } fn infer_core(&mut self, e: &Expr) -> TcResult { - match e.as_data() { - ExprData::Sort(level, _) => self.infer_sort(level), - ExprData::Const(name, levels, _) => self.infer_const(name, levels), - ExprData::App(..) => self.infer_app(e), - ExprData::Lam(..) => self.infer_lambda(e), - ExprData::ForallE(..) => self.infer_pi(e), - ExprData::LetE(_, typ, val, body, _, _) => { - self.infer_let(typ, val, body) - }, - ExprData::Lit(lit, _) => self.infer_lit(lit), - ExprData::Proj(type_name, idx, structure, _) => { - self.infer_proj(type_name, idx, structure) - }, - ExprData::Mdata(_, inner, _) => self.infer(inner), - ExprData::Fvar(name, _) => { - match self.local_types.get(name) { - Some(ty) => Ok(ty.clone()), - None => Err(TcError::KernelException { - msg: "cannot infer type of free variable without context".into(), - }), - } - }, - ExprData::Bvar(idx, _) => Err(TcError::FreeBoundVariable { - idx: idx.to_u64().unwrap_or(u64::MAX), - }), - ExprData::Mvar(..) => Err(TcError::KernelException { - msg: "cannot infer type of metavariable".into(), - }), + // Peel Mdata and Let layers iteratively to avoid stack depth + let mut cursor = e.clone(); + loop { + match cursor.as_data() { + ExprData::Mdata(_, inner, _) => { + // Check cache for inner before recursing + if let Some(cached) = self.infer_cache.get(inner) { + return Ok(cached.clone()); + } + cursor = inner.clone(); + continue; + }, + ExprData::LetE(_, typ, val, body, _, _) => { + let val_ty = self.infer(val)?; + self.assert_def_eq(&val_ty, typ)?; + let body_inst = inst(body, &[val.clone()]); + // Check cache for body_inst before looping + if let Some(cached) = self.infer_cache.get(&body_inst) { + return Ok(cached.clone()); + } + // Cache the current let expression's result once we compute it + let orig = cursor.clone(); + cursor = body_inst; + // We need to compute the result and cache it for `orig` + let result = self.infer(&cursor)?; + self.infer_cache.insert(orig, result.clone()); + return Ok(result); + }, + ExprData::Sort(level, _) => return self.infer_sort(level), + ExprData::Const(name, levels, _) => { + return self.infer_const(name, levels) + }, + ExprData::App(..) => return self.infer_app(&cursor), + ExprData::Lam(..) => return self.infer_lambda(&cursor), + ExprData::ForallE(..) => return self.infer_pi(&cursor), + ExprData::Lit(lit, _) => return self.infer_lit(lit), + ExprData::Proj(type_name, idx, structure, _) => { + return self.infer_proj(type_name, idx, structure) + }, + ExprData::Fvar(name, _) => { + return match self.local_types.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context" + .into(), + }), + } + }, + ExprData::Bvar(idx, _) => { + return Err(TcError::FreeBoundVariable { + idx: idx.to_u64().unwrap_or(u64::MAX), + }) + }, + ExprData::Mvar(..) => { + return Err(TcError::KernelException { + msg: "cannot infer type of metavariable".into(), + }) + }, + } } } @@ -253,19 +334,6 @@ impl<'env> TypeChecker<'env> { Ok(Expr::sort(result_level)) } - fn infer_let( - &mut self, - typ: &Expr, - val: &Expr, - body: &Expr, - ) -> TcResult { - // Verify value matches declared type - let val_ty = self.infer(val)?; - self.assert_def_eq(&val_ty, typ)?; - let body_inst = inst(body, &[val.clone()]); - self.infer(&body_inst) - } - fn infer_lit(&mut self, lit: &Literal) -> TcResult { match lit { Literal::NatVal(_) => { @@ -375,7 +443,11 @@ impl<'env> TypeChecker<'env> { // ========================================================================== pub fn def_eq(&mut self, x: &Expr, y: &Expr) -> bool { - def_eq(x, y, self) + self.def_eq_calls += 1; + eprintln!("[tc.def_eq] #{}", self.def_eq_calls); + let result = def_eq(x, y, self); + eprintln!("[tc.def_eq] #{} done => {result}", self.def_eq_calls); + result } pub fn assert_def_eq(&mut self, x: &Expr, y: &Expr) -> TcResult<()> { @@ -432,6 +504,31 @@ impl<'env> TypeChecker<'env> { Ok(()) } + /// Check a declaration that has both a type and a value (DefnInfo, ThmInfo, OpaqueInfo). + fn check_value_declar( + &mut self, + cnst: &ConstantVal, + value: &Expr, + ) -> TcResult<()> { + eprintln!("[check_value_declar] checking type for {}", cnst.name.pretty()); + self.check_declar_info(cnst)?; + eprintln!("[check_value_declar] type OK, checking value uparams"); + if !all_expr_uparams_defined(value, &cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + cnst.name.pretty() + ), + }); + } + eprintln!("[check_value_declar] inferring value type"); + let inferred_type = self.infer(value)?; + eprintln!("[check_value_declar] inferred, checking def_eq"); + self.assert_def_eq(&inferred_type, &cnst.typ)?; + eprintln!("[check_value_declar] done"); + Ok(()) + } + /// Check a single declaration. pub fn check_declar( &mut self, @@ -442,43 +539,13 @@ impl<'env> TypeChecker<'env> { self.check_declar_info(&v.cnst)?; }, ConstantInfo::DefnInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::ThmInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::OpaqueInfo(v) => { - self.check_declar_info(&v.cnst)?; - if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { - return Err(TcError::KernelException { - msg: format!( - "undeclared universe parameters in value of {}", - v.cnst.name.pretty() - ), - }); - } - let inferred_type = self.infer(&v.value)?; - self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + self.check_value_declar(&v.cnst, &v.value)?; }, ConstantInfo::QuotInfo(v) => { self.check_declar_info(&v.cnst)?; @@ -512,16 +579,77 @@ impl<'env> TypeChecker<'env> { } } -/// Check all declarations in an environment. +/// Check all declarations in an environment in parallel. pub fn check_env(env: &Env) -> Vec<(Name, TcError)> { - let mut errors = Vec::new(); - for (name, ci) in env.iter() { - let mut tc = TypeChecker::new(env); - if let Err(e) = tc.check_declar(ci) { - errors.push((name.clone(), e)); - } + use std::collections::BTreeSet; + use std::io::Write; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Mutex; + + let total = env.len(); + let checked = AtomicUsize::new(0); + + struct Display { + active: BTreeSet, + prev_lines: usize, } - errors + let display = Mutex::new(Display { active: BTreeSet::new(), prev_lines: 0 }); + + let refresh = |d: &mut Display, checked: usize| { + let mut stderr = std::io::stderr().lock(); + if d.prev_lines > 0 { + write!(stderr, "\x1b[{}A", d.prev_lines).ok(); + } + write!( + stderr, + "\x1b[2K[check_env] {}/{} — {} active\n", + checked, + total, + d.active.len() + ) + .ok(); + let mut new_lines = 1; + for name in &d.active { + write!(stderr, "\x1b[2K {}\n", name).ok(); + new_lines += 1; + } + let extra = d.prev_lines.saturating_sub(new_lines); + for _ in 0..extra { + write!(stderr, "\x1b[2K\n").ok(); + } + if extra > 0 { + write!(stderr, "\x1b[{}A", extra).ok(); + } + d.prev_lines = new_lines; + stderr.flush().ok(); + }; + + env + .par_iter() + .filter_map(|(name, ci)| { + let pretty = name.pretty(); + { + let mut d = display.lock().unwrap(); + d.active.insert(pretty.clone()); + refresh(&mut d, checked.load(Ordering::Relaxed)); + } + + let mut tc = TypeChecker::new(env); + let result = tc.check_declar(ci); + + let n = checked.fetch_add(1, Ordering::Relaxed) + 1; + { + let mut d = display.lock().unwrap(); + d.active.remove(&pretty); + refresh(&mut d, n); + } + + match result { + Ok(()) => None, + Err(e) => Some((name.clone(), e)), + } + }) + .collect() } #[cfg(test)] @@ -553,9 +681,18 @@ mod tests { Expr::sort(Level::param(mk_name("u"))) } - /// Build a minimal environment with Nat, Nat.zero, and Nat.succ. + fn bvar(n: u64) -> Expr { + Expr::bvar(Nat::from(n)) + } + + fn nat_succ_expr() -> Expr { + Expr::cnst(mk_name2("Nat", "succ"), vec![]) + } + + /// Build a minimal environment with Nat, Nat.zero, Nat.succ, and Nat.rec. fn mk_nat_env() -> Env { let mut env = Env::default(); + let u = mk_name("u"); let nat_name = mk_name("Nat"); // Nat : Sort 1 @@ -614,6 +751,147 @@ mod tests { }); env.insert(succ_name, succ); + // Nat.rec.{u} : + // {motive : Nat → Sort u} → + // motive Nat.zero → + // ((n : Nat) → motive n → motive (Nat.succ n)) → + // (t : Nat) → motive t + let rec_name = mk_name2("Nat", "rec"); + + // Build the type with de Bruijn indices. + // Binder stack (from outermost): motive(3), z(2), s(1), t(0) + // At the innermost body: motive=bvar(3), z=bvar(2), s=bvar(1), t=bvar(0) + let motive_type = Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ); // Nat → Sort u + + // s type: (n : Nat) → motive n → motive (Nat.succ n) + // At s's position: motive=bvar(1), z=bvar(0) + // Inside forallE "n": motive=bvar(2), z=bvar(1), n=bvar(0) + // Inside forallE "_": motive=bvar(3), z=bvar(2), n=bvar(1), _=bvar(0) + let s_type = Expr::all( + mk_name("n"), + nat_type(), + Expr::all( + mk_name("_"), + Expr::app(bvar(2), bvar(0)), // motive n + Expr::app(bvar(3), Expr::app(nat_succ_expr(), bvar(1))), // motive (Nat.succ n) + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let rec_type = Expr::all( + mk_name("motive"), + motive_type.clone(), + Expr::all( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), // motive Nat.zero + Expr::all( + mk_name("s"), + s_type, + Expr::all( + mk_name("t"), + nat_type(), + Expr::app(bvar(3), bvar(0)), // motive t + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Implicit, + ); + + // Zero rule RHS: fun (motive) (z) (s) => z + // Inside: motive=bvar(2), z=bvar(1), s=bvar(0) + let zero_rhs = Expr::lam( + mk_name("motive"), + motive_type.clone(), + Expr::lam( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), + Expr::lam( + mk_name("s"), + nat_type(), // placeholder type for s (not checked) + bvar(1), // z + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + // Succ rule RHS: fun (motive) (z) (s) (n) => s n (Nat.rec.{u} motive z s n) + // Inside: motive=bvar(3), z=bvar(2), s=bvar(1), n=bvar(0) + let nat_rec_u = + Expr::cnst(rec_name.clone(), vec![Level::param(u.clone())]); + let recursive_call = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_u, bvar(3)), // Nat.rec motive + bvar(2), // z + ), + bvar(1), // s + ), + bvar(0), // n + ); + let succ_rhs = Expr::lam( + mk_name("motive"), + motive_type, + Expr::lam( + mk_name("z"), + Expr::app(bvar(0), nat_zero()), + Expr::lam( + mk_name("s"), + nat_type(), // placeholder + Expr::lam( + mk_name("n"), + nat_type(), + Expr::app( + Expr::app(bvar(1), bvar(0)), // s n + recursive_call, // (Nat.rec motive z s n) + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: rec_name.clone(), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: zero_rhs, + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: succ_rhs, + }, + ], + k: false, + is_unsafe: false, + }); + env.insert(rec_name, rec); + env } @@ -1691,4 +1969,219 @@ mod tests { }); assert!(tc.check_declar(&rec).is_err()); } + + // ========================================================================== + // check_declar: Nat.add via Nat.rec + // ========================================================================== + + #[test] + fn check_nat_add_via_rec() { + // Nat.add : Nat → Nat → Nat := + // fun (n m : Nat) => @Nat.rec.{1} (fun _ => Nat) n (fun _ ih => Nat.succ ih) m + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + + let nat = nat_type(); + let nat_rec_1 = Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ); + + // motive: fun (_ : Nat) => Nat + let motive = Expr::lam( + mk_name("_"), + nat.clone(), + nat.clone(), + BinderInfo::Default, + ); + + // step: fun (_ : Nat) (ih : Nat) => Nat.succ ih + let step = Expr::lam( + mk_name("_"), + nat.clone(), + Expr::lam( + mk_name("ih"), + nat.clone(), + Expr::app(nat_succ_expr(), bvar(0)), // Nat.succ ih + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + // value: fun (n m : Nat) => @Nat.rec.{1} (fun _ => Nat) n (fun _ ih => Nat.succ ih) m + // = fun n m => Nat.rec motive n step m + let body = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_1, motive), + bvar(1), // n + ), + step, + ), + bvar(0), // m + ); + let value = Expr::lam( + mk_name("n"), + nat.clone(), + Expr::lam( + mk_name("m"), + nat.clone(), + body, + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let typ = Expr::all( + mk_name("n"), + nat.clone(), + Expr::all(mk_name("m"), nat.clone(), nat, BinderInfo::Default), + BinderInfo::Default, + ); + + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name2("Nat", "add"), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name2("Nat", "add")], + }); + assert!(tc.check_declar(&defn).is_ok()); + } + + /// Build mk_nat_env + Nat.add definition in the env. + fn mk_nat_add_env() -> Env { + let mut env = mk_nat_env(); + let nat = nat_type(); + + let nat_rec_1 = Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ); + + let motive = Expr::lam( + mk_name("_"), + nat.clone(), + nat.clone(), + BinderInfo::Default, + ); + + let step = Expr::lam( + mk_name("_"), + nat.clone(), + Expr::lam( + mk_name("ih"), + nat.clone(), + Expr::app(nat_succ_expr(), bvar(0)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let body = Expr::app( + Expr::app( + Expr::app( + Expr::app(nat_rec_1, motive), + bvar(1), // n + ), + step, + ), + bvar(0), // m + ); + let value = Expr::lam( + mk_name("n"), + nat.clone(), + Expr::lam( + mk_name("m"), + nat.clone(), + body, + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + let typ = Expr::all( + mk_name("n"), + nat.clone(), + Expr::all(mk_name("m"), nat.clone(), nat, BinderInfo::Default), + BinderInfo::Default, + ); + + env.insert( + mk_name2("Nat", "add"), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name2("Nat", "add"), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name2("Nat", "add")], + }), + ); + + env + } + + #[test] + fn check_nat_add_env() { + // Verify that the full Nat + Nat.add environment typechecks + let env = mk_nat_add_env(); + let errors = check_env(&env); + assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors); + } + + #[test] + fn whnf_nat_add_zero_zero() { + // Nat.add Nat.zero Nat.zero should WHNF to 0 (as nat literal) + let env = mk_nat_add_env(); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_zero(), + ), + nat_zero(), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::lit(Literal::NatVal(Nat::from(0u64)))); + } + + #[test] + fn whnf_nat_add_lit() { + // Nat.add 2 3 should WHNF to 5 + let env = mk_nat_add_env(); + let two = Expr::lit(Literal::NatVal(Nat::from(2u64))); + let three = Expr::lit(Literal::NatVal(Nat::from(3u64))); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + two, + ), + three, + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::lit(Literal::NatVal(Nat::from(5u64)))); + } + + #[test] + fn infer_nat_add_applied() { + // Nat.add Nat.zero Nat.zero : Nat + let env = mk_nat_add_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::app( + Expr::cnst(mk_name2("Nat", "add"), vec![]), + nat_zero(), + ), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } } diff --git a/src/ix/kernel/upcopy.rs b/src/ix/kernel/upcopy.rs index 89dae8a0..a3657ac4 100644 --- a/src/ix/kernel/upcopy.rs +++ b/src/ix/kernel/upcopy.rs @@ -10,223 +10,225 @@ use super::dll::DLL; // ============================================================================ pub fn upcopy(new_child: DAGPtr, cc: ParentPtr) { - unsafe { - match cc { - ParentPtr::Root => {}, - ParentPtr::LamBod(link) => { - let lam = &*link.as_ptr(); - let var = &lam.var; - let new_lam = alloc_lam(var.depth, new_child, None); - let new_lam_ref = &mut *new_lam.as_ptr(); - let bod_ref_ptr = - NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); - add_to_parents(new_child, bod_ref_ptr); - let new_var_ptr = - NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); - for parent in DLL::iter_option(var.parents) { - upcopy(DAGPtr::Var(new_var_ptr), *parent); - } - for parent in DLL::iter_option(lam.parents) { - upcopy(DAGPtr::Lam(new_lam), *parent); - } - }, - ParentPtr::AppFun(link) => { - let app = &mut *link.as_ptr(); - match app.copy { - Some(cache) => { - (*cache.as_ptr()).fun = new_child; - }, - None => { - let new_app = alloc_app_no_uplinks(new_child, app.arg); - app.copy = Some(new_app); - for parent in DLL::iter_option(app.parents) { - upcopy(DAGPtr::App(new_app), *parent); - } - }, - } - }, - ParentPtr::AppArg(link) => { - let app = &mut *link.as_ptr(); - match app.copy { - Some(cache) => { - (*cache.as_ptr()).arg = new_child; - }, - None => { - let new_app = alloc_app_no_uplinks(app.fun, new_child); - app.copy = Some(new_app); - for parent in DLL::iter_option(app.parents) { - upcopy(DAGPtr::App(new_app), *parent); - } - }, - } - }, - ParentPtr::FunDom(link) => { - let fun = &mut *link.as_ptr(); - match fun.copy { - Some(cache) => { - (*cache.as_ptr()).dom = new_child; - }, - None => { - let new_fun = alloc_fun_no_uplinks( - fun.binder_name.clone(), - fun.binder_info.clone(), - new_child, - fun.img, - ); - fun.copy = Some(new_fun); - for parent in DLL::iter_option(fun.parents) { - upcopy(DAGPtr::Fun(new_fun), *parent); - } - }, - } - }, - ParentPtr::FunImg(link) => { - let fun = &mut *link.as_ptr(); - // new_child must be a Lam - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("FunImg parent expects Lam child"), - }; - match fun.copy { - Some(cache) => { - (*cache.as_ptr()).img = new_lam; - }, - None => { - let new_fun = alloc_fun_no_uplinks( - fun.binder_name.clone(), - fun.binder_info.clone(), - fun.dom, - new_lam, - ); - fun.copy = Some(new_fun); - for parent in DLL::iter_option(fun.parents) { - upcopy(DAGPtr::Fun(new_fun), *parent); - } - }, - } - }, - ParentPtr::PiDom(link) => { - let pi = &mut *link.as_ptr(); - match pi.copy { - Some(cache) => { - (*cache.as_ptr()).dom = new_child; - }, - None => { - let new_pi = alloc_pi_no_uplinks( - pi.binder_name.clone(), - pi.binder_info.clone(), - new_child, - pi.img, - ); - pi.copy = Some(new_pi); - for parent in DLL::iter_option(pi.parents) { - upcopy(DAGPtr::Pi(new_pi), *parent); - } - }, - } - }, - ParentPtr::PiImg(link) => { - let pi = &mut *link.as_ptr(); - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("PiImg parent expects Lam child"), - }; - match pi.copy { - Some(cache) => { - (*cache.as_ptr()).img = new_lam; - }, - None => { - let new_pi = alloc_pi_no_uplinks( - pi.binder_name.clone(), - pi.binder_info.clone(), - pi.dom, - new_lam, - ); - pi.copy = Some(new_pi); - for parent in DLL::iter_option(pi.parents) { - upcopy(DAGPtr::Pi(new_pi), *parent); - } - }, - } - }, - ParentPtr::LetTyp(link) => { - let let_node = &mut *link.as_ptr(); - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).typ = new_child; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - new_child, - let_node.val, - let_node.bod, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::LetVal(link) => { - let let_node = &mut *link.as_ptr(); - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).val = new_child; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - let_node.typ, - new_child, - let_node.bod, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::LetBod(link) => { - let let_node = &mut *link.as_ptr(); - let new_lam = match new_child { - DAGPtr::Lam(p) => p, - _ => panic!("LetBod parent expects Lam child"), - }; - match let_node.copy { - Some(cache) => { - (*cache.as_ptr()).bod = new_lam; - }, - None => { - let new_let = alloc_let_no_uplinks( - let_node.binder_name.clone(), - let_node.non_dep, - let_node.typ, - let_node.val, - new_lam, - ); - let_node.copy = Some(new_let); - for parent in DLL::iter_option(let_node.parents) { - upcopy(DAGPtr::Let(new_let), *parent); - } - }, - } - }, - ParentPtr::ProjExpr(link) => { - let proj = &*link.as_ptr(); - let new_proj = alloc_proj_no_uplinks( - proj.type_name.clone(), - proj.idx.clone(), - new_child, - ); - for parent in DLL::iter_option(proj.parents) { - upcopy(DAGPtr::Proj(new_proj), *parent); - } - }, + let mut stack: Vec<(DAGPtr, ParentPtr)> = vec![(new_child, cc)]; + while let Some((new_child, cc)) = stack.pop() { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + let var = &lam.var; + let new_lam = alloc_lam(var.depth, new_child, None); + let new_lam_ref = &mut *new_lam.as_ptr(); + let bod_ref_ptr = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_child, bod_ref_ptr); + let new_var_ptr = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + for parent in DLL::iter_option(var.parents) { + stack.push((DAGPtr::Var(new_var_ptr), *parent)); + } + for parent in DLL::iter_option(lam.parents) { + stack.push((DAGPtr::Lam(new_lam), *parent)); + } + }, + ParentPtr::AppFun(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).fun = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(new_child, app.arg); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + stack.push((DAGPtr::App(new_app), *parent)); + } + }, + } + }, + ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).arg = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(app.fun, new_child); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + stack.push((DAGPtr::App(new_app), *parent)); + } + }, + } + }, + ParentPtr::FunDom(link) => { + let fun = &mut *link.as_ptr(); + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_child, + fun.img, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + stack.push((DAGPtr::Fun(new_fun), *parent)); + } + }, + } + }, + ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("FunImg parent expects Lam child"), + }; + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + new_lam, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + stack.push((DAGPtr::Fun(new_fun), *parent)); + } + }, + } + }, + ParentPtr::PiDom(link) => { + let pi = &mut *link.as_ptr(); + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_child, + pi.img, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + stack.push((DAGPtr::Pi(new_pi), *parent)); + } + }, + } + }, + ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("PiImg parent expects Lam child"), + }; + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + pi.dom, + new_lam, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + stack.push((DAGPtr::Pi(new_pi), *parent)); + } + }, + } + }, + ParentPtr::LetTyp(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).typ = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + new_child, + let_node.val, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::LetVal(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).val = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + new_child, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("LetBod parent expects Lam child"), + }; + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).bod = new_lam; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + let_node.val, + new_lam, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + stack.push((DAGPtr::Let(new_let), *parent)); + } + }, + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + let new_proj = alloc_proj_no_uplinks( + proj.type_name.clone(), + proj.idx.clone(), + new_child, + ); + for parent in DLL::iter_option(proj.parents) { + stack.push((DAGPtr::Proj(new_proj), *parent)); + } + }, + } } } } @@ -352,79 +354,82 @@ fn alloc_proj_no_uplinks( // ============================================================================ pub fn clean_up(cc: &ParentPtr) { - unsafe { - match cc { - ParentPtr::Root => {}, - ParentPtr::LamBod(link) => { - let lam = &*link.as_ptr(); - for parent in DLL::iter_option(lam.var.parents) { - clean_up(parent); - } - for parent in DLL::iter_option(lam.parents) { - clean_up(parent); - } - }, - ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { - let app = &mut *link.as_ptr(); - if let Some(app_copy) = app.copy { - let App { fun, arg, fun_ref, arg_ref, .. } = - &mut *app_copy.as_ptr(); - app.copy = None; - add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); - add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); - for parent in DLL::iter_option(app.parents) { - clean_up(parent); + let mut stack: Vec = vec![*cc]; + while let Some(cc) = stack.pop() { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + for parent in DLL::iter_option(lam.var.parents) { + stack.push(*parent); } - } - }, - ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { - let fun = &mut *link.as_ptr(); - if let Some(fun_copy) = fun.copy { - let Fun { dom, img, dom_ref, img_ref, .. } = - &mut *fun_copy.as_ptr(); - fun.copy = None; - add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); - for parent in DLL::iter_option(fun.parents) { - clean_up(parent); + for parent in DLL::iter_option(lam.parents) { + stack.push(*parent); } - } - }, - ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { - let pi = &mut *link.as_ptr(); - if let Some(pi_copy) = pi.copy { - let Pi { dom, img, dom_ref, img_ref, .. } = - &mut *pi_copy.as_ptr(); - pi.copy = None; - add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); - for parent in DLL::iter_option(pi.parents) { - clean_up(parent); + }, + ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + if let Some(app_copy) = app.copy { + let App { fun, arg, fun_ref, arg_ref, .. } = + &mut *app_copy.as_ptr(); + app.copy = None; + add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); + add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); + for parent in DLL::iter_option(app.parents) { + stack.push(*parent); + } } - } - }, - ParentPtr::LetTyp(link) - | ParentPtr::LetVal(link) - | ParentPtr::LetBod(link) => { - let let_node = &mut *link.as_ptr(); - if let Some(let_copy) = let_node.copy { - let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = - &mut *let_copy.as_ptr(); - let_node.copy = None; - add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); - add_to_parents(*val, NonNull::new(val_ref).unwrap()); - add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); - for parent in DLL::iter_option(let_node.parents) { - clean_up(parent); + }, + ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + if let Some(fun_copy) = fun.copy { + let Fun { dom, img, dom_ref, img_ref, .. } = + &mut *fun_copy.as_ptr(); + fun.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(fun.parents) { + stack.push(*parent); + } } - } - }, - ParentPtr::ProjExpr(link) => { - let proj = &*link.as_ptr(); - for parent in DLL::iter_option(proj.parents) { - clean_up(parent); - } - }, + }, + ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + if let Some(pi_copy) = pi.copy { + let Pi { dom, img, dom_ref, img_ref, .. } = + &mut *pi_copy.as_ptr(); + pi.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(pi.parents) { + stack.push(*parent); + } + } + }, + ParentPtr::LetTyp(link) + | ParentPtr::LetVal(link) + | ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + if let Some(let_copy) = let_node.copy { + let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = + &mut *let_copy.as_ptr(); + let_node.copy = None; + add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); + add_to_parents(*val, NonNull::new(val_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); + for parent in DLL::iter_option(let_node.parents) { + stack.push(*parent); + } + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + for parent in DLL::iter_option(proj.parents) { + stack.push(*parent); + } + }, + } } } } @@ -476,119 +481,122 @@ pub fn replace_child(old: DAGPtr, new: DAGPtr) { // Free dead nodes // ============================================================================ -pub fn free_dead_node(node: DAGPtr) { - unsafe { - match node { - DAGPtr::Lam(link) => { - let lam = &*link.as_ptr(); - let bod_ref_ptr = &lam.bod_ref as *const Parents; - if let Some(remaining) = (*bod_ref_ptr).unlink_node() { - set_parents(lam.bod, Some(remaining)); - } else { - set_parents(lam.bod, None); - free_dead_node(lam.bod); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::App(link) => { - let app = &*link.as_ptr(); - let fun_ref_ptr = &app.fun_ref as *const Parents; - if let Some(remaining) = (*fun_ref_ptr).unlink_node() { - set_parents(app.fun, Some(remaining)); - } else { - set_parents(app.fun, None); - free_dead_node(app.fun); - } - let arg_ref_ptr = &app.arg_ref as *const Parents; - if let Some(remaining) = (*arg_ref_ptr).unlink_node() { - set_parents(app.arg, Some(remaining)); - } else { - set_parents(app.arg, None); - free_dead_node(app.arg); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Fun(link) => { - let fun = &*link.as_ptr(); - let dom_ref_ptr = &fun.dom_ref as *const Parents; - if let Some(remaining) = (*dom_ref_ptr).unlink_node() { - set_parents(fun.dom, Some(remaining)); - } else { - set_parents(fun.dom, None); - free_dead_node(fun.dom); - } - let img_ref_ptr = &fun.img_ref as *const Parents; - if let Some(remaining) = (*img_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(fun.img), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(fun.img), None); - free_dead_node(DAGPtr::Lam(fun.img)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Pi(link) => { - let pi = &*link.as_ptr(); - let dom_ref_ptr = &pi.dom_ref as *const Parents; - if let Some(remaining) = (*dom_ref_ptr).unlink_node() { - set_parents(pi.dom, Some(remaining)); - } else { - set_parents(pi.dom, None); - free_dead_node(pi.dom); - } - let img_ref_ptr = &pi.img_ref as *const Parents; - if let Some(remaining) = (*img_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(pi.img), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(pi.img), None); - free_dead_node(DAGPtr::Lam(pi.img)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Let(link) => { - let let_node = &*link.as_ptr(); - let typ_ref_ptr = &let_node.typ_ref as *const Parents; - if let Some(remaining) = (*typ_ref_ptr).unlink_node() { - set_parents(let_node.typ, Some(remaining)); - } else { - set_parents(let_node.typ, None); - free_dead_node(let_node.typ); - } - let val_ref_ptr = &let_node.val_ref as *const Parents; - if let Some(remaining) = (*val_ref_ptr).unlink_node() { - set_parents(let_node.val, Some(remaining)); - } else { - set_parents(let_node.val, None); - free_dead_node(let_node.val); - } - let bod_ref_ptr = &let_node.bod_ref as *const Parents; - if let Some(remaining) = (*bod_ref_ptr).unlink_node() { - set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); - } else { - set_parents(DAGPtr::Lam(let_node.bod), None); - free_dead_node(DAGPtr::Lam(let_node.bod)); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Proj(link) => { - let proj = &*link.as_ptr(); - let expr_ref_ptr = &proj.expr_ref as *const Parents; - if let Some(remaining) = (*expr_ref_ptr).unlink_node() { - set_parents(proj.expr, Some(remaining)); - } else { - set_parents(proj.expr, None); - free_dead_node(proj.expr); - } - drop(Box::from_raw(link.as_ptr())); - }, - DAGPtr::Var(link) => { - let var = &*link.as_ptr(); - if let BinderPtr::Free = var.binder { +pub fn free_dead_node(root: DAGPtr) { + let mut stack: Vec = vec![root]; + while let Some(node) = stack.pop() { + unsafe { + match node { + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + let bod_ref_ptr = &lam.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(lam.bod, Some(remaining)); + } else { + set_parents(lam.bod, None); + stack.push(lam.bod); + } drop(Box::from_raw(link.as_ptr())); - } - }, - DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), - DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun_ref_ptr = &app.fun_ref as *const Parents; + if let Some(remaining) = (*fun_ref_ptr).unlink_node() { + set_parents(app.fun, Some(remaining)); + } else { + set_parents(app.fun, None); + stack.push(app.fun); + } + let arg_ref_ptr = &app.arg_ref as *const Parents; + if let Some(remaining) = (*arg_ref_ptr).unlink_node() { + set_parents(app.arg, Some(remaining)); + } else { + set_parents(app.arg, None); + stack.push(app.arg); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let dom_ref_ptr = &fun.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(fun.dom, Some(remaining)); + } else { + set_parents(fun.dom, None); + stack.push(fun.dom); + } + let img_ref_ptr = &fun.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(fun.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(fun.img), None); + stack.push(DAGPtr::Lam(fun.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let dom_ref_ptr = &pi.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(pi.dom, Some(remaining)); + } else { + set_parents(pi.dom, None); + stack.push(pi.dom); + } + let img_ref_ptr = &pi.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(pi.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(pi.img), None); + stack.push(DAGPtr::Lam(pi.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let typ_ref_ptr = &let_node.typ_ref as *const Parents; + if let Some(remaining) = (*typ_ref_ptr).unlink_node() { + set_parents(let_node.typ, Some(remaining)); + } else { + set_parents(let_node.typ, None); + stack.push(let_node.typ); + } + let val_ref_ptr = &let_node.val_ref as *const Parents; + if let Some(remaining) = (*val_ref_ptr).unlink_node() { + set_parents(let_node.val, Some(remaining)); + } else { + set_parents(let_node.val, None); + stack.push(let_node.val); + } + let bod_ref_ptr = &let_node.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(let_node.bod), None); + stack.push(DAGPtr::Lam(let_node.bod)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let expr_ref_ptr = &proj.expr_ref as *const Parents; + if let Some(remaining) = (*expr_ref_ptr).unlink_node() { + set_parents(proj.expr, Some(remaining)); + } else { + set_parents(proj.expr, None); + stack.push(proj.expr); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + } } } } @@ -598,6 +606,11 @@ pub fn free_dead_node(node: DAGPtr) { // ============================================================================ /// Contract a lambda redex: (Fun dom (Lam bod var)) arg → [arg/var]bod. +/// +/// After substitution, propagates the result through the redex App's parent +/// pointers (via `replace_child`) and frees the dead App/Fun/Lam nodes. +/// This ensures that enclosing DAG structures are properly updated, enabling +/// DAG-native sub-term WHNF without Expr roundtrips. pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { unsafe { let app = &*redex.as_ptr(); @@ -605,18 +618,46 @@ pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { let var = &lambda.var; let arg = app.arg; + // Perform substitution if DLL::is_singleton(lambda.parents) { - if DLL::is_empty(var.parents) { - return lambda.bod; + if !DLL::is_empty(var.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + } + } else if !DLL::is_empty(var.parents) { + // General case: upcopy arg through var's parents + for parent in DLL::iter_option(var.parents) { + upcopy(arg, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); } - replace_child(DAGPtr::Var(NonNull::from(var)), arg); - return lambda.bod; } + lambda.bod + } +} + +/// Substitute an argument into a Pi's body: given `Pi(dom, Lam(var, body))` +/// and `arg`, produce `[arg/var]body`. Used for computing the result type +/// of function application during type inference. +/// +/// Unlike `reduce_lam`, this does NOT consume the enclosing App/Fun — it +/// works directly on the Pi's Lam node. The Lam should typically be +/// singly-parented (freshly inferred types are not shared). +pub fn subst_pi_body(lam: NonNull, arg: DAGPtr) -> DAGPtr { + unsafe { + let lambda = &*lam.as_ptr(); + let var = &lambda.var; + if DLL::is_empty(var.parents) { return lambda.bod; } + if DLL::is_singleton(lambda.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + return lambda.bod; + } + // General case: upcopy arg through var's parents for parent in DLL::iter_option(var.parents) { upcopy(arg, *parent); @@ -629,6 +670,9 @@ pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { } /// Contract a let redex: Let(typ, val, Lam(bod, var)) → [val/var]bod. +/// +/// After substitution, propagates the result through the Let node's parent +/// pointers (via `replace_child`) and frees the dead Let/Lam nodes. pub fn reduce_let(let_node: NonNull) -> DAGPtr { unsafe { let ln = &*let_node.as_ptr(); @@ -636,24 +680,20 @@ pub fn reduce_let(let_node: NonNull) -> DAGPtr { let var = &lam.var; let val = ln.val; + // Perform substitution if DLL::is_singleton(lam.parents) { - if DLL::is_empty(var.parents) { - return lam.bod; + if !DLL::is_empty(var.parents) { + replace_child(DAGPtr::Var(NonNull::from(var)), val); + } + } else if !DLL::is_empty(var.parents) { + for parent in DLL::iter_option(var.parents) { + upcopy(val, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); } - replace_child(DAGPtr::Var(NonNull::from(var)), val); - return lam.bod; - } - - if DLL::is_empty(var.parents) { - return lam.bod; } - for parent in DLL::iter_option(var.parents) { - upcopy(val, *parent); - } - for parent in DLL::iter_option(var.parents) { - clean_up(parent); - } lam.bod } } diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs index 4fdde07a..d7cef49a 100644 --- a/src/ix/kernel/whnf.rs +++ b/src/ix/kernel/whnf.rs @@ -8,14 +8,16 @@ use super::convert::{from_expr, to_expr}; use super::dag::*; use super::level::{simplify, subst_level}; use super::upcopy::{reduce_lam, reduce_let}; - +use crate::ix::env::Literal; // ============================================================================ // Expression helpers (inst, unfold_apps, foldl_apps, subst_expr_levels) // ============================================================================ -/// Instantiate bound variables: `body[0 := substs[0], 1 := substs[1], ...]`. -/// `substs[0]` replaces `Bvar(0)` (innermost). +/// Instantiate bound variables: `body[0 := substs[n-1], 1 := substs[n-2], ...]`. +/// Follows Lean 4's `instantiate` convention: `substs[0]` is the outermost +/// variable and replaces `Bvar(n-1)`, while `substs[n-1]` is the innermost +/// and replaces `Bvar(0)`. pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { if substs.is_empty() { return body.clone(); @@ -24,56 +26,108 @@ pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { } fn inst_aux(e: &Expr, substs: &[Expr], offset: u64) -> Expr { - match e.as_data() { - ExprData::Bvar(idx, _) => { - let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); - if idx_u64 >= offset { - let adjusted = (idx_u64 - offset) as usize; - if adjusted < substs.len() { - return substs[adjusted].clone(); - } - } - e.clone() - }, - ExprData::App(f, a, _) => { - let f2 = inst_aux(f, substs, offset); - let a2 = inst_aux(a, substs, offset); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = inst_aux(t, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = inst_aux(t, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = inst_aux(t, substs, offset); - let v2 = inst_aux(v, substs, offset); - let b2 = inst_aux(b, substs, offset + 1); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = inst_aux(s, substs, offset); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = inst_aux(inner, substs, offset); - Expr::mdata(kvs.clone(), inner2) - }, - // Terminals with no bound vars - ExprData::Sort(..) - | ExprData::Const(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), + enum Frame<'a> { + Visit(&'a Expr, u64), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut work: Vec> = vec![Frame::Visit(e, offset)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e, offset) => match e.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 >= offset { + let adjusted = (idx_u64 - offset) as usize; + if adjusted < substs.len() { + // Lean 4 convention: substs[0] = outermost, substs[n-1] = innermost + // bvar(0) = innermost → substs[n-1], bvar(n-1) = outermost → substs[0] + results.push(substs[substs.len() - 1 - adjusted].clone()); + continue; + } + } + results.push(e.clone()); + }, + ExprData::App(f, a, _) => { + work.push(Frame::App); + work.push(Frame::Visit(a, offset)); + work.push(Frame::Visit(f, offset)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(v, offset)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s, offset)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner, offset)); + }, + ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => results.push(e.clone()), + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } -/// Abstract: replace free variable `fvar` with `Bvar(offset)` in `e`. +/// Abstract: replace free variables with bound variables. +/// Follows Lean 4 convention: `fvars[0]` (outermost) maps to `Bvar(n-1+offset)`, +/// `fvars[n-1]` (innermost) maps to `Bvar(0+offset)`. pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { if fvars.is_empty() { return e.clone(); @@ -82,50 +136,107 @@ pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { } fn abstr_aux(e: &Expr, fvars: &[Expr], offset: u64) -> Expr { - match e.as_data() { - ExprData::Fvar(..) => { - for (i, fv) in fvars.iter().enumerate().rev() { - if e == fv { - return Expr::bvar(Nat::from(i as u64 + offset)); - } - } - e.clone() - }, - ExprData::App(f, a, _) => { - let f2 = abstr_aux(f, fvars, offset); - let a2 = abstr_aux(a, fvars, offset); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = abstr_aux(t, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = abstr_aux(t, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = abstr_aux(t, fvars, offset); - let v2 = abstr_aux(v, fvars, offset); - let b2 = abstr_aux(b, fvars, offset + 1); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = abstr_aux(s, fvars, offset); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = abstr_aux(inner, fvars, offset); - Expr::mdata(kvs.clone(), inner2) - }, - ExprData::Bvar(..) - | ExprData::Sort(..) - | ExprData::Const(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), + enum Frame<'a> { + Visit(&'a Expr, u64), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut work: Vec> = vec![Frame::Visit(e, offset)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e, offset) => match e.as_data() { + ExprData::Fvar(..) => { + let n = fvars.len(); + let mut found = false; + for (i, fv) in fvars.iter().enumerate() { + if e == fv { + // fvars[0] (outermost) → Bvar(n-1+offset) + // fvars[n-1] (innermost) → Bvar(0+offset) + let bvar_idx = (n - 1 - i) as u64 + offset; + results.push(Expr::bvar(Nat::from(bvar_idx))); + found = true; + break; + } + } + if !found { + results.push(e.clone()); + } + }, + ExprData::App(f, a, _) => { + work.push(Frame::App); + work.push(Frame::Visit(a, offset)); + work.push(Frame::Visit(f, offset)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b, offset + 1)); + work.push(Frame::Visit(v, offset)); + work.push(Frame::Visit(t, offset)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s, offset)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner, offset)); + }, + ExprData::Bvar(..) + | ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => results.push(e.clone()), + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } /// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])`. @@ -154,66 +265,134 @@ pub fn foldl_apps(mut fun: Expr, args: impl Iterator) -> Expr { } /// Substitute universe level parameters in an expression. -pub fn subst_expr_levels( - e: &Expr, - params: &[Name], - values: &[Level], -) -> Expr { +pub fn subst_expr_levels(e: &Expr, params: &[Name], values: &[Level]) -> Expr { if params.is_empty() { return e.clone(); } subst_expr_levels_aux(e, params, values) } -fn subst_expr_levels_aux( - e: &Expr, - params: &[Name], - values: &[Level], -) -> Expr { - match e.as_data() { - ExprData::Sort(level, _) => { - Expr::sort(subst_level(level, params, values)) - }, - ExprData::Const(name, levels, _) => { - let new_levels: Vec = - levels.iter().map(|l| subst_level(l, params, values)).collect(); - Expr::cnst(name.clone(), new_levels) - }, - ExprData::App(f, a, _) => { - let f2 = subst_expr_levels_aux(f, params, values); - let a2 = subst_expr_levels_aux(a, params, values); - Expr::app(f2, a2) - }, - ExprData::Lam(n, t, b, bi, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::lam(n.clone(), t2, b2, bi.clone()) - }, - ExprData::ForallE(n, t, b, bi, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::all(n.clone(), t2, b2, bi.clone()) - }, - ExprData::LetE(n, t, v, b, nd, _) => { - let t2 = subst_expr_levels_aux(t, params, values); - let v2 = subst_expr_levels_aux(v, params, values); - let b2 = subst_expr_levels_aux(b, params, values); - Expr::letE(n.clone(), t2, v2, b2, *nd) - }, - ExprData::Proj(n, i, s, _) => { - let s2 = subst_expr_levels_aux(s, params, values); - Expr::proj(n.clone(), i.clone(), s2) - }, - ExprData::Mdata(kvs, inner, _) => { - let inner2 = subst_expr_levels_aux(inner, params, values); - Expr::mdata(kvs.clone(), inner2) - }, - // No levels to substitute - ExprData::Bvar(..) - | ExprData::Fvar(..) - | ExprData::Mvar(..) - | ExprData::Lit(..) => e.clone(), +fn subst_expr_levels_aux(e: &Expr, params: &[Name], values: &[Level]) -> Expr { + use rustc_hash::FxHashMap; + use std::sync::Arc; + + enum Frame<'a> { + Visit(&'a Expr), + CacheResult(*const ExprData), + App, + Lam(Name, BinderInfo), + All(Name, BinderInfo), + LetE(Name, bool), + Proj(Name, Nat), + Mdata(Vec<(Name, DataValue)>), + } + + let mut cache: FxHashMap<*const ExprData, Expr> = FxHashMap::default(); + let mut work: Vec> = vec![Frame::Visit(e)]; + let mut results: Vec = Vec::new(); + + while let Some(frame) = work.pop() { + match frame { + Frame::Visit(e) => { + let key = Arc::as_ptr(&e.0); + if let Some(cached) = cache.get(&key) { + results.push(cached.clone()); + continue; + } + match e.as_data() { + ExprData::Sort(level, _) => { + let r = Expr::sort(subst_level(level, params, values)); + cache.insert(key, r.clone()); + results.push(r); + }, + ExprData::Const(name, levels, _) => { + let new_levels: Vec = + levels.iter().map(|l| subst_level(l, params, values)).collect(); + let r = Expr::cnst(name.clone(), new_levels); + cache.insert(key, r.clone()); + results.push(r); + }, + ExprData::App(f, a, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::App); + work.push(Frame::Visit(a)); + work.push(Frame::Visit(f)); + }, + ExprData::Lam(n, t, b, bi, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Lam(n.clone(), bi.clone())); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(t)); + }, + ExprData::ForallE(n, t, b, bi, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::All(n.clone(), bi.clone())); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(t)); + }, + ExprData::LetE(n, t, v, b, nd, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::LetE(n.clone(), *nd)); + work.push(Frame::Visit(b)); + work.push(Frame::Visit(v)); + work.push(Frame::Visit(t)); + }, + ExprData::Proj(n, i, s, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Proj(n.clone(), i.clone())); + work.push(Frame::Visit(s)); + }, + ExprData::Mdata(kvs, inner, _) => { + work.push(Frame::CacheResult(key)); + work.push(Frame::Mdata(kvs.clone())); + work.push(Frame::Visit(inner)); + }, + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => { + cache.insert(key, e.clone()); + results.push(e.clone()); + }, + } + }, + Frame::CacheResult(key) => { + let result = results.last().unwrap().clone(); + cache.insert(key, result); + }, + Frame::App => { + let a = results.pop().unwrap(); + let f = results.pop().unwrap(); + results.push(Expr::app(f, a)); + }, + Frame::Lam(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::lam(n, t, b, bi)); + }, + Frame::All(n, bi) => { + let b = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::all(n, t, b, bi)); + }, + Frame::LetE(n, nd) => { + let b = results.pop().unwrap(); + let v = results.pop().unwrap(); + let t = results.pop().unwrap(); + results.push(Expr::letE(n, t, v, b, nd)); + }, + Frame::Proj(n, i) => { + let s = results.pop().unwrap(); + results.push(Expr::proj(n, i, s)); + }, + Frame::Mdata(kvs) => { + let inner = results.pop().unwrap(); + results.push(Expr::mdata(kvs, inner)); + }, + } } + + results.pop().unwrap() } /// Check if an expression has any loose bound variables above `offset`. @@ -222,40 +401,60 @@ pub fn has_loose_bvars(e: &Expr) -> bool { } fn has_loose_bvars_aux(e: &Expr, depth: u64) -> bool { - match e.as_data() { - ExprData::Bvar(idx, _) => idx.to_u64().unwrap_or(u64::MAX) >= depth, - ExprData::App(f, a, _) => { - has_loose_bvars_aux(f, depth) || has_loose_bvars_aux(a, depth) - }, - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - has_loose_bvars_aux(t, depth) || has_loose_bvars_aux(b, depth + 1) - }, - ExprData::LetE(_, t, v, b, _, _) => { - has_loose_bvars_aux(t, depth) - || has_loose_bvars_aux(v, depth) - || has_loose_bvars_aux(b, depth + 1) - }, - ExprData::Proj(_, _, s, _) => has_loose_bvars_aux(s, depth), - ExprData::Mdata(_, inner, _) => has_loose_bvars_aux(inner, depth), - _ => false, + let mut stack: Vec<(&Expr, u64)> = vec![(e, depth)]; + while let Some((e, depth)) = stack.pop() { + match e.as_data() { + ExprData::Bvar(idx, _) => { + if idx.to_u64().unwrap_or(u64::MAX) >= depth { + return true; + } + }, + ExprData::App(f, a, _) => { + stack.push((f, depth)); + stack.push((a, depth)); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push((t, depth)); + stack.push((b, depth + 1)); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push((t, depth)); + stack.push((v, depth)); + stack.push((b, depth + 1)); + }, + ExprData::Proj(_, _, s, _) => stack.push((s, depth)), + ExprData::Mdata(_, inner, _) => stack.push((inner, depth)), + _ => {}, + } } + false } /// Check if expression contains any free variables (Fvar). pub fn has_fvars(e: &Expr) -> bool { - match e.as_data() { - ExprData::Fvar(..) => true, - ExprData::App(f, a, _) => has_fvars(f) || has_fvars(a), - ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { - has_fvars(t) || has_fvars(b) - }, - ExprData::LetE(_, t, v, b, _, _) => { - has_fvars(t) || has_fvars(v) || has_fvars(b) - }, - ExprData::Proj(_, _, s, _) => has_fvars(s), - ExprData::Mdata(_, inner, _) => has_fvars(inner), - _ => false, + let mut stack: Vec<&Expr> = vec![e]; + while let Some(e) = stack.pop() { + match e.as_data() { + ExprData::Fvar(..) => return true, + ExprData::App(f, a, _) => { + stack.push(f); + stack.push(a); + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + stack.push(t); + stack.push(b); + }, + ExprData::LetE(_, t, v, b, _, _) => { + stack.push(t); + stack.push(v); + stack.push(b); + }, + ExprData::Proj(_, _, s, _) => stack.push(s), + ExprData::Mdata(_, inner, _) => stack.push(inner), + _ => {}, + } } + false } // ============================================================================ @@ -277,16 +476,63 @@ pub(crate) fn mk_name2(a: &str, b: &str) -> Name { /// iota/quot/nat/projection, and uses DAG-level splicing for delta. pub fn whnf(e: &Expr, env: &Env) -> Expr { let mut dag = from_expr(e); - whnf_dag(&mut dag, env); + whnf_dag(&mut dag, env, false); + let result = to_expr(&dag); + free_dag(dag); + result +} + + + +/// WHNF without delta reduction (beta/zeta/iota/quot/nat/proj only). +/// Matches Lean 4's `whnf_core` used in `is_def_eq_core`. +pub fn whnf_no_delta(e: &Expr, env: &Env) -> Expr { + let mut dag = from_expr(e); + whnf_dag(&mut dag, env, true); let result = to_expr(&dag); free_dag(dag); result } + /// Trail-based WHNF on DAG. Walks down the App spine collecting a trail, /// then dispatches on the head node. -fn whnf_dag(dag: &mut DAG, env: &Env) { +/// When `no_delta` is true, skips delta (definition) unfolding. +pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { + use std::sync::atomic::{AtomicU64, Ordering}; + static WHNF_DEPTH: AtomicU64 = AtomicU64::new(0); + static WHNF_TOTAL: AtomicU64 = AtomicU64::new(0); + + let depth = WHNF_DEPTH.fetch_add(1, Ordering::Relaxed); + let total = WHNF_TOTAL.fetch_add(1, Ordering::Relaxed); + if depth > 50 || total % 10_000 == 0 { + eprintln!("[whnf_dag] depth={depth} total={total} no_delta={no_delta}"); + } + if depth > 200 { + eprintln!("[whnf_dag] DEPTH LIMIT depth={depth}, bailing"); + WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); + return; + } + + const WHNF_STEP_LIMIT: u64 = 100_000; + let mut steps: u64 = 0; + let whnf_done = |depth| { WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); }; loop { + steps += 1; + if steps > WHNF_STEP_LIMIT { + eprintln!("[whnf_dag] step limit exceeded ({steps}) depth={depth}"); + whnf_done(depth); + return; + } + if steps <= 5 || steps % 10_000 == 0 { + let head_variant = match dag.head { + DAGPtr::Var(_) => "Var", DAGPtr::Sort(_) => "Sort", DAGPtr::Cnst(_) => "Cnst", + DAGPtr::App(_) => "App", DAGPtr::Fun(_) => "Fun", DAGPtr::Pi(_) => "Pi", + DAGPtr::Let(_) => "Let", DAGPtr::Lit(_) => "Lit", DAGPtr::Proj(_) => "Proj", + DAGPtr::Lam(_) => "Lam", + }; + eprintln!("[whnf_dag] step={steps} head={head_variant} trail_build_start"); + } // Build trail of App nodes by walking down the fun chain let mut trail: Vec> = Vec::new(); let mut cursor = dag.head; @@ -295,12 +541,26 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { match cursor { DAGPtr::App(app) => { trail.push(app); + if trail.len() > 100_000 { + eprintln!("[whnf_dag] TRAIL OVERFLOW: trail.len()={} — possible App cycle!", trail.len()); + whnf_done(depth); return; + } cursor = unsafe { (*app.as_ptr()).fun }; }, _ => break, } } + if steps <= 5 || steps % 10_000 == 0 { + let cursor_variant = match cursor { + DAGPtr::Var(_) => "Var", DAGPtr::Sort(_) => "Sort", DAGPtr::Cnst(_) => "Cnst", + DAGPtr::App(_) => "App", DAGPtr::Fun(_) => "Fun", DAGPtr::Pi(_) => "Pi", + DAGPtr::Let(_) => "Let", DAGPtr::Lit(_) => "Lit", DAGPtr::Proj(_) => "Proj", + DAGPtr::Lam(_) => "Lam", + }; + eprintln!("[whnf_dag] step={steps} trail_len={} cursor={cursor_variant}", trail.len()); + } + match cursor { // Beta: Fun at head with args on trail DAGPtr::Fun(fun_ptr) if !trail.is_empty() => { @@ -320,23 +580,23 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { // Const: try iota, quot, nat, then delta DAGPtr::Cnst(_) => { - // Try iota, quot, nat at Expr level - if try_expr_reductions(dag, env) { + // Try iota, quot, nat + if try_dag_reductions(dag, env) { continue; } - // Try delta (definition unfolding) on DAG - if try_dag_delta(dag, &trail, env) { + // Try delta (definition unfolding) on DAG, unless no_delta + if !no_delta && try_dag_delta(dag, &trail, env) { continue; } - return; // stuck + whnf_done(depth); return; // stuck }, // Proj: try projection reduction (Expr-level fallback) DAGPtr::Proj(_) => { - if try_expr_reductions(dag, env) { + if try_dag_reductions(dag, env) { continue; } - return; // stuck + whnf_done(depth); return; // stuck }, // Sort: simplify level in place @@ -345,7 +605,7 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { let sort = &mut *sort_ptr.as_ptr(); sort.level = simplify(&sort.level); } - return; + whnf_done(depth); return; }, // Mdata: strip metadata (Expr-level fallback) @@ -353,15 +613,15 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { // Check if this is a Nat literal that could be a Nat.succ application // by trying Expr-level reductions (which handles nat ops) if !trail.is_empty() { - if try_expr_reductions(dag, env) { + if try_dag_reductions(dag, env) { continue; } } - return; + whnf_done(depth); return; }, // Everything else (Var, Pi, Lam without args, etc.): already WHNF - _ => return, + _ => { whnf_done(depth); return; }, } } } @@ -369,11 +629,7 @@ fn whnf_dag(dag: &mut DAG, env: &Env) { /// Set the DAG head after a reduction step. /// If trail is empty, the result becomes the new head. /// If trail is non-empty, splice result into the innermost remaining App. -fn set_dag_head( - dag: &mut DAG, - result: DAGPtr, - trail: &[NonNull], -) { +fn set_dag_head(dag: &mut DAG, result: DAGPtr, trail: &[NonNull]) { if trail.is_empty() { dag.head = result; } else { @@ -384,138 +640,56 @@ fn set_dag_head( } } -/// Try iota/quot/nat/projection reductions at Expr level. -/// Converts current DAG to Expr, attempts reduction, converts back if -/// successful. -fn try_expr_reductions(dag: &mut DAG, env: &Env) -> bool { - let current_expr = to_expr(&DAG { head: dag.head }); - - let (head, args) = unfold_apps(¤t_expr); +/// Try iota/quot/nat/projection reductions directly on DAG. +fn try_dag_reductions(dag: &mut DAG, env: &Env) -> bool { + let (head, args) = dag_unfold_apps(dag.head); - let reduced = match head.as_data() { - ExprData::Const(name, levels, _) => { - // Try iota (recursor) reduction - if let Some(result) = try_reduce_rec(name, levels, &args, env) { + let reduced = match head { + DAGPtr::Cnst(cnst) => unsafe { + let cnst_ref = &*cnst.as_ptr(); + if let Some(result) = + try_reduce_rec_dag(&cnst_ref.name, &cnst_ref.levels, &args, env) + { Some(result) - } - // Try quotient reduction - else if let Some(result) = try_reduce_quot(name, &args, env) { + } else if let Some(result) = + try_reduce_quot_dag(&cnst_ref.name, &args, env) + { Some(result) - } - // Try nat reduction - else if let Some(result) = - try_reduce_nat(¤t_expr, env) + } else if let Some(result) = + try_reduce_native_dag(&cnst_ref.name, &args) + { + Some(result) + } else if let Some(result) = + try_reduce_nat_dag(&cnst_ref.name, &args, env) { Some(result) } else { None } }, - ExprData::Proj(type_name, idx, structure, _) => { - reduce_proj(type_name, idx, structure, env) - .map(|result| foldl_apps(result, args.into_iter())) - }, - ExprData::Mdata(_, inner, _) => { - Some(foldl_apps(inner.clone(), args.into_iter())) + DAGPtr::Proj(proj) => unsafe { + let proj_ref = &*proj.as_ptr(); + reduce_proj_dag(&proj_ref.type_name, &proj_ref.idx, proj_ref.expr, env) + .map(|result| dag_foldl_apps(result, &args)) }, _ => None, }; - if let Some(result_expr) = reduced { - let result_dag = from_expr(&result_expr); - dag.head = result_dag.head; + if let Some(result) = reduced { + dag.head = result; true } else { false } } -/// Try delta (definition) unfolding on DAG. -/// Looks up the constant, substitutes universe levels in the definition body, -/// converts it to a DAG, and splices it into the current DAG. -fn try_dag_delta( - dag: &mut DAG, - trail: &[NonNull], - env: &Env, -) -> bool { - // Extract constant info from head - let cnst_ref = match dag_head_past_trail(dag, trail) { - DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, - _ => return false, - }; - - let ci = match env.get(&cnst_ref.name) { - Some(c) => c, - None => return false, - }; - let (def_params, def_value) = match ci { - ConstantInfo::DefnInfo(d) - if d.hints != ReducibilityHints::Opaque => - { - (&d.cnst.level_params, &d.value) - }, - _ => return false, - }; - - if cnst_ref.levels.len() != def_params.len() { - return false; - } - - // Substitute levels at Expr level, then convert to DAG - let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); - let body_dag = from_expr(&val); - - // Splice body into the working DAG - set_dag_head(dag, body_dag.head, trail); - true -} - -/// Get the head node past the trail (the non-App node at the bottom). -fn dag_head_past_trail( - dag: &DAG, - trail: &[NonNull], -) -> DAGPtr { - if trail.is_empty() { - dag.head - } else { - unsafe { (*trail.last().unwrap().as_ptr()).fun } - } -} - -/// Try to unfold a definition at the head. -pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { - let (head, args) = unfold_apps(e); - let (name, levels) = match head.as_data() { - ExprData::Const(name, levels, _) => (name, levels), - _ => return None, - }; - - let ci = env.get(name)?; - let (def_params, def_value) = match ci { - ConstantInfo::DefnInfo(d) => { - if d.hints == ReducibilityHints::Opaque { - return None; - } - (&d.cnst.level_params, &d.value) - }, - _ => return None, - }; - - if levels.len() != def_params.len() { - return None; - } - - let val = subst_expr_levels(def_value, def_params, levels); - Some(foldl_apps(val, args.into_iter())) -} - -/// Try to reduce a recursor application (iota reduction). -fn try_reduce_rec( +/// Try to reduce a recursor application (iota reduction) on DAG. +fn try_reduce_rec_dag( name: &Name, levels: &[Level], - args: &[Expr], + args: &[DAGPtr], env: &Env, -) -> Option { +) -> Option { let ci = env.get(name)?; let rec = match ci { ConstantInfo::RecInfo(r) => r, @@ -529,150 +703,104 @@ fn try_reduce_rec( let major = args.get(major_idx)?; - // WHNF the major premise - let major_whnf = whnf(major, env); - - // Handle nat literal → constructor - let major_ctor = match major_whnf.as_data() { - ExprData::Lit(Literal::NatVal(n), _) => nat_lit_to_constructor(n), - _ => major_whnf.clone(), + // WHNF the major premise directly on the DAG + let mut major_dag = DAG { head: *major }; + whnf_dag(&mut major_dag, env, false); + + // Decompose the major premise into (ctor_head, ctor_args) at DAG level. + // Handle nat literal → constructor form as DAG nodes directly. + let (ctor_head, ctor_args) = match major_dag.head { + DAGPtr::Lit(lit) => unsafe { + match &(*lit.as_ptr()).val { + Literal::NatVal(n) => { + if n.0 == BigUint::ZERO { + let zero = DAGPtr::Cnst(alloc_val(Cnst { + name: mk_name2("Nat", "zero"), + levels: vec![], + parents: None, + })); + (zero, vec![]) + } else { + let pred = Nat(n.0.clone() - BigUint::from(1u64)); + let succ = DAGPtr::Cnst(alloc_val(Cnst { + name: mk_name2("Nat", "succ"), + levels: vec![], + parents: None, + })); + let pred_lit = nat_lit_dag(pred); + (succ, vec![pred_lit]) + } + }, + _ => return None, + } + }, + _ => dag_unfold_apps(major_dag.head), }; - let (ctor_head, ctor_args) = unfold_apps(&major_ctor); - - // Find the matching rec rule - let ctor_name = match ctor_head.as_data() { - ExprData::Const(name, _, _) => name, + // Find the matching rec rule by reading ctor name from DAG head + let ctor_name = match ctor_head { + DAGPtr::Cnst(cnst) => unsafe { &(*cnst.as_ptr()).name }, _ => return None, }; - let rule = rec.rules.iter().find(|r| &r.ctor == ctor_name)?; + let rule = rec.rules.iter().find(|r| r.ctor == *ctor_name)?; let n_fields = rule.n_fields.to_u64().unwrap() as usize; let num_params = rec.num_params.to_u64().unwrap() as usize; let num_motives = rec.num_motives.to_u64().unwrap() as usize; let num_minors = rec.num_minors.to_u64().unwrap() as usize; - // The constructor args may have extra params for nested inductives - let ctor_args_wo_params = - if ctor_args.len() >= n_fields { - &ctor_args[ctor_args.len() - n_fields..] - } else { - return None; - }; - - // Substitute universe levels in the rule's RHS - let rhs = subst_expr_levels( - &rule.rhs, - &rec.cnst.level_params, - levels, - ); - - // Apply: params, motives, minors - let prefix_count = num_params + num_motives + num_minors; - let mut result = rhs; - for arg in args.iter().take(prefix_count) { - result = Expr::app(result, arg.clone()); - } - - // Apply constructor fields - for arg in ctor_args_wo_params { - result = Expr::app(result, arg.clone()); - } - - // Apply remaining args after major - for arg in args.iter().skip(major_idx + 1) { - result = Expr::app(result, arg.clone()); + if ctor_args.len() < n_fields { + return None; } + let ctor_fields = &ctor_args[ctor_args.len() - n_fields..]; - Some(result) -} - -/// Convert a Nat literal to its constructor form. -fn nat_lit_to_constructor(n: &Nat) -> Expr { - if n.0 == BigUint::ZERO { - Expr::cnst(mk_name2("Nat", "zero"), vec![]) - } else { - let pred = Nat(n.0.clone() - BigUint::from(1u64)); - let pred_expr = Expr::lit(Literal::NatVal(pred)); - Expr::app(Expr::cnst(mk_name2("Nat", "succ"), vec![]), pred_expr) - } -} + // Build RHS as DAG: from_expr(subst_expr_levels(rule.rhs, ...)) once + // (unavoidable — rule RHS is stored as Expr in Env) + let rhs_expr = subst_expr_levels(&rule.rhs, &rec.cnst.level_params, levels); + let rhs_dag = from_expr(&rhs_expr); -/// Convert a string literal to its constructor form: -/// `"hello"` → `String.mk (List.cons 'h' (List.cons 'e' ... List.nil))` -/// where chars are represented as `Char.ofNat n`. -fn string_lit_to_constructor(s: &str) -> Expr { - let list_name = Name::str(Name::anon(), "List".into()); - let char_name = Name::str(Name::anon(), "Char".into()); - let char_type = Expr::cnst(char_name.clone(), vec![]); - - // Build the list from right to left - // List.nil.{0} : List Char - let nil = Expr::app( - Expr::cnst( - Name::str(list_name.clone(), "nil".into()), - vec![Level::succ(Level::zero())], - ), - char_type.clone(), - ); - - let result = s.chars().rev().fold(nil, |acc, c| { - let char_val = Expr::app( - Expr::cnst(Name::str(char_name.clone(), "ofNat".into()), vec![]), - Expr::lit(Literal::NatVal(Nat::from(c as u64))), - ); - // List.cons.{0} Char char_val acc - Expr::app( - Expr::app( - Expr::app( - Expr::cnst( - Name::str(list_name.clone(), "cons".into()), - vec![Level::succ(Level::zero())], - ), - char_type.clone(), - ), - char_val, - ), - acc, - ) - }); + // Collect all args at DAG level: params+motives+minors, ctor_fields, rest + let prefix_count = num_params + num_motives + num_minors; + let mut all_args: Vec = + Vec::with_capacity(prefix_count + n_fields + args.len() - major_idx - 1); + all_args.extend_from_slice(&args[..prefix_count]); + all_args.extend_from_slice(ctor_fields); + all_args.extend_from_slice(&args[major_idx + 1..]); - // String.mk list - Expr::app( - Expr::cnst( - Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), - vec![], - ), - result, - ) + Some(dag_foldl_apps(rhs_dag.head, &all_args)) } -/// Try to reduce a projection. -fn reduce_proj( +/// Try to reduce a projection on DAG. +fn reduce_proj_dag( _type_name: &Name, idx: &Nat, - structure: &Expr, + structure: DAGPtr, env: &Env, -) -> Option { - let structure_whnf = whnf(structure, env); - - // Handle string literal → constructor - let structure_ctor = match structure_whnf.as_data() { - ExprData::Lit(Literal::StrVal(s), _) => { - string_lit_to_constructor(s) +) -> Option { + // WHNF the structure directly on the DAG + let mut struct_dag = DAG { head: structure }; + whnf_dag(&mut struct_dag, env, false); + + // Handle string literal → constructor form at DAG level + let struct_whnf = match struct_dag.head { + DAGPtr::Lit(lit) => unsafe { + match &(*lit.as_ptr()).val { + Literal::StrVal(s) => string_lit_to_dag_ctor(s), + _ => struct_dag.head, + } }, - _ => structure_whnf, + _ => struct_dag.head, }; - let (ctor_head, ctor_args) = unfold_apps(&structure_ctor); + // Decompose at DAG level + let (ctor_head, ctor_args) = dag_unfold_apps(struct_whnf); - let ctor_name = match ctor_head.as_data() { - ExprData::Const(name, _, _) => name, + let ctor_name = match ctor_head { + DAGPtr::Cnst(cnst) => unsafe { &(*cnst.as_ptr()).name }, _ => return None, }; - // Look up constructor to get num_params let ci = env.get(ctor_name)?; let num_params = match ci { ConstantInfo::CtorInfo(c) => c.num_params.to_u64().unwrap() as usize, @@ -680,15 +808,15 @@ fn reduce_proj( }; let field_idx = num_params + idx.to_u64().unwrap() as usize; - ctor_args.get(field_idx).cloned() + ctor_args.get(field_idx).copied() } -/// Try to reduce a quotient operation. -fn try_reduce_quot( +/// Try to reduce a quotient operation on DAG. +fn try_reduce_quot_dag( name: &Name, - args: &[Expr], + args: &[DAGPtr], env: &Env, -) -> Option { +) -> Option { let ci = env.get(name)?; let kind = match ci { ConstantInfo::QuotInfo(q) => &q.kind, @@ -702,33 +830,304 @@ fn try_reduce_quot( }; let qmk = args.get(qmk_idx)?; - let qmk_whnf = whnf(qmk, env); - // Check that the head is Quot.mk - let (qmk_head, _) = unfold_apps(&qmk_whnf); - match qmk_head.as_data() { - ExprData::Const(n, _, _) if *n == mk_name2("Quot", "mk") => {}, + // WHNF the Quot.mk arg directly on the DAG + let mut qmk_dag = DAG { head: *qmk }; + whnf_dag(&mut qmk_dag, env, false); + + // Check that the head is Quot.mk at DAG level + let (qmk_head, _) = dag_unfold_apps(qmk_dag.head); + match qmk_head { + DAGPtr::Cnst(cnst) => unsafe { + if (*cnst.as_ptr()).name != mk_name2("Quot", "mk") { + return None; + } + }, _ => return None, } let f = args.get(3)?; - // Extract the argument of Quot.mk - let qmk_arg = match qmk_whnf.as_data() { - ExprData::App(_, arg, _) => arg, + // Extract the argument of Quot.mk (the outermost App's arg) + let qmk_arg = match qmk_dag.head { + DAGPtr::App(app) => unsafe { (*app.as_ptr()).arg }, _ => return None, }; - let mut result = Expr::app(f.clone(), qmk_arg.clone()); - for arg in args.iter().skip(rest_idx) { - result = Expr::app(result, arg.clone()); + // Build result directly at DAG level: f qmk_arg rest_args... + let mut result_args = Vec::with_capacity(1 + args.len() - rest_idx); + result_args.push(qmk_arg); + result_args.extend_from_slice(&args[rest_idx..]); + Some(dag_foldl_apps(*f, &result_args)) +} + +/// Try to reduce `Lean.reduceBool` / `Lean.reduceNat` on DAG. +pub(crate) fn try_reduce_native_dag(name: &Name, args: &[DAGPtr]) -> Option { + if args.len() != 1 { + return None; + } + let reduce_bool = mk_name2("Lean", "reduceBool"); + let reduce_nat = mk_name2("Lean", "reduceNat"); + if *name == reduce_bool || *name == reduce_nat { + Some(args[0]) + } else { + None } +} - Some(result) +/// Try to reduce nat operations on DAG. +pub(crate) fn try_reduce_nat_dag( + name: &Name, + args: &[DAGPtr], + env: &Env, +) -> Option { + match args.len() { + 1 => { + if *name == mk_name2("Nat", "succ") { + // WHNF the arg directly on the DAG + let mut arg_dag = DAG { head: args[0] }; + whnf_dag(&mut arg_dag, env, false); + let n = get_nat_value_dag(arg_dag.head)?; + let result = alloc_val(LitNode { + val: Literal::NatVal(Nat(n + BigUint::from(1u64))), + parents: None, + }); + Some(DAGPtr::Lit(result)) + } else { + None + } + }, + 2 => { + // WHNF both args directly on the DAG + let mut a_dag = DAG { head: args[0] }; + whnf_dag(&mut a_dag, env, false); + let mut b_dag = DAG { head: args[1] }; + whnf_dag(&mut b_dag, env, false); + let a = get_nat_value_dag(a_dag.head)?; + let b = get_nat_value_dag(b_dag.head)?; + + if *name == mk_name2("Nat", "add") { + Some(nat_lit_dag(Nat(a + b))) + } else if *name == mk_name2("Nat", "sub") { + Some(nat_lit_dag(Nat(if a >= b { a - b } else { BigUint::ZERO }))) + } else if *name == mk_name2("Nat", "mul") { + Some(nat_lit_dag(Nat(a * b))) + } else if *name == mk_name2("Nat", "div") { + Some(nat_lit_dag(Nat(if b == BigUint::ZERO { + BigUint::ZERO + } else { + a / b + }))) + } else if *name == mk_name2("Nat", "mod") { + Some(nat_lit_dag(Nat(if b == BigUint::ZERO { a } else { a % b }))) + } else if *name == mk_name2("Nat", "beq") { + Some(bool_to_dag(a == b)) + } else if *name == mk_name2("Nat", "ble") { + Some(bool_to_dag(a <= b)) + } else if *name == mk_name2("Nat", "pow") { + let exp = u32::try_from(&b).unwrap_or(u32::MAX); + Some(nat_lit_dag(Nat(a.pow(exp)))) + } else if *name == mk_name2("Nat", "land") { + Some(nat_lit_dag(Nat(a & b))) + } else if *name == mk_name2("Nat", "lor") { + Some(nat_lit_dag(Nat(a | b))) + } else if *name == mk_name2("Nat", "xor") { + Some(nat_lit_dag(Nat(a ^ b))) + } else if *name == mk_name2("Nat", "shiftLeft") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(nat_lit_dag(Nat(a << shift))) + } else if *name == mk_name2("Nat", "shiftRight") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(nat_lit_dag(Nat(a >> shift))) + } else if *name == mk_name2("Nat", "blt") { + Some(bool_to_dag(a < b)) + } else { + None + } + }, + _ => None, + } +} + +/// Extract a nat value from a DAGPtr (analog of get_nat_value_expr). +fn get_nat_value_dag(ptr: DAGPtr) -> Option { + unsafe { + match ptr { + DAGPtr::Lit(lit) => match &(*lit.as_ptr()).val { + Literal::NatVal(n) => Some(n.0.clone()), + _ => None, + }, + DAGPtr::Cnst(cnst) => { + if (*cnst.as_ptr()).name == mk_name2("Nat", "zero") { + Some(BigUint::ZERO) + } else { + None + } + }, + _ => None, + } + } +} + +/// Allocate a Nat literal DAG node. +pub(crate) fn nat_lit_dag(n: Nat) -> DAGPtr { + DAGPtr::Lit(alloc_val(LitNode { val: Literal::NatVal(n), parents: None })) +} + +/// Convert a bool to a DAG constant (Bool.true / Bool.false). +fn bool_to_dag(b: bool) -> DAGPtr { + let name = + if b { mk_name2("Bool", "true") } else { mk_name2("Bool", "false") }; + DAGPtr::Cnst(alloc_val(Cnst { name, levels: vec![], parents: None })) +} + +/// Build `String.mk (List.cons (Char.ofNat n1) (List.cons ... List.nil))` +/// entirely at the DAG level (no Expr round-trip). +fn string_lit_to_dag_ctor(s: &str) -> DAGPtr { + let list_name = Name::str(Name::anon(), "List".into()); + let char_name = Name::str(Name::anon(), "Char".into()); + let char_type = DAGPtr::Cnst(alloc_val(Cnst { + name: char_name.clone(), + levels: vec![], + parents: None, + })); + let nil = DAGPtr::App(alloc_app( + DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(list_name.clone(), "nil".into()), + levels: vec![Level::succ(Level::zero())], + parents: None, + })), + char_type, + None, + )); + let list = s.chars().rev().fold(nil, |acc, c| { + let of_nat = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(char_name.clone(), "ofNat".into()), + levels: vec![], + parents: None, + })); + let char_val = + DAGPtr::App(alloc_app(of_nat, nat_lit_dag(Nat::from(c as u64)), None)); + let char_type_copy = DAGPtr::Cnst(alloc_val(Cnst { + name: char_name.clone(), + levels: vec![], + parents: None, + })); + let cons = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(list_name.clone(), "cons".into()), + levels: vec![Level::succ(Level::zero())], + parents: None, + })); + let c1 = DAGPtr::App(alloc_app(cons, char_type_copy, None)); + let c2 = DAGPtr::App(alloc_app(c1, char_val, None)); + DAGPtr::App(alloc_app(c2, acc, None)) + }); + let string_mk = DAGPtr::Cnst(alloc_val(Cnst { + name: Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), + levels: vec![], + parents: None, + })); + DAGPtr::App(alloc_app(string_mk, list, None)) +} + +/// Try delta (definition) unfolding on DAG. +/// Looks up the constant, substitutes universe levels in the definition body, +/// converts it to a DAG, and splices it into the current DAG. +fn try_dag_delta(dag: &mut DAG, trail: &[NonNull], env: &Env) -> bool { + // Extract constant info from head + let cnst_ref = match dag_head_past_trail(dag, trail) { + DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, + _ => return false, + }; + + let ci = match env.get(&cnst_ref.name) { + Some(c) => c, + None => return false, + }; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) if d.hints != ReducibilityHints::Opaque => { + (&d.cnst.level_params, &d.value) + }, + _ => return false, + }; + + if cnst_ref.levels.len() != def_params.len() { + return false; + } + + eprintln!("[try_dag_delta] unfolding: {}", cnst_ref.name.pretty()); + + // Substitute levels at Expr level, then convert to DAG + let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); + eprintln!("[try_dag_delta] subst done, calling from_expr"); + let body_dag = from_expr(&val); + eprintln!("[try_dag_delta] from_expr done, calling set_dag_head"); + + // Splice body into the working DAG + set_dag_head(dag, body_dag.head, trail); + eprintln!("[try_dag_delta] set_dag_head done"); + true +} + +/// Get the head node past the trail (the non-App node at the bottom). +fn dag_head_past_trail(dag: &DAG, trail: &[NonNull]) -> DAGPtr { + if trail.is_empty() { + dag.head + } else { + unsafe { (*trail.last().unwrap().as_ptr()).fun } + } +} + +/// Try to unfold a definition at the head. +pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + let (name, levels) = match head.as_data() { + ExprData::Const(name, levels, _) => (name, levels), + _ => return None, + }; + + let ci = env.get(name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + _ => return None, + }; + + if levels.len() != def_params.len() { + return None; + } + + let val = subst_expr_levels(def_value, def_params, levels); + Some(foldl_apps(val, args.into_iter())) +} + +/// Try to reduce `Lean.reduceBool` / `Lean.reduceNat`. +/// +/// These are opaque constants with special kernel reduction rules. In the Lean 4 +/// kernel they evaluate their argument using compiled native code. Since both are +/// semantically identity functions (`fun b => b` / `fun n => n`), we simply +/// return the argument and let the WHNF loop continue reducing it via our +/// existing efficient paths (e.g. `try_reduce_nat` handles `Nat.ble` etc. in O(1)). +pub(crate) fn try_reduce_native(name: &Name, args: &[Expr]) -> Option { + if args.len() != 1 { + return None; + } + let reduce_bool = mk_name2("Lean", "reduceBool"); + let reduce_nat = mk_name2("Lean", "reduceNat"); + if *name == reduce_bool || *name == reduce_nat { + Some(args[0].clone()) + } else { + None + } } /// Try to reduce nat operations. -fn try_reduce_nat(e: &Expr, env: &Env) -> Option { +pub(crate) fn try_reduce_nat(e: &Expr, env: &Env) -> Option { if has_fvars(e) { return None; } @@ -818,11 +1217,8 @@ fn get_nat_value(e: &Expr) -> Option { } fn bool_to_expr(b: bool) -> Option { - let name = if b { - mk_name2("Bool", "true") - } else { - mk_name2("Bool", "false") - }; + let name = + if b { mk_name2("Bool", "true") } else { mk_name2("Bool", "false") }; Some(Expr::cnst(name, vec![])) } @@ -865,12 +1261,8 @@ mod tests { BinderInfo::Default, ); let result = inst(&body, &[nat_zero()]); - let expected = Expr::lam( - Name::anon(), - nat_type(), - nat_zero(), - BinderInfo::Default, - ); + let expected = + Expr::lam(Name::anon(), nat_type(), nat_zero(), BinderInfo::Default); assert_eq!(result, expected); } @@ -927,11 +1319,7 @@ mod tests { env.insert( n.clone(), ConstantInfo::DefnInfo(DefinitionVal { - cnst: ConstantVal { - name: n.clone(), - level_params: vec![], - typ, - }, + cnst: ConstantVal { name: n.clone(), level_params: vec![], typ }, value, hints: ReducibilityHints::Abbrev, safety: DefinitionSafety::Safe, @@ -1198,7 +1586,10 @@ mod tests { fn test_nat_shift_right() { let env = Env::default(); let e = Expr::app( - Expr::app(Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), nat_lit(256)), + Expr::app( + Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), + nat_lit(256), + ), nat_lit(4), ); assert_eq!(whnf(&e, &env), nat_lit(16)); @@ -1336,12 +1727,8 @@ mod tests { #[test] fn test_whnf_pi_unchanged() { let env = Env::default(); - let e = Expr::all( - mk_name("x"), - nat_type(), - nat_type(), - BinderInfo::Default, - ); + let e = + Expr::all(mk_name("x"), nat_type(), nat_type(), BinderInfo::Default); let result = whnf(&e, &env); assert_eq!(result, e); } @@ -1417,4 +1804,371 @@ mod tests { let result = subst_expr_levels(&e, &[u_name], &[Level::zero()]); assert_eq!(result, Expr::sort(Level::zero())); } + + // ========================================================================== + // Nat.rec on large literals — reproduces the hang + // ========================================================================== + + /// Build a minimal env with Nat, Nat.zero, Nat.succ, and Nat.rec. + fn mk_nat_rec_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + let zero_name = mk_name2("Nat", "zero"); + let succ_name = mk_name2("Nat", "succ"); + let rec_name = mk_name2("Nat", "rec"); + + // Nat : Sort 1 + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![zero_name.clone(), succ_name.clone()], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Nat.zero : Nat + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: nat_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + // Nat.succ : Nat → Nat + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: nat_name.clone(), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Nat.rec.{u} : (motive : Nat → Sort u) → motive Nat.zero → + // ((n : Nat) → motive n → motive (Nat.succ n)) → (t : Nat) → motive t + // Rules: + // Nat.rec m z s Nat.zero => z + // Nat.rec m z s (Nat.succ n) => s n (Nat.rec m z s n) + let u = mk_name("u"); + env.insert( + rec_name.clone(), + ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: rec_name.clone(), + level_params: vec![u.clone()], + typ: Expr::sort(Level::param(u.clone())), // placeholder + }, + all: vec![nat_name], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + // Nat.rec m z s Nat.zero => z + RecursorRule { + ctor: zero_name, + n_fields: Nat::from(0u64), + // RHS is just bvar(1) = z (the zero minor) + // After substitution: Nat.rec m z s Nat.zero + // => rule.rhs applied to [m, z, s] + // => z + rhs: Expr::bvar(Nat::from(1u64)), + }, + // Nat.rec m z s (Nat.succ n) => s n (Nat.rec m z s n) + RecursorRule { + ctor: succ_name, + n_fields: Nat::from(1u64), + // RHS = fun n => s n (Nat.rec m z s n) + // But actually the rule rhs receives [m, z, s] then [n] as args + // rhs = bvar(0) = s, applied to the field n + // Actually the recursor rule rhs is applied as: + // rhs m z s + // For Nat.succ with 1 field (the predecessor n): + // rhs m z s n => s n (Nat.rec.{u} m z s n) + // So rhs = lam receiving params+minors then fields: + // Actually, rhs is an expression that gets applied to + // [params..., motives..., minors..., fields...] + // For Nat.rec: 0 params, 1 motive, 2 minors, 1 field + // So rhs gets applied to: m z s n + // We want: s n (Nat.rec.{u} m z s n) + // As a closed term using bvars after inst: + // After being applied to m z s n: + // bvar(3) = m, bvar(2) = z, bvar(1) = s, bvar(0) = n + // We want: s n (Nat.rec.{u} m z s n) + // = app(app(bvar(1), bvar(0)), + // app(app(app(app(Nat.rec.{u}, bvar(3)), bvar(2)), bvar(1)), bvar(0))) + // But wait, rhs is not a lambda - it gets args applied directly. + // The rhs just receives the args via Expr::app in try_reduce_rec. + // So rhs should be a term that, after being applied to m, z, s, n, + // produces s n (Nat.rec m z s n). + // + // Simplest: rhs is a 4-arg lambda + rhs: Expr::lam( + mk_name("m"), + Expr::sort(Level::zero()), // placeholder type + Expr::lam( + mk_name("z"), + Expr::sort(Level::zero()), + Expr::lam( + mk_name("s"), + Expr::sort(Level::zero()), + Expr::lam( + mk_name("n"), + nat_type(), + // body: s n (Nat.rec.{u} m z s n) + // bvar(3)=m, bvar(2)=z, bvar(1)=s, bvar(0)=n + Expr::app( + Expr::app( + Expr::bvar(Nat::from(1u64)), // s + Expr::bvar(Nat::from(0u64)), // n + ), + Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + rec_name.clone(), + vec![Level::param(u.clone())], + ), + Expr::bvar(Nat::from(3u64)), // m + ), + Expr::bvar(Nat::from(2u64)), // z + ), + Expr::bvar(Nat::from(1u64)), // s + ), + Expr::bvar(Nat::from(0u64)), // n + ), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + }, + ], + k: false, + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn test_nat_rec_small_literal() { + // Nat.rec (fun _ => Nat) 0 (fun n _ => Nat.succ n) 3 + // Should reduce to 3 (identity via recursion) + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive, + ), + zero_case, + ), + succ_case, + ), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(3)); + } + + #[test] + fn test_nat_rec_large_literal_hangs() { + // This test demonstrates the O(n) recursor peeling issue. + // Nat.rec on 65536 (2^16) — would take 65536 recursive steps. + // We use a timeout-style approach: just verify it works for small n + // and document that large n hangs. + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + // Test with 100 — should be fast enough + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive.clone(), + ), + zero_case.clone(), + ), + succ_case.clone(), + ), + nat_lit(100), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(100)); + + // nat_lit(65536) would hang here — that's the bug to fix + } + + // ========================================================================== + // try_reduce_native tests (Lean.reduceBool / Lean.reduceNat) + // ========================================================================== + + #[test] + fn test_reduce_bool_true() { + // Lean.reduceBool Bool.true → Bool.true + let args = vec![Expr::cnst(mk_name2("Bool", "true"), vec![])]; + let result = try_reduce_native(&mk_name2("Lean", "reduceBool"), &args); + assert_eq!(result, Some(Expr::cnst(mk_name2("Bool", "true"), vec![]))); + } + + #[test] + fn test_reduce_nat_literal() { + // Lean.reduceNat (lit 42) → lit 42 + let args = vec![nat_lit(42)]; + let result = try_reduce_native(&mk_name2("Lean", "reduceNat"), &args); + assert_eq!(result, Some(nat_lit(42))); + } + + #[test] + fn test_reduce_bool_with_nat_ble() { + // Lean.reduceBool (Nat.ble 3 5) → passes through the arg + // WHNF will then reduce Nat.ble 3 5 → Bool.true + let ble_expr = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(3)), + nat_lit(5), + ); + let args = vec![ble_expr.clone()]; + let result = try_reduce_native(&mk_name2("Lean", "reduceBool"), &args); + assert_eq!(result, Some(ble_expr)); + + // Verify WHNF continues reducing the returned argument + let env = Env::default(); + let full_result = whnf(&result.unwrap(), &env); + assert_eq!(full_result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_reduce_native_wrong_name() { + let args = vec![nat_lit(1)]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "other"), &args), None); + } + + #[test] + fn test_reduce_native_wrong_arity() { + // 0 args + let empty: Vec = vec![]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "reduceBool"), &empty), None); + // 2 args + let two = vec![nat_lit(1), nat_lit(2)]; + assert_eq!(try_reduce_native(&mk_name2("Lean", "reduceBool"), &two), None); + } + + #[test] + fn test_nat_rec_65536() { + let env = mk_nat_rec_env(); + let motive = + Expr::lam(mk_name("_"), nat_type(), nat_type(), BinderInfo::Default); + let zero_case = nat_lit(0); + let succ_case = Expr::lam( + mk_name("n"), + nat_type(), + Expr::lam( + mk_name("_"), + nat_type(), + Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + Expr::bvar(Nat::from(1u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let e = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Nat", "rec"), + vec![Level::succ(Level::zero())], + ), + motive, + ), + zero_case, + ), + succ_case, + ), + nat_lit(65536), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(65536)); + } } diff --git a/src/lean/ffi.rs b/src/lean/ffi.rs index 07003a57..40553a06 100644 --- a/src/lean/ffi.rs +++ b/src/lean/ffi.rs @@ -6,6 +6,7 @@ pub mod lean_env; // Modular FFI structure pub mod builder; // IxEnvBuilder struct +pub mod check; // Kernel type-checking: rs_check_env pub mod compile; // Compilation: rs_compile_env_full, rs_compile_phases, etc. pub mod graph; // Graph/SCC: rs_build_ref_graph, rs_compute_sccs pub mod ix; // Ix types: Name, Level, Expr, ConstantInfo, Environment diff --git a/src/lean/ffi/check.rs b/src/lean/ffi/check.rs new file mode 100644 index 00000000..01e69cc7 --- /dev/null +++ b/src/lean/ffi/check.rs @@ -0,0 +1,182 @@ +//! FFI bridge for the Rust kernel type-checker. +//! +//! Provides `extern "C"` function callable from Lean via `@[extern]`: +//! - `rs_check_env`: type-check all declarations in a Lean environment + +use std::ffi::{CString, c_void}; + +use super::builder::LeanBuildCache; +use super::ffi_io_guard; +use super::ix::expr::build_expr; +use super::ix::name::build_name; +use super::lean_env::lean_ptr_to_env; +use crate::ix::env::{ConstantInfo, Name}; +use crate::ix::kernel::dag_tc::{DagTypeChecker, dag_check_env}; +use crate::ix::kernel::error::TcError; +use crate::lean::string::LeanStringObject; +use crate::lean::{ + as_ref_unsafe, lean_alloc_array, lean_alloc_ctor, lean_array_set_core, + lean_ctor_set, lean_ctor_set_uint64, lean_io_result_mk_ok, lean_mk_string, +}; + +/// Build a Lean `Ix.Kernel.CheckError` constructor from a Rust `TcError`. +/// +/// Constructor tags (must match the Lean `inductive CheckError`): +/// - 0: typeExpected (2 obj: expr, inferred) +/// - 1: functionExpected (2 obj: expr, inferred) +/// - 2: typeMismatch (3 obj: expected, found, expr) +/// - 3: defEqFailure (2 obj: lhs, rhs) +/// - 4: unknownConst (1 obj: name) +/// - 5: duplicateUniverse (1 obj: name) +/// - 6: freeBoundVariable (0 obj + 8 byte scalar: idx) +/// - 7: kernelException (1 obj: msg) +unsafe fn build_check_error( + cache: &mut LeanBuildCache, + err: &TcError, +) -> *mut c_void { + unsafe { + match err { + TcError::TypeExpected { expr, inferred } => { + let obj = lean_alloc_ctor(0, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, expr)); + lean_ctor_set(obj, 1, build_expr(cache, inferred)); + obj + }, + TcError::FunctionExpected { expr, inferred } => { + let obj = lean_alloc_ctor(1, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, expr)); + lean_ctor_set(obj, 1, build_expr(cache, inferred)); + obj + }, + TcError::TypeMismatch { expected, found, expr } => { + let obj = lean_alloc_ctor(2, 3, 0); + lean_ctor_set(obj, 0, build_expr(cache, expected)); + lean_ctor_set(obj, 1, build_expr(cache, found)); + lean_ctor_set(obj, 2, build_expr(cache, expr)); + obj + }, + TcError::DefEqFailure { lhs, rhs } => { + let obj = lean_alloc_ctor(3, 2, 0); + lean_ctor_set(obj, 0, build_expr(cache, lhs)); + lean_ctor_set(obj, 1, build_expr(cache, rhs)); + obj + }, + TcError::UnknownConst { name } => { + let obj = lean_alloc_ctor(4, 1, 0); + lean_ctor_set(obj, 0, build_name(cache, name)); + obj + }, + TcError::DuplicateUniverse { name } => { + let obj = lean_alloc_ctor(5, 1, 0); + lean_ctor_set(obj, 0, build_name(cache, name)); + obj + }, + TcError::FreeBoundVariable { idx } => { + let obj = lean_alloc_ctor(6, 0, 8); + lean_ctor_set_uint64(obj, 0, *idx); + obj + }, + TcError::KernelException { msg } => { + let c_msg = CString::new(msg.as_str()) + .unwrap_or_else(|_| CString::new("kernel exception").unwrap()); + let obj = lean_alloc_ctor(7, 1, 0); + lean_ctor_set(obj, 0, lean_mk_string(c_msg.as_ptr())); + obj + }, + } + } +} + +/// FFI function to type-check all declarations in a Lean environment using the +/// Rust kernel. Returns `IO (Array (Ix.Name × CheckError))`. +#[unsafe(no_mangle)] +pub extern "C" fn rs_check_env(env_consts_ptr: *const c_void) -> *mut c_void { + ffi_io_guard(std::panic::AssertUnwindSafe(|| { + let rust_env = lean_ptr_to_env(env_consts_ptr); + let errors = dag_check_env(&rust_env); + let mut cache = LeanBuildCache::new(); + unsafe { + let arr = lean_alloc_array(errors.len(), errors.len()); + for (i, (name, tc_err)) in errors.iter().enumerate() { + let name_obj = build_name(&mut cache, name); + let err_obj = build_check_error(&mut cache, tc_err); + let pair = lean_alloc_ctor(0, 2, 0); // Prod.mk + lean_ctor_set(pair, 0, name_obj); + lean_ctor_set(pair, 1, err_obj); + lean_array_set_core(arr, i, pair); + } + lean_io_result_mk_ok(arr) + } + })) +} + +/// Parse a dotted name string (e.g. "ISize.toInt16_ofIntLE") into a `Name`. +fn parse_name(s: &str) -> Name { + let mut name = Name::anon(); + for part in s.split('.') { + name = Name::str(name, part.to_string()); + } + name +} + +/// FFI function to type-check a single constant by name. +/// Takes the environment and a dotted name string. +/// Returns `IO (Option CheckError)` — `none` on success, `some err` on failure. +#[unsafe(no_mangle)] +pub extern "C" fn rs_check_const( + env_consts_ptr: *const c_void, + name_ptr: *const c_void, +) -> *mut c_void { + ffi_io_guard(std::panic::AssertUnwindSafe(|| { + eprintln!("[rs_check_const] entered FFI"); + let rust_env = lean_ptr_to_env(env_consts_ptr); + let name_str: &LeanStringObject = as_ref_unsafe(name_ptr.cast()); + let name = parse_name(&name_str.as_string()); + eprintln!("[rs_check_const] checking: {}", name.pretty()); + + let ci = match rust_env.get(&name) { + Some(ci) => { + match ci { + ConstantInfo::DefnInfo(d) => { + eprintln!("[rs_check_const] type: {:#?}", d.cnst.typ); + eprintln!("[rs_check_const] value: {:#?}", d.value); + eprintln!("[rs_check_const] hints: {:?}", d.hints); + }, + _ => {}, + } + ci + }, + None => { + // Return some (kernelException "not found") + let err = TcError::KernelException { + msg: format!("constant not found: {}", name.pretty()), + }; + let mut cache = LeanBuildCache::new(); + unsafe { + let err_obj = build_check_error(&mut cache, &err); + let some = lean_alloc_ctor(1, 1, 0); // Option.some + lean_ctor_set(some, 0, err_obj); + return lean_io_result_mk_ok(some); + } + }, + }; + + let mut tc = DagTypeChecker::new(&rust_env); + match tc.check_declar(ci) { + Ok(()) => unsafe { + // Option.none = ctor tag 0, 0 fields + let none = lean_alloc_ctor(0, 0, 0); + lean_io_result_mk_ok(none) + }, + Err(e) => { + let mut cache = LeanBuildCache::new(); + unsafe { + let err_obj = build_check_error(&mut cache, &e); + let some = lean_alloc_ctor(1, 1, 0); // Option.some + lean_ctor_set(some, 0, err_obj); + lean_io_result_mk_ok(some) + } + }, + } + })) +} diff --git a/src/lean/ffi/lean_env.rs b/src/lean/ffi/lean_env.rs index 3817e0e4..2562cd94 100644 --- a/src/lean/ffi/lean_env.rs +++ b/src/lean/ffi/lean_env.rs @@ -852,8 +852,10 @@ fn analyze_const_size(stt: &crate::ix::compile::CompileState, name_str: &str) { // BFS through all transitive dependencies while let Some(dep_addr) = queue.pop_front() { if let Some(dep_const) = stt.env.consts.get(&dep_addr) { - // Get the name for this dependency - let dep_name_opt = stt.env.get_name_by_addr(&dep_addr); + // Get the name for this dependency (linear scan through named entries) + let dep_name_opt = stt.env.named.iter() + .find(|entry| entry.value().addr == dep_addr) + .map(|entry| entry.key().clone()); let dep_name_str = dep_name_opt .as_ref() .map_or_else(|| format!("{:?}", dep_addr), |n| n.pretty()); From ff923998e917f5d12c67f892c133a27ae3a2d875 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 08:13:47 -0500 Subject: [PATCH 3/5] reenable printing type of erroring constants --- Ix/Kernel/Infer.lean | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 1d0b0159..0c161539 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -387,12 +387,7 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let total := items.size for h : idx in [:total] do let (addr, ci) := items[idx] - --let typ := ci.type.pp - --let val := match ci.value? with - -- | some v => s!"\n value: {v.pp}" - -- | none => "" - let (typ, val) := ("_", "_") - (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})\n type: {typ}{val}" + (← IO.getStdout).putStrLn s!" [{idx + 1}/{total}] {ci.cv.name} ({ci.kindName})" (← IO.getStdout).flush match typecheckConst kenv prims addr quotInit with | .ok () => @@ -400,6 +395,10 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) (← IO.getStdout).flush | .error e => let header := s!"constant {ci.cv.name} ({ci.kindName}, {addr})" + let typ := ci.type.pp + let val := match ci.value? with + | some v => s!"\n value: {v.pp}" + | none => "" return .error s!"{header}: {e}\n type: {typ}{val}" return .ok () From 14380d835eed56f0622b1ad324ee119462fe4800 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 08:20:21 -0500 Subject: [PATCH 4/5] move error printing to end to unhide if types are long --- Ix/Kernel/Infer.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index 0c161539..cc2d89e5 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -399,7 +399,8 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let val := match ci.value? with | some v => s!"\n value: {v.pp}" | none => "" - return .error s!"{header}: {e}\n type: {typ}{val}" + IO.println s!"type: {typ}{val}" + return .error s!"{header}: {e}" return .ok () end Ix.Kernel From c77d3096feceef297e4d438b94006d67d4e7495b Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 20 Feb 2026 11:41:48 -0500 Subject: [PATCH 5/5] correctness improvements and ST caching --- Ix/Kernel/Equal.lean | 13 +- Ix/Kernel/Eval.lean | 49 +- Ix/Kernel/Infer.lean | 266 ++++++++- Ix/Kernel/TypecheckM.lean | 82 +-- Tests/Ix/KernelTests.lean | 494 ++++++++++++++++- Tests/Ix/PP.lean | 26 +- src/ix/kernel/def_eq.rs | 12 +- src/ix/kernel/inductive.rs | 1041 +++++++++++++++++++++++++++++++++++- src/ix/kernel/level.rs | 85 +++ src/ix/kernel/tc.rs | 318 ++++++++++- src/ix/kernel/whnf.rs | 13 +- 11 files changed, 2275 insertions(+), 124 deletions(-) diff --git a/Ix/Kernel/Equal.lean b/Ix/Kernel/Equal.lean index 4f219b7c..a2e8db92 100644 --- a/Ix/Kernel/Equal.lean +++ b/Ix/Kernel/Equal.lean @@ -34,7 +34,7 @@ private def equalUnivArrays (us us' : Array (Level m)) : Bool := mutual /-- Try eta expansion for structure-like types. -/ - partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := do + partial def tryEtaStruct (lvl : Nat) (term term' : SusValue m) : TypecheckM m σ Bool := do match term'.get with | .app (.const k _ _) args _ => match (← get).typedConsts.get? k with @@ -59,7 +59,7 @@ mutual /-- Check if two suspended values are definitionally equal at the given level. Assumes both have the same type and live in the same context. -/ - partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m Bool := + partial def equal (lvl : Nat) (term term' : SusValue m) : TypecheckM m σ Bool := match term.info, term'.info with | .unit, .unit => pure true | .proof, .proof => pure true @@ -67,9 +67,10 @@ mutual if (← read).trace then dbg_trace s!"equal: {term.get.ctorName} vs {term'.get.ctorName}" -- Fast path: pointer equality on thunks if susValuePtrEq term term' then return true - -- Check equality cache + -- Check equality cache via ST.Ref let key := susValueCacheKey term term' - if let some true := (← get).equalCache.get? key then return true + let eqCache ← (← read).equalCacheRef.get + if let some true := eqCache.get? key then return true let tv := term.get let tv' := term'.get let result ← match tv, tv' with @@ -151,11 +152,11 @@ mutual dbg_trace s!"equal FALLTHROUGH at lvl={lvl}: lhs={tv.dump} rhs={tv'.dump}" pure false if result then - modify fun stt => { stt with equalCache := stt.equalCache.insert key true } + let _ ← (← read).equalCacheRef.modify fun c => c.insert key true return result /-- Check if two lists of suspended values are pointwise equal. -/ - partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m Bool := + partial def equalThunks (lvl : Nat) (vals vals' : List (SusValue m)) : TypecheckM m σ Bool := match vals, vals' with | val :: vals, val' :: vals' => do let eq ← equal lvl val val' diff --git a/Ix/Kernel/Eval.lean b/Ix/Kernel/Eval.lean index 9fa74125..eed16e52 100644 --- a/Ix/Kernel/Eval.lean +++ b/Ix/Kernel/Eval.lean @@ -35,7 +35,7 @@ def listGet? (l : List α) (n : Nat) : Option α := /-- Try to reduce a primitive operation if all arguments are available. -/ private def tryPrimOp (prims : Primitives) (addr : Address) - (args : List (SusValue m)) : TypecheckM m (Option (Value m)) := do + (args : List (SusValue m)) : TypecheckM m σ (Option (Value m)) := do -- Nat.succ: 1 arg if addr == prims.natSucc then if args.length >= 1 then @@ -78,7 +78,7 @@ private def tryPrimOp (prims : Primitives) (addr : Address) /-- Expand a string literal to its constructor form: String.mk (list-of-chars). Each character is represented as Char.ofNat n, and the list uses List.cons/List.nil at universe level 0. -/ -def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) := do +def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m σ (Value m) := do let charMkName ← lookupName prims.charMk let charName ← lookupName prims.char let listNilName ← lookupName prims.listNil @@ -105,7 +105,7 @@ def strLitToCtorVal (prims : Primitives) (s : String) : TypecheckM m (Value m) : mutual /-- Evaluate a typed expression to a value. -/ - partial def eval (t : TypedExpr m) : TypecheckM m (Value m) := withFuelCheck do + partial def eval (t : TypedExpr m) : TypecheckM m σ (Value m) := withFuelCheck do if (← read).trace then dbg_trace s!"eval: {t.body.tag}" match t.body with | .app fnc arg => do @@ -171,7 +171,7 @@ mutual pure (.app (.proj typeAddr idx ⟨ti, val⟩ typeName) [] []) | e => throw s!"Value is impossible to project: {e.ctorName}" - partial def evalTyped (t : TypedExpr m) : TypecheckM m (AddInfo (TypeInfo m) (Value m)) := do + partial def evalTyped (t : TypedExpr m) : TypecheckM m σ (AddInfo (TypeInfo m) (Value m)) := do let reducedInfo := t.info.update (← read).env.univs.toArray let value ← eval t pure ⟨reducedInfo, value⟩ @@ -180,11 +180,12 @@ mutual Theorems are treated as opaque (not unfolded) — proof irrelevance handles equality of proof terms, and this avoids deep recursion through proof bodies. Caches evaluated definition bodies to avoid redundant evaluation. -/ - partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + partial def evalConst' (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do match (← read).kenv.find? addr with | some (.defnInfo _) => - -- Check eval cache (must also match universe parameters) - if let some (cachedUnivs, cachedVal) := (← get).evalCache.get? addr then + -- Check eval cache via ST.Ref (persists across thunks) + let cache ← (← read).evalCacheRef.get + if let some (cachedUnivs, cachedVal) := cache.get? addr then if cachedUnivs == univs then return cachedVal ensureTypedConst addr match (← get).typedConsts.get? addr with @@ -192,29 +193,29 @@ mutual if part then pure (mkConst addr univs name) else let val ← withEnv (.mk [] univs.toList) (eval deref) - modify fun stt => { stt with evalCache := stt.evalCache.insert addr (univs, val) } + let _ ← (← read).evalCacheRef.modify fun c => c.insert addr (univs, val) pure val | _ => throw "Invalid const kind for evaluation" | _ => pure (mkConst addr univs name) /-- Evaluate a constant: check if it's Nat.zero, a primitive op, or unfold it. -/ - partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + partial def evalConst (addr : Address) (univs : Array (Level m)) (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do let prims := (← read).prims if addr == prims.natZero then pure (.lit (.natVal 0)) else if isPrimOp prims addr then pure (mkConst addr univs name) else evalConst' addr univs name /-- Create a suspended value from a typed expression, capturing context. -/ - partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m) (stt : TypecheckState m) : SusValue m := + partial def suspend (expr : TypedExpr m) (ctx : TypecheckCtx m σ) (stt : TypecheckState m) : SusValue m := let thunk : Thunk (Value m) := .mk fun _ => - match TypecheckM.run ctx stt (eval expr) with + match pureRunST (TypecheckM.run ctx stt (eval expr)) with | .ok a => a | .error e => .exception e let reducedInfo := expr.info.update ctx.env.univs.toArray ⟨reducedInfo, thunk⟩ /-- Apply a value to an argument. -/ - partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m (Value m) := do + partial def apply (val : AddInfo (TypeInfo m) (Value m)) (arg : SusValue m) : TypecheckM m σ (Value m) := do if (← read).trace then dbg_trace s!"apply: {val.body.ctorName}" match val.body with | .lam _ bod lamEnv _ _ => @@ -233,7 +234,7 @@ mutual /-- Apply a named constant to arguments, handling recursors, quotients, and primitives. -/ partial def applyConst (addr : Address) (univs : Array (Level m)) (arg : SusValue m) (args : List (SusValue m)) (info : TypeInfo m) (infos : List (TypeInfo m)) - (name : MetaField m Ix.Name := default) : TypecheckM m (Value m) := do + (name : MetaField m Ix.Name := default) : TypecheckM m σ (Value m) := do let prims := (← read).prims -- Try primitive operations if let some result ← tryPrimOp prims addr (arg :: args) then @@ -326,7 +327,7 @@ mutual /-- Apply a quotient to a value. -/ partial def applyQuot (_prims : Primitives) (major : SusValue m) (args : List (SusValue m)) - (reduceSize argPos : Nat) (default : Value m) : TypecheckM m (Value m) := + (reduceSize argPos : Nat) (default : Value m) : TypecheckM m σ (Value m) := let argsLength := args.length + 1 if argsLength == reduceSize then match major.get with @@ -343,7 +344,7 @@ mutual else throw s!"argsLength {argsLength} can't be greater than reduceSize {reduceSize}" /-- Convert a nat literal to Nat.succ/Nat.zero constructors. -/ - partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m (Value m) + partial def toCtorIfLit (prims : Primitives) : Value m → TypecheckM m σ (Value m) | .lit (.natVal 0) => do let name ← lookupName prims.natZero pure (Value.neu (.const prims.natZero #[] name)) @@ -357,7 +358,7 @@ end /-! ## Quoting (read-back from Value to Expr) -/ mutual - partial def quote (lvl : Nat) : Value m → TypecheckM m (Expr m) + partial def quote (lvl : Nat) : Value m → TypecheckM m σ (Expr m) | .sort univ => do let env := (← read).env pure (.sort (instBulkReduce env.univs.toArray univ)) @@ -379,14 +380,14 @@ mutual | .lit lit => pure (.lit lit) | .exception e => throw e - partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m (TypedExpr m) := do + partial def quoteTyped (lvl : Nat) (val : AddInfo (TypeInfo m) (Value m)) : TypecheckM m σ (TypedExpr m) := do pure ⟨val.info, ← quote lvl val.body⟩ - partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m (TypedExpr m) := do + partial def quoteTypedExpr (lvl : Nat) (t : TypedExpr m) (env : ValEnv m) : TypecheckM m σ (TypedExpr m) := do let e ← quoteExpr lvl t.body env pure ⟨t.info, e⟩ - partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m (Expr m) := + partial def quoteExpr (lvl : Nat) (expr : Expr m) (env : ValEnv m) : TypecheckM m σ (Expr m) := match expr with | .bvar idx _ => do match listGet? env.exprs idx with @@ -421,7 +422,7 @@ mutual pure (.proj typeAddr idx struct name) | .lit .. => pure expr - partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m (Expr m) + partial def quoteNeutral (lvl : Nat) : Neutral m → TypecheckM m σ (Expr m) | .fvar idx name => do pure (.bvar (lvl - idx - 1) name) | .const addr univs name => do @@ -501,22 +502,22 @@ partial def foldLiterals (prims : Primitives) : Expr m → Expr m /-- Pretty-print a value by quoting it back to an Expr, then using Expr.pp. Folds Nat/String constructor chains back to literals for readability. -/ -partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m String := do +partial def ppValue (lvl : Nat) (v : Value m) : TypecheckM m σ String := do let expr ← quote lvl v let expr := foldLiterals (← read).prims expr return expr.pp /-- Pretty-print a suspended value. -/ -partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m String := +partial def ppSusValue (lvl : Nat) (sv : SusValue m) : TypecheckM m σ String := ppValue lvl sv.get /-- Pretty-print a value, falling back to the shallow summary on error. -/ -partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m String := do +partial def tryPpValue (lvl : Nat) (v : Value m) : TypecheckM m σ String := do try ppValue lvl v catch _ => return v.summary /-- Apply a value to a list of arguments. -/ -def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m (Value m) := +def applyType (v : Value m) (args : List (SusValue m)) : TypecheckM m σ (Value m) := match args with | [] => pure v | arg :: rest => do diff --git a/Ix/Kernel/Infer.lean b/Ix/Kernel/Infer.lean index cc2d89e5..0dacf465 100644 --- a/Ix/Kernel/Infer.lean +++ b/Ix/Kernel/Infer.lean @@ -7,22 +7,102 @@ import Ix.Kernel.Equal namespace Ix.Kernel +/-! ## Inductive validation helpers -/ + +/-- Check if an expression mentions a constant at the given address. -/ +partial def exprMentionsConst (e : Expr m) (addr : Address) : Bool := + match e with + | .const a _ _ => a == addr + | .app fn arg => exprMentionsConst fn addr || exprMentionsConst arg addr + | .lam ty body _ _ => exprMentionsConst ty addr || exprMentionsConst body addr + | .forallE ty body _ _ => exprMentionsConst ty addr || exprMentionsConst body addr + | .letE ty val body _ => exprMentionsConst ty addr || exprMentionsConst val addr || exprMentionsConst body addr + | .proj _ _ s _ => exprMentionsConst s addr + | _ => false + +/-- Check strict positivity of a field type w.r.t. a set of inductive addresses. + Returns true if positive, false if negative occurrence found. -/ +partial def checkStrictPositivity (ty : Expr m) (indAddrs : Array Address) : Bool := + -- If no inductive is mentioned, we're fine + if !indAddrs.any (exprMentionsConst ty ·) then true + else match ty with + | .forallE domain body _ _ => + -- Domain must NOT mention any inductive + if indAddrs.any (exprMentionsConst domain ·) then false + -- Continue checking body + else checkStrictPositivity body indAddrs + | e => + -- Not a forall — must be the inductive at the head + let fn := e.getAppFn + match fn with + | .const addr _ _ => indAddrs.any (· == addr) + | _ => false + +/-- Walk a Pi chain, skip numParams binders, then check positivity of each field. + Returns an error message or none on success. -/ +partial def checkCtorPositivity (ctorType : Expr m) (numParams : Nat) (indAddrs : Array Address) + : Option String := + go ctorType numParams +where + go (ty : Expr m) (remainingParams : Nat) : Option String := + match ty with + | .forallE _domain body _name _bi => + if remainingParams > 0 then + go body (remainingParams - 1) + else + -- This is a field — check positivity of its domain + let domain := ty.bindingDomain! + if !checkStrictPositivity domain indAddrs then + some "inductive occurs in negative position (strict positivity violation)" + else + go body 0 + | _ => none + +/-- Walk a Pi chain past numParams + numFields binders to get the return type. + Returns the return type expression (with bvars). -/ +def getCtorReturnType (ctorType : Expr m) (numParams numFields : Nat) : Expr m := + go ctorType (numParams + numFields) +where + go (ty : Expr m) (n : Nat) : Expr m := + match n, ty with + | 0, e => e + | n+1, .forallE _ body _ _ => go body n + | _, e => e + +/-- Extract result universe level from an inductive type expression. + Walks past all forall binders to find the final Sort. -/ +def getIndResultLevel (indType : Expr m) : Option (Level m) := + go indType +where + go : Expr m → Option (Level m) + | .forallE _ body _ _ => go body + | .sort lvl => some lvl + | _ => none + +/-- Check if a level is definitively non-zero (always ≥ 1). -/ +partial def levelIsNonZero : Level m → Bool + | .succ _ => true + | .zero => false + | .param .. => false -- could be zero + | .max a b => levelIsNonZero a || levelIsNonZero b + | .imax _ b => levelIsNonZero b + /-! ## Type info helpers -/ def lamInfo : TypeInfo m → TypeInfo m | .proof => .proof | _ => .none -def piInfo (dom img : TypeInfo m) : TypecheckM m (TypeInfo m) := match dom, img with +def piInfo (dom img : TypeInfo m) : TypecheckM m σ (TypeInfo m) := match dom, img with | .sort lvl, .sort lvl' => pure (.sort (Level.reduceIMax lvl lvl')) | _, _ => pure .none -def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m Bool := do +def eqSortInfo (inferType expectType : SusValue m) : TypecheckM m σ Bool := do match inferType.info, expectType.info with | .sort lvl, .sort lvl' => pure (Level.equalLevel lvl lvl') | _, _ => pure true -- info unavailable; defer to structural equality -def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := +def infoFromType (typ : SusValue m) : TypecheckM m σ (TypeInfo m) := match typ.info with | .sort (.zero) => pure .proof | _ => @@ -45,7 +125,7 @@ def infoFromType (typ : SusValue m) : TypecheckM m (TypeInfo m) := mutual /-- Check that a term has a given type. -/ - partial def check (term : Expr m) (type : SusValue m) : TypecheckM m (TypedExpr m) := do + partial def check (term : Expr m) (type : SusValue m) : TypecheckM m σ (TypedExpr m) := do if (← read).trace then dbg_trace s!"check: {term.tag}" let (te, inferType) ← infer term if !(← eqSortInfo inferType type) then @@ -60,7 +140,7 @@ mutual pure te /-- Infer the type of an expression, returning the typed expression and its type. -/ - partial def infer (term : Expr m) : TypecheckM m (TypedExpr m × SusValue m) := withFuelCheck do + partial def infer (term : Expr m) : TypecheckM m σ (TypedExpr m × SusValue m) := withFuelCheck do if (← read).trace then dbg_trace s!"infer: {term.tag}" match term with | .bvar idx bvarName => do @@ -194,7 +274,7 @@ mutual | _ => throw "Impossible case: structure type does not have enough fields" /-- Check if an expression is a Sort, returning the typed expr and the universe level. -/ - partial def isSort (expr : Expr m) : TypecheckM m (TypedExpr m × Level m) := do + partial def isSort (expr : Expr m) : TypecheckM m σ (TypedExpr m × Level m) := do let (te, typ) ← infer expr match typ.get with | .sort u => pure (te, u) @@ -204,7 +284,7 @@ mutual /-- Get structure info from a value that should be a structure type. -/ partial def getStructInfo (v : Value m) : - TypecheckM m (TypedExpr m × List (Level m) × List (SusValue m)) := do + TypecheckM m σ (TypedExpr m × List (Level m) × List (SusValue m)) := do match v with | .app (.const indAddr univs _) params _ => match (← read).kenv.find? indAddr with @@ -226,13 +306,13 @@ mutual /-- Typecheck a constant. With fresh state per declaration, dependencies get provisional entries via `ensureTypedConst` and are assumed well-typed. -/ - partial def checkConst (addr : Address) : TypecheckM m Unit := withResetCtx do + partial def checkConst (addr : Address) : TypecheckM m σ Unit := withResetCtx do -- Reset fuel and per-constant caches - modify fun stt => { stt with - fuel := defaultFuel - evalCache := {} - equalCache := {} - constTypeCache := {} } + modify fun stt => { stt with constTypeCache := {} } + let ctx ← read + let _ ← ctx.fuelRef.set defaultFuel + let _ ← ctx.evalCacheRef.set {} + let _ ← ctx.equalCacheRef.set {} -- Skip if already in typedConsts (provisional entry is fine — dependency assumed well-typed) if (← get).typedConsts.get? addr |>.isSome then return () @@ -286,7 +366,12 @@ mutual ensureTypedConst indAddr -- Check recursor type let (type, _) ← isSort ci.type - -- Check recursor rules + -- (#3) Validate K-flag instead of trusting the environment + if v.k then + validateKFlag v indAddr + -- (#4) Validate recursor rules + validateRecursorRules v indAddr + -- Check recursor rules (type-check RHS) let typedRules ← v.rules.mapM fun rule => do let (rhs, _) ← infer rule.rhs pure (rule.nfields, rhs) @@ -295,7 +380,7 @@ mutual /-- Walk a Pi chain to extract the return sort level (the universe of the result type). Assumes the expression ends in `Sort u` after `numBinders` forall binders. -/ - partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m (Level m) := + partial def getReturnSort (expr : Expr m) (numBinders : Nat) : TypecheckM m σ (Level m) := match numBinders, expr with | 0, .sort u => do let univs := (← read).env.univs.toArray @@ -316,7 +401,7 @@ mutual | _, _ => throw "inductive type has fewer binders than expected" /-- Typecheck a mutual inductive block starting from one of its addresses. -/ - partial def checkIndBlock (addr : Address) : TypecheckM m Unit := do + partial def checkIndBlock (addr : Address) : TypecheckM m σ Unit := do let ci ← derefConst addr -- Find the inductive info let indInfo ← match ci with @@ -337,6 +422,13 @@ mutual | some (.ctorInfo cv) => cv.numFields > 0 | _ => false modify fun stt => { stt with typedConsts := stt.typedConsts.insert addr (TypedConst.inductive type isStruct) } + + -- Collect all inductive addresses in this mutual block + let indAddrs := iv.all + + -- Get the inductive's result universe level + let indResultLevel := getIndResultLevel iv.type + -- Check constructors for (ctorAddr, cidx) in iv.ctors.toList.zipIdx do match (← read).kenv.find? ctorAddr with @@ -344,23 +436,146 @@ mutual let ctorUnivs := cv.toConstantVal.mkUnivParams let (ctorType, _) ← withEnv (.mk [] ctorUnivs.toList) (isSort cv.type) modify fun stt => { stt with typedConsts := stt.typedConsts.insert ctorAddr (TypedConst.constructor ctorType cidx cv.numFields) } + + -- (#5) Check constructor parameter count matches inductive + if cv.numParams != iv.numParams then + throw s!"Constructor {ctorAddr} has {cv.numParams} params but inductive has {iv.numParams}" + + -- (#1) Positivity checking (skip for unsafe inductives) + if !iv.isUnsafe then + match checkCtorPositivity cv.type cv.numParams indAddrs with + | some msg => throw s!"Constructor {ctorAddr}: {msg}" + | none => pure () + + -- (#2) Universe constraint checking on constructor fields + -- Each non-parameter field's sort must be ≤ the inductive's result sort. + -- We check this by inferring the sort of each field type and comparing levels. + if !iv.isUnsafe then + if let some indLvl := indResultLevel then + let indLvlReduced := Level.instBulkReduce univs indLvl + checkFieldUniverses cv.type cv.numParams ctorAddr indLvlReduced + + -- (#6) Check indices in ctor return type don't mention the inductive + if !iv.isUnsafe then + let retType := getCtorReturnType cv.type cv.numParams cv.numFields + let args := retType.getAppArgs + -- Index arguments are those after numParams + for i in [iv.numParams:args.size] do + for indAddr in indAddrs do + if exprMentionsConst args[i]! indAddr then + throw s!"Constructor {ctorAddr} index argument mentions the inductive (unsound)" + | _ => throw s!"Constructor {ctorAddr} not found" -- Note: recursors are checked individually via checkConst's .recInfo branch, -- which calls checkConst on the inductives first then checks rules. + + /-- Check that constructor field types have sorts ≤ the inductive's result sort. -/ + partial def checkFieldUniverses (ctorType : Expr m) (numParams : Nat) + (ctorAddr : Address) (indLvl : Level m) : TypecheckM m σ Unit := + go ctorType numParams + where + go (ty : Expr m) (remainingParams : Nat) : TypecheckM m σ Unit := + match ty with + | .forallE dom body piName _ => + if remainingParams > 0 then do + let (domTe, _) ← isSort dom + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx var domVal (go body (remainingParams - 1)) + else do + -- This is a field — infer its sort level and check ≤ indLvl + let (domTe, fieldSortLvl) ← isSort dom + let fieldReduced := Level.reduce fieldSortLvl + let indReduced := Level.reduce indLvl + -- Allow if field ≤ ind, OR if ind is Prop (is_zero allows any field) + if !Level.leq fieldReduced indReduced 0 && !Level.isZero indReduced then + throw s!"Constructor {ctorAddr} field type lives in a universe larger than the inductive's universe" + let ctx ← read + let stt ← get + let domVal := suspend domTe ctx stt + let var := mkSusVar (← infoFromType domVal) ctx.lvl piName + withExtendedCtx var domVal (go body 0) + | _ => pure () + + /-- (#3) Validate K-flag: requires non-mutual, Prop, single ctor, zero fields. -/ + partial def validateKFlag (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do + -- Must be non-mutual + if rec.all.size != 1 then + throw "recursor claims K but inductive is mutual" + -- Look up the inductive + match (← read).kenv.find? indAddr with + | some (.inductInfo iv) => + -- Must be in Prop + match getIndResultLevel iv.type with + | some lvl => + if levelIsNonZero lvl then + throw s!"recursor claims K but inductive is not in Prop" + | none => throw "recursor claims K but cannot determine inductive's result sort" + -- Must have single constructor + if iv.ctors.size != 1 then + throw s!"recursor claims K but inductive has {iv.ctors.size} constructors (need 1)" + -- Constructor must have zero fields + match (← read).kenv.find? iv.ctors[0]! with + | some (.ctorInfo cv) => + if cv.numFields != 0 then + throw s!"recursor claims K but constructor has {cv.numFields} fields (need 0)" + | _ => throw "recursor claims K but constructor not found" + | _ => throw s!"recursor claims K but {indAddr} is not an inductive" + + /-- (#4) Validate recursor rules: check rule count, ctor membership, field counts. -/ + partial def validateRecursorRules (rec : RecursorVal m) (indAddr : Address) : TypecheckM m σ Unit := do + -- Collect all constructors from the mutual block + let mut allCtors : Array Address := #[] + for iAddr in rec.all do + match (← read).kenv.find? iAddr with + | some (.inductInfo iv) => + allCtors := allCtors ++ iv.ctors + | _ => throw s!"recursor references {iAddr} which is not an inductive" + -- Check rule count + if rec.rules.size != allCtors.size then + throw s!"recursor has {rec.rules.size} rules but inductive(s) have {allCtors.size} constructors" + -- Check each rule + for h : i in [:rec.rules.size] do + let rule := rec.rules[i] + -- Rule's constructor must match expected constructor in order + if rule.ctor != allCtors[i]! then + throw s!"recursor rule {i} has constructor {rule.ctor} but expected {allCtors[i]!}" + -- Look up the constructor and validate nfields + match (← read).kenv.find? rule.ctor with + | some (.ctorInfo cv) => + if rule.nfields != cv.numFields then + throw s!"recursor rule for {rule.ctor} has nfields={rule.nfields} but constructor has {cv.numFields} fields" + | _ => throw s!"recursor rule constructor {rule.ctor} not found" + -- Validate structural counts against the inductive + match (← read).kenv.find? indAddr with + | some (.inductInfo iv) => + if rec.numParams != iv.numParams then + throw s!"recursor numParams={rec.numParams} but inductive has {iv.numParams}" + if rec.numIndices != iv.numIndices then + throw s!"recursor numIndices={rec.numIndices} but inductive has {iv.numIndices}" + | _ => pure () + end -- mutual /-! ## Top-level entry points -/ /-- Typecheck a single constant by address. -/ def typecheckConst (kenv : Env m) (prims : Primitives) (addr : Address) - (quotInit : Bool := true) : Except String Unit := do - let ctx : TypecheckCtx m := { - lvl := 0, env := default, types := [], kenv := kenv, - prims := prims, safety := .safe, quotInit := quotInit, - mutTypes := default, recAddr? := none - } - let stt : TypecheckState m := { typedConsts := default } - TypecheckM.run ctx stt (checkConst addr) + (quotInit : Bool := true) : Except String Unit := + runST fun σ => do + let fuelRef ← ST.mkRef defaultFuel + let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level m) × Value m)) + let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) + let ctx : TypecheckCtx m σ := { + lvl := 0, env := default, types := [], kenv := kenv, + prims := prims, safety := .safe, quotInit := quotInit, + mutTypes := default, recAddr? := none, + fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef + } + let stt : TypecheckState m := { typedConsts := default } + TypecheckM.run ctx stt (checkConst addr) /-- Typecheck all constants in a kernel environment. Uses fresh state per declaration — dependencies are assumed well-typed. -/ @@ -399,7 +614,8 @@ def typecheckAllIO (kenv : Env m) (prims : Primitives) (quotInit : Bool := true) let val := match ci.value? with | some v => s!"\n value: {v.pp}" | none => "" - IO.println s!"type: {typ}{val}" + IO.println s!"type: {typ}" + IO.println s!"val: {val}" return .error s!"{header}: {e}" return .ok () diff --git a/Ix/Kernel/TypecheckM.lean b/Ix/Kernel/TypecheckM.lean index 8b1a93ba..9fb0d2cd 100644 --- a/Ix/Kernel/TypecheckM.lean +++ b/Ix/Kernel/TypecheckM.lean @@ -8,7 +8,7 @@ namespace Ix.Kernel /-! ## Typechecker Context -/ -structure TypecheckCtx (m : MetaMode) where +structure TypecheckCtx (m : MetaMode) (σ : Type) where lvl : Nat env : ValEnv m types : List (SusValue m) @@ -23,29 +23,23 @@ structure TypecheckCtx (m : MetaMode) where /-- Depth fuel: bounds the call-stack depth to prevent native stack overflow. Decremented via the reader on each entry to eval/equal/infer. Thunks inherit the depth from their capture point. -/ - depth : Nat := 3000 + depth : Nat := 10000 /-- Enable dbg_trace on major entry points for debugging. -/ trace : Bool := false - deriving Inhabited + /-- Global fuel counter: bounds total recursive work across all thunks via ST.Ref. -/ + fuelRef : ST.Ref σ Nat + /-- Mutable eval cache: persists across thunk evaluations via ST.Ref. -/ + evalCacheRef : ST.Ref σ (Std.HashMap Address (Array (Level m) × Value m)) + /-- Mutable equality cache: persists across thunk evaluations via ST.Ref. -/ + equalCacheRef : ST.Ref σ (Std.HashMap (USize × USize) Bool) /-! ## Typechecker State -/ /-- Default fuel for bounding total recursive work per constant. -/ -def defaultFuel : Nat := 100000 +def defaultFuel : Nat := 200000 structure TypecheckState (m : MetaMode) where typedConsts : Std.TreeMap Address (TypedConst m) Address.compare - /-- Fuel counter for bounding total recursive work. Decremented on each entry to - eval/equal/infer. Reset at the start of each `checkConst` call. -/ - fuel : Nat := defaultFuel - /-- Cache for evaluated constant definitions. Maps an address to its universe - parameters and evaluated value. Universe-polymorphic constants produce different - values for different universe instantiations, so we store and check univs. -/ - evalCache : Std.HashMap Address (Array (Level m) × Value m) := {} - /-- Cache for definitional equality results. Maps `(ptrAddrUnsafe a, ptrAddrUnsafe b)` - (canonicalized so smaller pointer comes first) to `Bool`. Only `true` results are - cached (monotone under state growth). -/ - equalCache : Std.HashMap (USize × USize) Bool := {} /-- Cache for constant type SusValues. When `infer (.const addr _)` computes a suspended type, it is cached here so repeated references to the same constant share the same SusValue pointer, enabling fast-path pointer equality in `equal`. @@ -55,75 +49,87 @@ structure TypecheckState (m : MetaMode) where /-! ## TypecheckM monad -/ -abbrev TypecheckM (m : MetaMode) := ReaderT (TypecheckCtx m) (StateT (TypecheckState m) (Except String)) +abbrev TypecheckM (m : MetaMode) (σ : Type) := + ReaderT (TypecheckCtx m σ) (ExceptT String (StateT (TypecheckState m) (ST σ))) + +def TypecheckM.run (ctx : TypecheckCtx m σ) (stt : TypecheckState m) + (x : TypecheckM m σ α) : ST σ (Except String α) := do + let (result, _) ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt + pure result + +def TypecheckM.runState (ctx : TypecheckCtx m σ) (stt : TypecheckState m) (x : TypecheckM m σ α) + : ST σ (Except String (α × TypecheckState m)) := do + let (result, stt') ← StateT.run (ExceptT.run (ReaderT.run x ctx)) stt + pure (match result with | .ok a => .ok (a, stt') | .error e => .error e) + +/-! ## pureRunST -/ -def TypecheckM.run (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) : Except String α := - match (StateT.run (ReaderT.run x ctx) stt) with - | .error e => .error e - | .ok (a, _) => .ok a +/-- Unsafe bridge: run ST σ from pure code (for Thunk bodies). + Safe because the only side effects are append-only cache mutations. -/ +@[inline] unsafe def pureRunSTImpl {σ α : Type} [Inhabited α] (x : ST σ α) : α := + (x (unsafeCast ())).val -def TypecheckM.runState (ctx : TypecheckCtx m) (stt : TypecheckState m) (x : TypecheckM m α) - : Except String (α × TypecheckState m) := - StateT.run (ReaderT.run x ctx) stt +@[implemented_by pureRunSTImpl] +opaque pureRunST {σ α : Type} [Inhabited α] : ST σ α → α /-! ## Context modifiers -/ -def withEnv (env : ValEnv m) : TypecheckM m α → TypecheckM m α := +def withEnv (env : ValEnv m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := env } -def withResetCtx : TypecheckM m α → TypecheckM m α := +def withResetCtx : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with lvl := 0, env := default, types := default, mutTypes := default, recAddr? := none } def withMutTypes (mutTypes : Std.TreeMap Nat (Address × (Array (Level m) → SusValue m)) compare) : - TypecheckM m α → TypecheckM m α := + TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with mutTypes := mutTypes } -def withExtendedCtx (val typ : SusValue m) : TypecheckM m α → TypecheckM m α := +def withExtendedCtx (val typ : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with lvl := ctx.lvl + 1, types := typ :: ctx.types, env := ctx.env.extendWith val } -def withExtendedEnv (thunk : SusValue m) : TypecheckM m α → TypecheckM m α := +def withExtendedEnv (thunk : SusValue m) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := ctx.env.extendWith thunk } def withNewExtendedEnv (env : ValEnv m) (thunk : SusValue m) : - TypecheckM m α → TypecheckM m α := + TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with env := env.extendWith thunk } -def withRecAddr (addr : Address) : TypecheckM m α → TypecheckM m α := +def withRecAddr (addr : Address) : TypecheckM m σ α → TypecheckM m σ α := withReader fun ctx => { ctx with recAddr? := some addr } /-- Check both fuel counters, decrement them, and run the action. - State fuel bounds total work (prevents exponential blowup / hanging). - Reader depth bounds call-stack depth (prevents native stack overflow). -/ -def withFuelCheck (action : TypecheckM m α) : TypecheckM m α := do +def withFuelCheck (action : TypecheckM m σ α) : TypecheckM m σ α := do let ctx ← read if ctx.depth == 0 then throw "deep recursion depth limit reached" - let stt ← get - if stt.fuel == 0 then throw "deep recursion work limit reached" - set { stt with fuel := stt.fuel - 1 } + let fuel ← ctx.fuelRef.get + if fuel == 0 then throw "deep recursion fuel limit reached" + let _ ← ctx.fuelRef.set (fuel - 1) withReader (fun ctx => { ctx with depth := ctx.depth - 1 }) action /-! ## Name lookup -/ /-- Look up the MetaField name for a constant address from the kernel environment. -/ -def lookupName (addr : Address) : TypecheckM m (MetaField m Ix.Name) := do +def lookupName (addr : Address) : TypecheckM m σ (MetaField m Ix.Name) := do match (← read).kenv.find? addr with | some ci => pure ci.cv.name | none => pure default /-! ## Const dereferencing -/ -def derefConst (addr : Address) : TypecheckM m (ConstantInfo m) := do +def derefConst (addr : Address) : TypecheckM m σ (ConstantInfo m) := do let ctx ← read match ctx.kenv.find? addr with | some ci => pure ci | none => throw s!"unknown constant {addr}" -def derefTypedConst (addr : Address) : TypecheckM m (TypedConst m) := do +def derefTypedConst (addr : Address) : TypecheckM m σ (TypedConst m) := do match (← get).typedConsts.get? addr with | some tc => pure tc | none => throw s!"typed constant not found: {addr}" @@ -170,7 +176,7 @@ def provisionalTypedConst (ci : ConstantInfo m) : TypedConst m := /-- Ensure a constant has a TypedConst entry. If not already present, build a provisional one from raw ConstantInfo. This avoids the deep recursion of `checkConst` when called from `infer`. -/ -def ensureTypedConst (addr : Address) : TypecheckM m Unit := do +def ensureTypedConst (addr : Address) : TypecheckM m σ Unit := do if (← get).typedConsts.get? addr |>.isSome then return () let ci ← derefConst addr let tc := provisionalTypedConst ci diff --git a/Tests/Ix/KernelTests.lean b/Tests/Ix/KernelTests.lean index f1ed3c55..b14dbff4 100644 --- a/Tests/Ix/KernelTests.lean +++ b/Tests/Ix/KernelTests.lean @@ -131,10 +131,38 @@ def testLevelOps : TestSeq := /-! ## Integration tests: Const pipeline -/ -/-- Parse a dotted name string like "Nat.add" into an Ix.Name. -/ -private def parseIxName (s : String) : Ix.Name := - let parts := s.splitOn "." - parts.foldl (fun acc part => Ix.Name.mkStr acc part) Ix.Name.mkAnon +/-- Parse a dotted name string like "Nat.add" into an Ix.Name. + Handles `«...»` quoted name components (e.g. `Foo.«0».Bar`). -/ +private partial def parseIxName (s : String) : Ix.Name := + let parts := splitParts s.toList [] + parts.foldl (fun acc part => + match part with + | .inl str => Ix.Name.mkStr acc str + | .inr nat => Ix.Name.mkNat acc nat + ) Ix.Name.mkAnon +where + /-- Split a dotted name into parts: .inl for string components, .inr for numeric (guillemet). -/ + splitParts : List Char → List (String ⊕ Nat) → List (String ⊕ Nat) + | [], acc => acc + | '.' :: rest, acc => splitParts rest acc + | '«' :: rest, acc => + let (inside, rest') := collectUntilClose rest "" + let part := match inside.toNat? with + | some n => .inr n + | none => .inl inside + splitParts rest' (acc ++ [part]) + | cs, acc => + let (word, rest) := collectUntilDot cs "" + splitParts rest (if word.isEmpty then acc else acc ++ [.inl word]) + collectUntilClose : List Char → String → String × List Char + | [], s => (s, []) + | '»' :: rest, s => (s, rest) + | c :: rest, s => collectUntilClose rest (s.push c) + collectUntilDot : List Char → String → String × List Char + | [], s => (s, []) + | '.' :: rest, s => (s, '.' :: rest) + | '«' :: rest, s => (s, '«' :: rest) + | c :: rest, s => collectUntilDot rest (s.push c) /-- Convert a Lean.Name to an Ix.Name (reproducing the Blake3 hashing). -/ private partial def leanNameToIx : Lean.Name → Ix.Name @@ -605,6 +633,461 @@ def negativeTests : TestSeq := return (false, some s!"{failures.size} failure(s)") ) .done +/-! ## Soundness negative tests (inductive validation) -/ + +/-- Helper: make unique addresses from a seed byte. -/ +private def mkAddr (seed : UInt8) : Address := + Address.blake3 (ByteArray.mk #[seed, 0xAA, 0xBB]) + +/-- Soundness negative test suite: verify that the typechecker rejects unsound + inductive declarations (positivity, universe constraints, K-flag, recursor rules). -/ +def soundnessNegativeTests : TestSeq := + .individualIO "kernel soundness negative tests" (do + let prims := buildPrimitives + let mut passed := 0 + let mut failures : Array String := #[] + + -- ======================================================================== + -- Test 1: Positivity violation — Bad | mk : (Bad → Bad) → Bad + -- The inductive appears in negative position (Pi domain). + -- ======================================================================== + do + let badAddr := mkAddr 10 + let badMkAddr := mkAddr 11 + let badType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let badCv : ConstantVal .anon := + { numLevels := 0, type := badType, name := (), levelParams := () } + let badInd : ConstantInfo .anon := .inductInfo { + toConstantVal := badCv, numParams := 0, numIndices := 0, + all := #[badAddr], ctors := #[badMkAddr], numNested := 0, + isRec := true, isUnsafe := false, isReflexive := false + } + -- mk : (Bad → Bad) → Bad + -- The domain (Bad → Bad) has Bad in negative position + let mkType : Expr .anon := + .forallE + (.forallE (.const badAddr #[] ()) (.const badAddr #[] ()) () ()) + (.const badAddr #[] ()) + () () + let mkCv : ConstantVal .anon := + { numLevels := 0, type := mkType, name := (), levelParams := () } + let mkCtor : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := badAddr, cidx := 0, + numParams := 0, numFields := 1, isUnsafe := false + } + let env := ((default : Env .anon).insert badAddr badInd).insert badMkAddr mkCtor + match typecheckConst env prims badAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "positivity-violation: expected error (Bad → Bad in domain)" + + -- ======================================================================== + -- Test 2: Universe constraint violation — Uni1Bad : Sort 1 | mk : Sort 2 → Uni1Bad + -- Field lives in Sort 3 (Sort 2 : Sort 3) but inductive is in Sort 1. + -- (Note: Prop inductives have special exception allowing any field universe, + -- so we test with a Sort 1 inductive instead.) + -- ======================================================================== + do + let ubAddr := mkAddr 20 + let ubMkAddr := mkAddr 21 + let ubType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let ubCv : ConstantVal .anon := + { numLevels := 0, type := ubType, name := (), levelParams := () } + let ubInd : ConstantInfo .anon := .inductInfo { + toConstantVal := ubCv, numParams := 0, numIndices := 0, + all := #[ubAddr], ctors := #[ubMkAddr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + -- mk : Sort 2 → Uni1Bad + -- Sort 2 : Sort 3, so field sort = 3. Inductive sort = 1. 3 ≤ 1 fails. + let mkType : Expr .anon := + .forallE (.sort (.succ (.succ .zero))) (.const ubAddr #[] ()) () () + let mkCv : ConstantVal .anon := + { numLevels := 0, type := mkType, name := (), levelParams := () } + let mkCtor : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := ubAddr, cidx := 0, + numParams := 0, numFields := 1, isUnsafe := false + } + let env := ((default : Env .anon).insert ubAddr ubInd).insert ubMkAddr mkCtor + match typecheckConst env prims ubAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "universe-constraint: expected error (Sort 2 field in Sort 1 inductive)" + + -- ======================================================================== + -- Test 3: K-flag invalid — K=true on non-Prop inductive (Sort 1, 2 ctors) + -- ======================================================================== + do + let indAddr := mkAddr 30 + let mk1Addr := mkAddr 31 + let mk2Addr := mkAddr 32 + let recAddr := mkAddr 33 + let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 (not Prop) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + -- Recursor with k=true on a non-Prop inductive + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ], + k := true, -- INVALID: not Prop + isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "k-flag-not-prop: expected error" + + -- ======================================================================== + -- Test 4: Recursor wrong rule count — 1 rule for 2-ctor inductive + -- ======================================================================== + do + let indAddr := mkAddr 40 + let mk1Addr := mkAddr 41 + let mk2Addr := mkAddr 42 + let recAddr := mkAddr 43 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + -- Recursor with only 1 rule (should be 2) + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[{ ctor := mk1Addr, nfields := 0, rhs := .sort .zero }], -- only 1! + k := false, isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-rule-count: expected error" + + -- ======================================================================== + -- Test 5: Recursor wrong nfields — ctor has 0 fields but rule claims 5 + -- ======================================================================== + do + let indAddr := mkAddr 50 + let mkAddr' := mkAddr 51 + let recAddr := mkAddr 52 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 1, + rules := #[{ ctor := mkAddr', nfields := 5, rhs := .sort .zero }], -- wrong nfields + k := false, isUnsafe := false + } + let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-nfields: expected error" + + -- ======================================================================== + -- Test 6: Recursor wrong num_params — rec claims 5 params, inductive has 0 + -- ======================================================================== + do + let indAddr := mkAddr 60 + let mkAddr' := mkAddr 61 + let recAddr := mkAddr 62 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 5, -- wrong: inductive has 0 + numIndices := 0, numMotives := 1, numMinors := 1, + rules := #[{ ctor := mkAddr', nfields := 0, rhs := .sort .zero }], + k := false, isUnsafe := false + } + let env := (((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-num-params: expected error" + + -- ======================================================================== + -- Test 7: Constructor param count mismatch — ctor claims 3 params, ind has 0 + -- ======================================================================== + do + let indAddr := mkAddr 70 + let mkAddr' := mkAddr 71 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 3, -- wrong: inductive has 0 + numFields := 0, isUnsafe := false + } + let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI + match typecheckConst env prims indAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "ctor-param-mismatch: expected error" + + -- ======================================================================== + -- Test 8: K-flag invalid — K=true on Prop inductive with 2 ctors + -- ======================================================================== + do + let indAddr := mkAddr 80 + let mk1Addr := mkAddr 81 + let mk2Addr := mkAddr 82 + let recAddr := mkAddr 83 + let indType : Expr .anon := .sort .zero -- Prop + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 0, type := .sort .zero, name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero }, + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero } + ], + k := true, -- INVALID: 2 ctors + isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "k-flag-two-ctors: expected error" + + -- ======================================================================== + -- Test 9: Recursor wrong ctor order — rules in wrong order + -- ======================================================================== + do + let indAddr := mkAddr 90 + let mk1Addr := mkAddr 91 + let mk2Addr := mkAddr 92 + let recAddr := mkAddr 93 + let indType : Expr .anon := .sort (.succ .zero) + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mk1Addr, mk2Addr], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mk1Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk1CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk1Cv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let mk2Cv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mk2CI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mk2Cv, induct := indAddr, cidx := 1, + numParams := 0, numFields := 0, isUnsafe := false + } + let recCv : ConstantVal .anon := + { numLevels := 1, type := .sort (.param 0 ()), name := (), levelParams := () } + let recCI : ConstantInfo .anon := .recInfo { + toConstantVal := recCv, all := #[indAddr], + numParams := 0, numIndices := 0, numMotives := 1, numMinors := 2, + rules := #[ + { ctor := mk2Addr, nfields := 0, rhs := .sort .zero }, -- wrong order! + { ctor := mk1Addr, nfields := 0, rhs := .sort .zero } + ], + k := false, isUnsafe := false + } + let env := ((((default : Env .anon).insert indAddr indCI).insert mk1Addr mk1CI).insert mk2Addr mk2CI).insert recAddr recCI + match typecheckConst env prims recAddr with + | .error _ => passed := passed + 1 + | .ok () => failures := failures.push "rec-wrong-ctor-order: expected error" + + -- ======================================================================== + -- Test 10: Valid single-ctor inductive passes (sanity check) + -- ======================================================================== + do + let indAddr := mkAddr 100 + let mkAddr' := mkAddr 101 + let indType : Expr .anon := .sort (.succ .zero) -- Sort 1 + let indCv : ConstantVal .anon := + { numLevels := 0, type := indType, name := (), levelParams := () } + let indCI : ConstantInfo .anon := .inductInfo { + toConstantVal := indCv, numParams := 0, numIndices := 0, + all := #[indAddr], ctors := #[mkAddr'], numNested := 0, + isRec := false, isUnsafe := false, isReflexive := false + } + let mkCv : ConstantVal .anon := + { numLevels := 0, type := .const indAddr #[] (), name := (), levelParams := () } + let mkCI : ConstantInfo .anon := .ctorInfo { + toConstantVal := mkCv, induct := indAddr, cidx := 0, + numParams := 0, numFields := 0, isUnsafe := false + } + let env := ((default : Env .anon).insert indAddr indCI).insert mkAddr' mkCI + match typecheckConst env prims indAddr with + | .ok () => passed := passed + 1 + | .error e => failures := failures.push s!"valid-inductive: unexpected error: {e}" + + let totalTests := 10 + IO.println s!"[kernel-soundness] {passed}/{totalTests} passed" + if failures.isEmpty then + return (true, none) + else + for f in failures do IO.println s!" [fail] {f}" + return (false, some s!"{failures.size} failure(s)") + ) .done + +/-! ## Unit tests: helper functions -/ + +def testHelperFunctions : TestSeq := + -- exprMentionsConst + let addr1 := mkAddr 200 + let addr2 := mkAddr 201 + let c1 : Expr .anon := .const addr1 #[] () + let c2 : Expr .anon := .const addr2 #[] () + test "exprMentionsConst: direct match" + (exprMentionsConst c1 addr1) ++ + test "exprMentionsConst: no match" + (!exprMentionsConst c2 addr1) ++ + test "exprMentionsConst: in app fn" + (exprMentionsConst (.app c1 c2) addr1) ++ + test "exprMentionsConst: in app arg" + (exprMentionsConst (.app c2 c1) addr1) ++ + test "exprMentionsConst: in forallE domain" + (exprMentionsConst (.forallE c1 c2 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: in forallE body" + (exprMentionsConst (.forallE c2 c1 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: in lam" + (exprMentionsConst (.lam c1 c2 () () : Expr .anon) addr1) ++ + test "exprMentionsConst: absent in sort" + (!exprMentionsConst (.sort .zero : Expr .anon) addr1) ++ + test "exprMentionsConst: absent in bvar" + (!exprMentionsConst (.bvar 0 () : Expr .anon) addr1) ++ + -- checkStrictPositivity + let indAddrs := #[addr1] + test "checkStrictPositivity: no mention is positive" + (checkStrictPositivity c2 indAddrs) ++ + test "checkStrictPositivity: head occurrence is positive" + (checkStrictPositivity c1 indAddrs) ++ + test "checkStrictPositivity: in Pi domain is negative" + (!checkStrictPositivity (.forallE c1 c2 () () : Expr .anon) indAddrs) ++ + test "checkStrictPositivity: in Pi codomain positive" + (checkStrictPositivity (.forallE c2 c1 () () : Expr .anon) indAddrs) ++ + -- getIndResultLevel + test "getIndResultLevel: sort zero" + (getIndResultLevel (.sort .zero : Expr .anon) == some .zero) ++ + test "getIndResultLevel: sort (succ zero)" + (getIndResultLevel (.sort (.succ .zero) : Expr .anon) == some (.succ .zero)) ++ + test "getIndResultLevel: forallE _ (sort zero)" + (getIndResultLevel (.forallE (.sort .zero) (.sort (.succ .zero)) () () : Expr .anon) == some (.succ .zero)) ++ + test "getIndResultLevel: bvar (no sort)" + (getIndResultLevel (.bvar 0 () : Expr .anon) == none) ++ + -- levelIsNonZero + test "levelIsNonZero: zero is false" + (!levelIsNonZero (.zero : Level .anon)) ++ + test "levelIsNonZero: succ zero is true" + (levelIsNonZero (.succ .zero : Level .anon)) ++ + test "levelIsNonZero: param is false" + (!levelIsNonZero (.param 0 () : Level .anon)) ++ + test "levelIsNonZero: max(succ 0, param) is true" + (levelIsNonZero (.max (.succ .zero) (.param 0 ()) : Level .anon)) ++ + test "levelIsNonZero: imax(param, succ 0) is true" + (levelIsNonZero (.imax (.param 0 ()) (.succ .zero) : Level .anon)) ++ + test "levelIsNonZero: imax(succ, param) depends on second" + (!levelIsNonZero (.imax (.succ .zero) (.param 0 ()) : Level .anon)) ++ + -- checkCtorPositivity + test "checkCtorPositivity: no inductive mention is ok" + (checkCtorPositivity c2 0 indAddrs == none) ++ + test "checkCtorPositivity: negative occurrence" + (checkCtorPositivity (.forallE (.forallE c1 c2 () ()) (.const addr1 #[] ()) () () : Expr .anon) 0 indAddrs != none) ++ + -- getCtorReturnType + test "getCtorReturnType: no binders returns expr" + (getCtorReturnType c1 0 0 == c1) ++ + test "getCtorReturnType: skips foralls" + (getCtorReturnType (.forallE c2 c1 () () : Expr .anon) 0 1 == c1) ++ + .done + /-! ## Focused NbE constant tests -/ /-- Test individual constants through the NbE kernel to isolate failures. -/ @@ -631,6 +1114,7 @@ def testNbeConsts : TestSeq := "Nat.Linear.Poly.of_denote_eq_cancel", -- String theorem (fuel-sensitive) "String.length_empty", + "_private.Init.Grind.Ring.Basic.«0».Lean.Grind.IsCharP.mk'_aux._proof_1_5", ] let mut passed := 0 let mut failures : Array String := #[] @@ -673,6 +1157,7 @@ def unitSuite : List TestSeq := [ testLevelLeqComplex, testLevelInstBulkReduce, testReducibilityHintsLt, + testHelperFunctions, ] def convertSuite : List TestSeq := [ @@ -686,6 +1171,7 @@ def constSuite : List TestSeq := [ def negativeSuite : List TestSeq := [ negativeTests, + soundnessNegativeTests, ] def anonConvertSuite : List TestSeq := [ diff --git a/Tests/Ix/PP.lean b/Tests/Ix/PP.lean index d96bd0f1..ab52ea3e 100644 --- a/Tests/Ix/PP.lean +++ b/Tests/Ix/PP.lean @@ -248,22 +248,30 @@ def testQuoteRoundtrip : TestSeq := -- Build Value.lam: fun (y : Nat) => y let bodyTE : TypedExpr .meta := ⟨.none, .bvar 0 yName⟩ let lamVal : Value .meta := .lam domVal bodyTE (.mk [] []) yName .default - -- Quote and pp in a minimal TypecheckM context - let ctx : TypecheckCtx .meta := { - lvl := 0, env := .mk [] [], types := [], - kenv := default, prims := buildPrimitives, - safety := .safe, quotInit := true, mutTypes := default, recAddr? := none - } - let stt : TypecheckState .meta := { typedConsts := default } + -- Quote and pp in a minimal TypecheckM context (wrapped in runST for ST.Ref allocation) + let result := runST fun σ => do + let fuelRef ← ST.mkRef Ix.Kernel.defaultFuel + let evalRef ← ST.mkRef ({} : Std.HashMap Address (Array (Level .meta) × Value .meta)) + let equalRef ← ST.mkRef ({} : Std.HashMap (USize × USize) Bool) + let ctx : TypecheckCtx .meta σ := { + lvl := 0, env := .mk [] [], types := [], + kenv := default, prims := buildPrimitives, + safety := .safe, quotInit := true, mutTypes := default, recAddr? := none, + fuelRef := fuelRef, evalCacheRef := evalRef, equalCacheRef := equalRef + } + let stt : TypecheckState .meta := { typedConsts := default } + let piResult ← TypecheckM.run ctx stt (ppValue 0 piVal) + let lamResult ← TypecheckM.run ctx stt (ppValue 0 lamVal) + pure (piResult, lamResult) -- Test pi - match TypecheckM.run ctx stt (ppValue 0 piVal) with + match result.1 with | .ok s => if s != "∀ (x : Nat), Nat" then return (false, some s!"pi round-trip: expected '∀ (x : Nat), Nat', got '{s}'") else pure () | .error e => return (false, some s!"pi round-trip error: {e}") -- Test lam - match TypecheckM.run ctx stt (ppValue 0 lamVal) with + match result.2 with | .ok s => if s != "λ (y : Nat) => y" then return (false, some s!"lam round-trip: expected 'λ (y : Nat) => y', got '{s}'") diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs index ada12904..0cc24620 100644 --- a/src/ix/kernel/def_eq.rs +++ b/src/ix/kernel/def_eq.rs @@ -530,9 +530,8 @@ fn get_applied_def( Some((name.clone(), d.hints)) } }, - ConstantInfo::ThmInfo(_) => { - Some((name.clone(), ReducibilityHints::Opaque)) - }, + // Theorems are never unfolded — proof irrelevance handles them. + // ConstantInfo::ThmInfo(_) => return None, _ => None, } } @@ -1570,13 +1569,12 @@ mod tests { } #[test] - fn test_get_applied_def_includes_theorems_as_opaque() { + fn test_get_applied_def_excludes_theorems() { + // Theorems should never be unfolded — proof irrelevance handles them. let env = mk_thm_env(); let thm = Expr::cnst(mk_name("thmA"), vec![]); let result = get_applied_def(&thm, &env); - assert!(result.is_some()); - let (_, hints) = result.unwrap(); - assert_eq!(hints, ReducibilityHints::Opaque); + assert!(result.is_none()); } // ========================================================================== diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs index 4cf79d45..90da54ba 100644 --- a/src/ix/kernel/inductive.rs +++ b/src/ix/kernel/inductive.rs @@ -155,6 +155,216 @@ pub fn validate_k_flag( Ok(()) } +/// Validate recursor rules against the inductive's constructors. +/// Checks: +/// - One rule per constructor +/// - Each rule's constructor exists and belongs to the inductive +/// - Each rule's n_fields matches the constructor's actual field count +/// - Rules are in constructor order +pub fn validate_recursor_rules( + rec: &RecursorVal, + env: &Env, +) -> TcResult<()> { + // Find the primary inductive + if rec.all.is_empty() { + return Err(TcError::KernelException { + msg: "recursor has no associated inductives".into(), + }); + } + let ind_name = &rec.all[0]; + let ind = match env.get(ind_name) { + Some(ConstantInfo::InductInfo(iv)) => iv, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor's inductive {} is not an inductive type", + ind_name.pretty() + ), + }) + }, + }; + + // For mutual inductives, collect all constructors in order + let mut all_ctors: Vec = Vec::new(); + for iname in &rec.all { + if let Some(ConstantInfo::InductInfo(iv)) = env.get(iname) { + all_ctors.extend(iv.ctors.iter().cloned()); + } + } + + // Check rule count matches total constructor count + if rec.rules.len() != all_ctors.len() { + return Err(TcError::KernelException { + msg: format!( + "recursor has {} rules but inductive(s) have {} constructors", + rec.rules.len(), + all_ctors.len() + ), + }); + } + + // Check each rule + for (i, rule) in rec.rules.iter().enumerate() { + // Rule's constructor must match expected constructor in order + if rule.ctor != all_ctors[i] { + return Err(TcError::KernelException { + msg: format!( + "recursor rule {} has constructor {} but expected {}", + i, + rule.ctor.pretty(), + all_ctors[i].pretty() + ), + }); + } + + // Look up the constructor and validate n_fields + let ctor = match env.get(&rule.ctor) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor rule constructor {} not found or not a constructor", + rule.ctor.pretty() + ), + }) + }, + }; + + if rule.n_fields != ctor.num_fields { + return Err(TcError::KernelException { + msg: format!( + "recursor rule for {} has n_fields={} but constructor has {} fields", + rule.ctor.pretty(), + rule.n_fields, + ctor.num_fields + ), + }); + } + } + + // Validate structural counts against the inductive + let expected_params = ind.num_params.to_u64().unwrap(); + let rec_params = rec.num_params.to_u64().unwrap(); + if rec_params != expected_params { + return Err(TcError::KernelException { + msg: format!( + "recursor num_params={} but inductive has {} params", + rec_params, expected_params + ), + }); + } + + let expected_indices = ind.num_indices.to_u64().unwrap(); + let rec_indices = rec.num_indices.to_u64().unwrap(); + if rec_indices != expected_indices { + return Err(TcError::KernelException { + msg: format!( + "recursor num_indices={} but inductive has {} indices", + rec_indices, expected_indices + ), + }); + } + + // Validate elimination restriction for Prop inductives. + // If the inductive is in Prop and requires elimination only at universe zero, + // then the recursor must not have extra universe parameters beyond the inductive's. + if !rec.is_unsafe { + if let Some(elim_zero) = elim_only_at_universe_zero(ind, env) { + if elim_zero { + // Recursor should have same number of level params as the inductive + // (no extra universe parameter for the motive's result sort) + let ind_level_count = ind.cnst.level_params.len(); + let rec_level_count = rec.cnst.level_params.len(); + if rec_level_count > ind_level_count { + return Err(TcError::KernelException { + msg: format!( + "recursor has {} universe params but inductive has {} — \ + large elimination is not allowed for this Prop inductive", + rec_level_count, ind_level_count + ), + }); + } + } + } + } + + Ok(()) +} + +/// Compute whether a Prop inductive can only eliminate to Prop (universe zero). +/// +/// Returns `Some(true)` if elimination is restricted to Prop, +/// `Some(false)` if large elimination is allowed, +/// `None` if the inductive is not in Prop (no restriction applies). +/// +/// Matches the C++ kernel's `elim_only_at_universe_zero`: +/// 1. If result universe is always non-zero: None (not a predicate) +/// 2. If mutual: restricted +/// 3. If >1 constructor: restricted +/// 4. If 0 constructors: not restricted (e.g., False) +/// 5. If 1 constructor: restricted iff any non-Prop field doesn't appear in result indices +fn elim_only_at_universe_zero( + ind: &InductiveVal, + env: &Env, +) -> Option { + // Check if the inductive's result is in Prop. + // Walk past all binders to find the final Sort. + let mut ty = ind.cnst.typ.clone(); + loop { + match ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ty = body.clone(); + }, + _ => break, + } + } + let result_level = match ty.as_data() { + ExprData::Sort(l, _) => l, + _ => return None, + }; + + // If the result sort is definitively non-zero (e.g., Sort 1, Sort (u+1)), + // this is not a predicate. + if !level::could_be_zero(result_level) { + return None; + } + + // Must be possibly Prop. Apply the 5 conditions. + + // Condition 2: Mutual inductives → restricted + if ind.all.len() > 1 { + return Some(true); + } + + // Condition 3: >1 constructor → restricted + if ind.ctors.len() > 1 { + return Some(true); + } + + // Condition 4: 0 constructors → not restricted (e.g., False) + if ind.ctors.is_empty() { + return Some(false); + } + + // Condition 5: Single constructor — check fields + let ctor = match env.get(&ind.ctors[0]) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return Some(true), // can't look up ctor, be conservative + }; + + // If zero fields, not restricted + if ctor.num_fields == Nat::ZERO { + return Some(false); + } + + // For single-constructor with fields: restricted if any non-Prop field + // doesn't appear in the result type's indices. + // Conservative approximation: if any field exists that could be non-Prop, + // assume restricted. This is safe (may reject some valid large eliminations + // but never allows unsound ones). + Some(true) +} + /// Check if an expression mentions a constant by name. fn expr_mentions_const(e: &Expr, name: &Name) -> bool { let mut stack: Vec<&Expr> = vec![e]; @@ -364,14 +574,33 @@ fn check_field_universe_constraints( /// Verify that a constructor's return type targets the parent inductive. /// Walks the constructor type telescope, then checks that the resulting /// type is an application of the parent inductive with at least `num_params` args. +/// Also validates: +/// - The first `num_params` arguments are definitionally equal to the inductive's parameters. +/// - Index arguments (after params) don't mention the inductive being declared. fn check_ctor_return_type( ctor: &ConstructorVal, ind: &InductiveVal, tc: &mut TypeChecker, ) -> TcResult<()> { - let mut ty = ctor.cnst.typ.clone(); + let num_params = ind.num_params.to_u64().unwrap() as usize; + + // Walk the inductive's type telescope to collect parameter locals. + let mut ind_ty = ind.cnst.typ.clone(); + let mut param_locals = Vec::with_capacity(num_params); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + param_locals.push(local.clone()); + ind_ty = inst(body, &[local]); + }, + _ => break, + } + } - // Walk past all Pi binders + // Walk past all Pi binders in the constructor type. + let mut ty = ctor.cnst.typ.clone(); loop { let whnf_ty = tc.whnf(&ty); match whnf_ty.as_data() { @@ -411,7 +640,6 @@ fn check_ctor_return_type( }); } - let num_params = ind.num_params.to_u64().unwrap() as usize; if args.len() < num_params { return Err(TcError::KernelException { msg: format!( @@ -423,6 +651,35 @@ fn check_ctor_return_type( }); } + // Check that the first num_params arguments match the inductive's parameters. + for i in 0..num_params { + if i < param_locals.len() && !tc.def_eq(&args[i], ¶m_locals[i]) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} parameter {} does not match inductive's parameter", + ctor.cnst.name.pretty(), + i + ), + }); + } + } + + // Check that index arguments (after params) don't mention the inductive. + for i in num_params..args.len() { + for ind_name in &ind.all { + if expr_mentions_const(&args[i], ind_name) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} index argument {} mentions the inductive {}", + ctor.cnst.name.pretty(), + i - num_params, + ind_name.pretty() + ), + }); + } + } + } + Ok(()) } @@ -784,4 +1041,782 @@ mod tests { let mut tc = TypeChecker::new(&env); assert!(check_inductive(ind, &mut tc).is_err()); } + + // ========================================================================== + // Recursor rule validation + // ========================================================================== + + #[test] + fn validate_rec_rules_wrong_count() { + // Nat has 2 ctors but we provide 1 rule + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_wrong_ctor_order() { + // Provide rules in wrong order (succ first, zero second) + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_wrong_nfields() { + // zero has 0 fields but we claim 3 + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(3u64), // wrong! + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_bogus_ctor() { + // Rule references a non-existent constructor + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "bogus"), // doesn't exist + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + #[test] + fn validate_rec_rules_correct() { + // Correct rules for Nat + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_ok()); + } + + #[test] + fn validate_rec_rules_wrong_num_params() { + // Recursor claims 5 params but Nat has 0 + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(5u64), // wrong + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }; + assert!(validate_recursor_rules(&rec, &env).is_err()); + } + + // ========================================================================== + // K-flag validation + // ========================================================================== + + /// Build a Prop inductive with 1 ctor and 0 fields (Eq-like). + fn mk_k_valid_env() -> Env { + let mut env = mk_nat_env(); + let eq_name = mk_name("KEq"); + let eq_refl = mk_name2("KEq", "refl"); + let u = mk_name("u"); + + // KEq.{u} (α : Sort u) (a b : α) : Prop + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), // Prop + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + // KEq.refl.{u} (α : Sort u) (a : α) : KEq α a a + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn validate_k_flag_valid_prop_single_zero_fields() { + let env = mk_k_valid_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("KEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("KEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![RecursorRule { + ctor: mk_name2("KEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_ok()); + } + + #[test] + fn validate_k_flag_fails_not_prop() { + // Nat is in Sort 1, not Prop — K should fail + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + #[test] + fn validate_k_flag_fails_multiple_ctors() { + // Even a Prop inductive with 2 ctors can't be K + // We need a Prop inductive with 2 ctors for this test + let mut env = Env::default(); + let p_name = mk_name("P"); + let mk1 = mk_name2("P", "mk1"); + let mk2 = mk_name2("P", "mk2"); + env.insert( + p_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), // Prop + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![p_name.clone()], + ctors: vec![mk1.clone(), mk2.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + mk1.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk1, + level_params: vec![], + typ: Expr::cnst(p_name.clone(), vec![]), + }, + induct: p_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env.insert( + mk2.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk2, + level_params: vec![], + typ: Expr::cnst(p_name.clone(), vec![]), + }, + induct: p_name, + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("P", "rec"), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + all: vec![mk_name("P")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + #[test] + fn validate_k_flag_false_always_ok() { + // k=false is always conservative, never rejected + let env = mk_nat_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: false, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_ok()); + } + + #[test] + fn validate_k_flag_fails_mutual() { + // K requires all.len() == 1 + let env = mk_k_valid_env(); + let rec = RecursorVal { + cnst: ConstantVal { + name: mk_name2("KEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("KEq"), mk_name("OtherInd")], // mutual + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![], + k: true, + is_unsafe: false, + }; + assert!(validate_k_flag(&rec, &env).is_err()); + } + + // ========================================================================== + // Elimination restriction + // ========================================================================== + + #[test] + fn elim_restriction_non_prop_is_none() { + // Nat is in Sort 1, not Prop — no restriction applies + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), None); + } + + #[test] + fn elim_restriction_prop_2_ctors_restricted() { + // A Prop inductive with 2 constructors: restricted to Prop elimination + let mut env = Env::default(); + let p_name = mk_name("P2"); + let mk1 = mk_name2("P2", "mk1"); + let mk2 = mk_name2("P2", "mk2"); + env.insert( + p_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![p_name.clone()], + ctors: vec![mk1.clone(), mk2.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert(mk1.clone(), ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { name: mk1, level_params: vec![], typ: Expr::cnst(p_name.clone(), vec![]) }, + induct: p_name.clone(), cidx: Nat::from(0u64), num_params: Nat::from(0u64), num_fields: Nat::from(0u64), is_unsafe: false, + })); + env.insert(mk2.clone(), ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { name: mk2, level_params: vec![], typ: Expr::cnst(p_name.clone(), vec![]) }, + induct: p_name.clone(), cidx: Nat::from(1u64), num_params: Nat::from(0u64), num_fields: Nat::from(0u64), is_unsafe: false, + })); + let ind = match env.get(&p_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(true)); + } + + #[test] + fn elim_restriction_prop_0_ctors_not_restricted() { + // Empty Prop inductive (like False): can eliminate to any universe + let env_name = mk_name("MyFalse"); + let mut env = Env::default(); + env.insert( + env_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: env_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![env_name.clone()], + ctors: vec![], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let ind = match env.get(&env_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(false)); + } + + #[test] + fn elim_restriction_prop_1_ctor_0_fields_not_restricted() { + // Prop inductive, 1 ctor, 0 fields (like True): not restricted + let mut env = Env::default(); + let t_name = mk_name("MyTrue"); + let t_mk = mk_name2("MyTrue", "intro"); + env.insert( + t_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: t_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![t_name.clone()], + ctors: vec![t_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + t_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: t_mk, + level_params: vec![], + typ: Expr::cnst(t_name.clone(), vec![]), + }, + induct: t_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let ind = match env.get(&t_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(false)); + } + + #[test] + fn elim_restriction_prop_1_ctor_with_fields_restricted() { + // Prop inductive, 1 ctor with fields: conservatively restricted + // (like Exists) + let mut env = Env::default(); + let ex_name = mk_name("MyExists"); + let ex_mk = mk_name2("MyExists", "intro"); + // For simplicity: MyExists : Prop, MyExists.intro : Prop → MyExists + // (simplified from the real Exists which is polymorphic) + env.insert( + ex_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: ex_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![ex_name.clone()], + ctors: vec![ex_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + ex_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: ex_mk, + level_params: vec![], + typ: Expr::all( + mk_name("h"), + Expr::sort(Level::zero()), // a Prop field + Expr::cnst(ex_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: ex_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + let ind = match env.get(&ex_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + // Conservative: any fields means restricted + assert_eq!(elim_only_at_universe_zero(ind, &env), Some(true)); + } + + // ========================================================================== + // Index-mentions-inductive check + // ========================================================================== + + #[test] + fn index_mentions_inductive_rejected() { + // Construct an inductive with 1 param and 1 index where the index + // mentions the inductive itself. This should be rejected. + // + // inductive Bad (α : Type) : Bad α → Type + // | mk : Bad α + // + // The ctor return type is `Bad α (Bad.mk α)`, but for the test + // we manually build a ctor whose index arg mentions `Bad`. + let mut env = mk_nat_env(); + let bad_name = mk_name("BadIdx"); + let bad_mk = mk_name2("BadIdx", "mk"); + + // BadIdx (α : Sort 1) : Sort 1 + // (For simplicity, we make it have 1 param and 1 index) + env.insert( + bad_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bad_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("α"), + Expr::sort(Level::succ(Level::zero())), + Expr::all( + mk_name("_idx"), + nat_type(), // index of type Nat + Expr::sort(Level::succ(Level::zero())), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + }, + num_params: Nat::from(1u64), + num_indices: Nat::from(1u64), + all: vec![bad_name.clone()], + ctors: vec![bad_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // BadIdx.mk (α : Sort 1) : BadIdx α + // The return type's index argument mentions BadIdx + let bad_idx_expr = Expr::app( + Expr::cnst(bad_name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), // dummy + ); + let ctor_ret = Expr::app( + Expr::app( + Expr::cnst(bad_name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), // param α + ), + bad_idx_expr, // index mentions BadIdx! + ); + env.insert( + bad_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: bad_mk, + level_params: vec![], + typ: Expr::all( + mk_name("α"), + Expr::sort(Level::succ(Level::zero())), + ctor_ret, + BinderInfo::Default, + ), + }, + induct: bad_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(1u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&bad_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // expr_mentions_const + // ========================================================================== + + #[test] + fn expr_mentions_const_direct() { + let name = mk_name("Foo"); + let e = Expr::cnst(name.clone(), vec![]); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_nested_app() { + let name = mk_name("Foo"); + let e = Expr::app( + Expr::cnst(mk_name("bar"), vec![]), + Expr::cnst(name.clone(), vec![]), + ); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_absent() { + let name = mk_name("Foo"); + let e = Expr::app( + Expr::cnst(mk_name("bar"), vec![]), + Expr::cnst(mk_name("baz"), vec![]), + ); + assert!(!expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_in_forall_domain() { + let name = mk_name("Foo"); + let e = Expr::all( + mk_name("x"), + Expr::cnst(name.clone(), vec![]), + Expr::sort(Level::zero()), + BinderInfo::Default, + ); + assert!(expr_mentions_const(&e, &name)); + } + + #[test] + fn expr_mentions_const_in_let() { + let name = mk_name("Foo"); + let e = Expr::letE( + mk_name("x"), + Expr::sort(Level::zero()), + Expr::cnst(name.clone(), vec![]), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(expr_mentions_const(&e, &name)); + } } diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs index 80195e35..624f8fb2 100644 --- a/src/ix/kernel/level.rs +++ b/src/ix/kernel/level.rs @@ -54,6 +54,23 @@ pub fn is_zero(l: &Level) -> bool { leq(l, &Level::zero()) } +/// Check if a level could possibly be zero (i.e., not definitively non-zero). +/// Returns false only if the level is guaranteed to be ≥ 1 for all parameter assignments. +pub fn could_be_zero(l: &Level) -> bool { + let s = simplify(l); + could_be_zero_core(&s) +} + +fn could_be_zero_core(l: &Level) -> bool { + match l.as_data() { + LevelData::Zero(_) => true, + LevelData::Succ(..) => false, // n+1 is never zero + LevelData::Param(..) | LevelData::Mvar(..) => true, // parameter could be instantiated to zero + LevelData::Max(a, b, _) => could_be_zero_core(a) && could_be_zero_core(b), + LevelData::Imax(_, b, _) => could_be_zero_core(b), // imax(a, 0) = 0 + } +} + /// Check if `l <= r`. pub fn leq(l: &Level, r: &Level) -> bool { let l_s = simplify(l); @@ -400,4 +417,72 @@ mod tests { let expected = Level::succ(Level::zero()); assert_eq!(result, expected); } + + // ========================================================================== + // could_be_zero + // ========================================================================== + + #[test] + fn could_be_zero_zero() { + assert!(could_be_zero(&Level::zero())); + } + + #[test] + fn could_be_zero_succ_is_false() { + // Succ(0) = 1, never zero + assert!(!could_be_zero(&Level::succ(Level::zero()))); + } + + #[test] + fn could_be_zero_succ_param_is_false() { + // u+1 is never zero regardless of u + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(!could_be_zero(&Level::succ(u))); + } + + #[test] + fn could_be_zero_param_is_true() { + // Param u could be zero (instantiated to 0) + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(could_be_zero(&u)); + } + + #[test] + fn could_be_zero_max_both_could() { + // max(u, v) could be zero if both u and v could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(could_be_zero(&Level::max(u, v))); + } + + #[test] + fn could_be_zero_max_one_nonzero() { + // max(u+1, v) cannot be zero because u+1 ≥ 1 + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(!could_be_zero(&Level::max(Level::succ(u), v))); + } + + #[test] + fn could_be_zero_imax_zero_right() { + // imax(u, 0) = 0, so could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + assert!(could_be_zero(&Level::imax(u, Level::zero()))); + } + + #[test] + fn could_be_zero_imax_succ_right() { + // imax(u, v+1) = max(u, v+1), never zero since v+1 ≥ 1 + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(!could_be_zero(&Level::imax(u, Level::succ(v)))); + } + + #[test] + fn could_be_zero_imax_param_right() { + // imax(u, v): if v=0 then imax(u,0)=0, so could be zero + let u = Level::param(Name::str(Name::anon(), "u".into())); + let v = Level::param(Name::str(Name::anon(), "v".into())); + assert!(could_be_zero(&Level::imax(u, v))); + } } diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs index 604fbf02..59685192 100644 --- a/src/ix/kernel/tc.rs +++ b/src/ix/kernel/tc.rs @@ -573,6 +573,7 @@ impl<'env> TypeChecker<'env> { } } super::inductive::validate_k_flag(v, self.env)?; + super::inductive::validate_recursor_rules(v, self.env)?; }, } Ok(()) @@ -1542,7 +1543,8 @@ mod tests { } #[test] - fn check_rec_with_inductive() { + fn check_rec_empty_rules_fails() { + // Nat has 2 constructors, so 0 rules should fail let env = mk_nat_env(); let mut tc = TypeChecker::new(&env); let rec = ConstantInfo::RecInfo(RecursorVal { @@ -1560,7 +1562,16 @@ mod tests { k: false, is_unsafe: false, }); - assert!(tc.check_declar(&rec).is_ok()); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_with_valid_rules() { + // Use the full mk_nat_env which includes Nat.rec with proper rules + let env = mk_nat_env(); + let nat_rec = env.get(&mk_name2("Nat", "rec")).unwrap(); + let mut tc = TypeChecker::new(&env); + assert!(tc.check_declar(nat_rec).is_ok()); } // ========================================================================== @@ -1940,7 +1951,11 @@ mod tests { num_indices: Nat::from(1u64), num_motives: Nat::from(1u64), num_minors: Nat::from(1u64), - rules: vec![], + rules: vec![RecursorRule { + ctor: mk_name2("MyEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), // placeholder + }], k: true, is_unsafe: false, }); @@ -2184,4 +2199,301 @@ mod tests { let ty = tc.infer(&e).unwrap(); assert_eq!(ty, nat_type()); } + + // ========================================================================== + // check_declar: Recursor rule validation (integration tests) + // ========================================================================== + + #[test] + fn check_rec_wrong_nfields_via_check_declar() { + // Nat.rec with zero rule claiming 5 fields instead of 0 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let motive_type = Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ); + let rec_type = Expr::all( + mk_name("motive"), + motive_type, + Expr::sort(Level::param(u.clone())), // simplified + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(5u64), // WRONG + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_wrong_ctor_order_via_check_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let rec_type = Expr::all( + mk_name("motive"), + Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ), + Expr::sort(Level::param(u.clone())), + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + // WRONG ORDER: succ then zero + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_wrong_num_params_via_check_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + + let rec_type = Expr::all( + mk_name("motive"), + Expr::all( + mk_name("_"), + nat_type(), + Expr::sort(Level::param(u.clone())), + BinderInfo::Default, + ), + Expr::sort(Level::param(u.clone())), + BinderInfo::Implicit, + ); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec2"), + level_params: vec![u], + typ: rec_type, + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(99u64), // WRONG: Nat has 0 params + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_valid_rules_passes() { + // Full Nat.rec declaration from mk_nat_env passes check_declar + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let nat_rec = env.get(&mk_name2("Nat", "rec")).unwrap(); + assert!(tc.check_declar(nat_rec).is_ok()); + } + + // ========================================================================== + // check_declar: K-flag via check_declar + // ========================================================================== + + /// Build an env with an Eq-like Prop inductive that supports K. + fn mk_k_env() -> Env { + let mut env = mk_nat_env(); + let u = mk_name("u"); + let eq_name = mk_name("MyEq"); + let eq_refl = mk_name2("MyEq", "refl"); + + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn check_k_flag_valid_via_check_declar() { + let env = mk_k_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("MyEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("MyEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![RecursorRule { + ctor: mk_name2("MyEq", "refl"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + #[test] + fn check_k_flag_invalid_on_nat_via_check_declar() { + // K=true on Nat (Sort 1, 2 ctors) should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "recK"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![ + RecursorRule { + ctor: mk_name2("Nat", "zero"), + n_fields: Nat::from(0u64), + rhs: Expr::sort(Level::zero()), + }, + RecursorRule { + ctor: mk_name2("Nat", "succ"), + n_fields: Nat::from(1u64), + rhs: Expr::sort(Level::zero()), + }, + ], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } } diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs index d7cef49a..d4500e85 100644 --- a/src/ix/kernel/whnf.rs +++ b/src/ix/kernel/whnf.rs @@ -509,9 +509,8 @@ pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { eprintln!("[whnf_dag] depth={depth} total={total} no_delta={no_delta}"); } if depth > 200 { - eprintln!("[whnf_dag] DEPTH LIMIT depth={depth}, bailing"); WHNF_DEPTH.fetch_sub(1, Ordering::Relaxed); - return; + panic!("[whnf_dag] DEPTH LIMIT exceeded (depth={depth}): possible infinite reduction or extremely deep term"); } const WHNF_STEP_LIMIT: u64 = 100_000; @@ -520,9 +519,8 @@ pub(crate) fn whnf_dag(dag: &mut DAG, env: &Env, no_delta: bool) { loop { steps += 1; if steps > WHNF_STEP_LIMIT { - eprintln!("[whnf_dag] step limit exceeded ({steps}) depth={depth}"); whnf_done(depth); - return; + panic!("[whnf_dag] step limit exceeded ({steps} steps at depth={depth}): possible infinite reduction"); } if steps <= 5 || steps % 10_000 == 0 { let head_variant = match dag.head { @@ -925,7 +923,9 @@ pub(crate) fn try_reduce_nat_dag( } else if *name == mk_name2("Nat", "ble") { Some(bool_to_dag(a <= b)) } else if *name == mk_name2("Nat", "pow") { + // Limit exponent to prevent OOM (matches yatima's 2^24 limit) let exp = u32::try_from(&b).unwrap_or(u32::MAX); + if exp > (1 << 24) { return None; } Some(nat_lit_dag(Nat(a.pow(exp)))) } else if *name == mk_name2("Nat", "land") { Some(nat_lit_dag(Nat(a & b))) @@ -934,7 +934,9 @@ pub(crate) fn try_reduce_nat_dag( } else if *name == mk_name2("Nat", "xor") { Some(nat_lit_dag(Nat(a ^ b))) } else if *name == mk_name2("Nat", "shiftLeft") { + // Limit shift to prevent OOM let shift = u64::try_from(&b).unwrap_or(u64::MAX); + if shift > (1 << 24) { return None; } Some(nat_lit_dag(Nat(a << shift))) } else if *name == mk_name2("Nat", "shiftRight") { let shift = u64::try_from(&b).unwrap_or(u64::MAX); @@ -1094,7 +1096,8 @@ pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { } (&d.cnst.level_params, &d.value) }, - ConstantInfo::ThmInfo(t) => (&t.cnst.level_params, &t.value), + // Theorems are never unfolded — proof irrelevance handles them. + // ConstantInfo::ThmInfo(_) => return None, _ => return None, };