diff --git a/src/node/redeem.rs b/src/node/redeem.rs index aa0bf45e..d118992e 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -361,7 +361,7 @@ impl RedeemNode { fn convert_data( &mut self, - _: &PostOrderIterItem<&RedeemNode>, + data: &PostOrderIterItem<&RedeemNode>, inner: Inner< &Arc>, J, @@ -369,11 +369,29 @@ impl RedeemNode { &Option, >, ) -> Result, Self::Error> { + let preserved_case_target = matches!(data.node.inner(), Inner::Case(..)) + && matches!(inner.as_ref(), Inner::AssertL(..) | Inner::AssertR(..)); let converted_inner = inner .map(|node| node.cached_data()) .map_witness(Option::::clone); let retyped = ConstructData::from_inner(&self.inference_context, converted_inner) .expect("pruned types should check out if unpruned types check out"); + if preserved_case_target { + // Re-inference of assertl/assertr intentionally forgets the hidden branch. + // When that assert came from pruning a case, keep the original case target + // so final witness shrinking does not over-specialize the surviving branch. + let original_target = types::Type::complete( + &self.inference_context, + Arc::clone(&data.node.arrow().target), + ); + self.inference_context + .unify( + &retyped.arrow().target, + &original_target, + "pruned assert target = original case target", + ) + .expect("pruned assert target should stay compatible with the original case target"); + } Ok(retyped) } } @@ -607,9 +625,10 @@ impl RedeemNode { #[cfg(test)] mod tests { use super::*; + use crate::dag::{DagLike, NoSharing}; use crate::human_encoding::Forest; use crate::jet::Core; - use crate::node::SimpleFinalizer; + use crate::node::{ConstructNode, CoreConstructible, SimpleFinalizer, WitnessConstructible}; use crate::types::Final; use hex::DisplayHex; use std::collections::HashMap; @@ -1129,6 +1148,86 @@ main := comp input comp process jet_verify : 1 -> 1"#; &env, ); } + + #[test] + fn prune_case_to_assert_preserves_nested_sum_witness() { + type Node<'brand> = Arc>; + + crate::types::Context::with_context(|ctx| { + let original_witness = Value::left( + Value::product( + Value::right(Final::u1(), Value::u1(1)), + Value::right(Final::u1(), Value::u1(0)), + ), + Final::unit(), + ); + let expected_pruned_witness = Value::left( + Value::product(Value::unit(), Value::right(Final::u1(), Value::u1(0))), + Final::unit(), + ); + + let witness = Node::witness(&ctx, Some(original_witness.clone())); + let witness_ty = + crate::types::Type::complete(&ctx, Arc::new(original_witness.ty().clone())); + ctx.unify( + &witness.arrow().target, + &witness_ty, + "bind regression witness target type", + ) + .expect("regression witness type should be consistent"); + let unit = Node::unit(&ctx); + let input = Node::pair(&witness, &unit).unwrap(); + + let iden = Node::iden(&ctx); + let keep_right = Node::take(&Node::drop_(&iden)); + let keep_unit = Node::drop_(&Node::unit(&ctx)); + let left = Node::pair(&keep_unit, &keep_right).unwrap(); + // The hidden branch is what pins the surviving nested sum target to 1 + 1. + let right_value = + Value::product(Value::unit(), Value::right(Final::u1(), Value::u1(1))); + let right_const = Node::scribe(&ctx, &right_value); + let right = Node::drop_(&right_const); + let process = Node::case(&left, &right).unwrap(); + let main = Node::comp(&input, &process).unwrap(); + + let unpruned = main + .finalize_unpruned() + .expect("unpruned regression program should finalize"); + let mut mac = + BitMachine::for_program(&unpruned).expect("unpruned program has reasonable bounds"); + let unpruned_output = mac + .exec(&unpruned, &()) + .expect("unpruned program should execute"); + + let pruned = unpruned + .prune(&()) + .expect("pruning should succeed for the regression program"); + assert!( + pruned + .as_ref() + .post_order_iter::() + .any(|data| matches!(data.node.inner(), Inner::AssertL(..))), + "pruning should turn the outer case into assertl", + ); + + let pruned_witness = pruned + .as_ref() + .post_order_iter::() + .find_map(|data| match data.node.inner() { + Inner::Witness(value) => Some(value.clone()), + _ => None, + }) + .expect("pruned program should still contain the witness node"); + assert_eq!(pruned_witness, expected_pruned_witness); + + let mut mac = + BitMachine::for_program(&pruned).expect("pruned program has reasonable bounds"); + let pruned_output = mac + .exec(&pruned, &()) + .expect("pruned program should execute"); + assert_eq!(pruned_output, unpruned_output); + }); + } } #[cfg(bench)]