Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions src/node/redeem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,37 @@ impl<J: Jet> RedeemNode<J> {

fn convert_data(
&mut self,
_: &PostOrderIterItem<&RedeemNode<J>>,
data: &PostOrderIterItem<&RedeemNode<J>>,
inner: Inner<
&Arc<ConstructNode<'brand, J>>,
J,
&Option<Arc<ConstructNode<'brand, J>>>,
&Option<Value>,
>,
) -> Result<ConstructData<'brand, J>, Self::Error> {
let preserved_case_target = matches!(data.node.inner(), Inner::Case(..))
&& matches!(inner.as_ref(), Inner::AssertL(..) | Inner::AssertR(..));
let converted_inner = inner
.map(|node| node.cached_data())
.map_witness(Option::<Value>::clone);
let retyped = ConstructData::from_inner(&self.inference_context, converted_inner)
.expect("pruned types should check out if unpruned types check out");
if preserved_case_target {
// Re-inference of assertl/assertr intentionally forgets the hidden branch.
// When that assert came from pruning a case, keep the original case target
// so final witness shrinking does not over-specialize the surviving branch.
let original_target = types::Type::complete(
&self.inference_context,
Arc::clone(&data.node.arrow().target),
);
self.inference_context
.unify(
&retyped.arrow().target,
&original_target,
"pruned assert target = original case target",
)
.expect("pruned assert target should stay compatible with the original case target");
}
Ok(retyped)
}
}
Expand Down Expand Up @@ -607,9 +625,10 @@ impl<J: Jet> RedeemNode<J> {
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{DagLike, NoSharing};
use crate::human_encoding::Forest;
use crate::jet::Core;
use crate::node::SimpleFinalizer;
use crate::node::{ConstructNode, CoreConstructible, SimpleFinalizer, WitnessConstructible};
use crate::types::Final;
use hex::DisplayHex;
use std::collections::HashMap;
Expand Down Expand Up @@ -1129,6 +1148,86 @@ main := comp input comp process jet_verify : 1 -> 1"#;
&env,
);
}

#[test]
fn prune_case_to_assert_preserves_nested_sum_witness() {
type Node<'brand> = Arc<ConstructNode<'brand, Core>>;

crate::types::Context::with_context(|ctx| {
let original_witness = Value::left(
Value::product(
Value::right(Final::u1(), Value::u1(1)),
Value::right(Final::u1(), Value::u1(0)),
),
Final::unit(),
);
let expected_pruned_witness = Value::left(
Value::product(Value::unit(), Value::right(Final::u1(), Value::u1(0))),
Final::unit(),
);

let witness = Node::witness(&ctx, Some(original_witness.clone()));
let witness_ty =
crate::types::Type::complete(&ctx, Arc::new(original_witness.ty().clone()));
ctx.unify(
&witness.arrow().target,
&witness_ty,
"bind regression witness target type",
)
.expect("regression witness type should be consistent");
let unit = Node::unit(&ctx);
let input = Node::pair(&witness, &unit).unwrap();

let iden = Node::iden(&ctx);
let keep_right = Node::take(&Node::drop_(&iden));
let keep_unit = Node::drop_(&Node::unit(&ctx));
let left = Node::pair(&keep_unit, &keep_right).unwrap();
// The hidden branch is what pins the surviving nested sum target to 1 + 1.
let right_value =
Value::product(Value::unit(), Value::right(Final::u1(), Value::u1(1)));
let right_const = Node::scribe(&ctx, &right_value);
let right = Node::drop_(&right_const);
let process = Node::case(&left, &right).unwrap();
let main = Node::comp(&input, &process).unwrap();

let unpruned = main
.finalize_unpruned()
.expect("unpruned regression program should finalize");
let mut mac =
BitMachine::for_program(&unpruned).expect("unpruned program has reasonable bounds");
let unpruned_output = mac
.exec(&unpruned, &())
.expect("unpruned program should execute");

let pruned = unpruned
.prune(&())
.expect("pruning should succeed for the regression program");
assert!(
pruned
.as_ref()
.post_order_iter::<NoSharing>()
.any(|data| matches!(data.node.inner(), Inner::AssertL(..))),
"pruning should turn the outer case into assertl",
);

let pruned_witness = pruned
.as_ref()
.post_order_iter::<NoSharing>()
.find_map(|data| match data.node.inner() {
Inner::Witness(value) => Some(value.clone()),
_ => None,
})
.expect("pruned program should still contain the witness node");
assert_eq!(pruned_witness, expected_pruned_witness);

let mut mac =
BitMachine::for_program(&pruned).expect("pruned program has reasonable bounds");
let pruned_output = mac
.exec(&pruned, &())
.expect("pruned program should execute");
assert_eq!(pruned_output, unpruned_output);
});
}
}

#[cfg(bench)]
Expand Down
Loading