Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/lean_vm/src/isa/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ impl CustomHint {
let remaining_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize();
let to_decompose_ptr = args[2].read_value(ctx.memory, ctx.fp)?.to_usize();
let num_to_decompose = args[3].read_value(ctx.memory, ctx.fp)?.to_usize();
let w = args[4].read_value(ctx.memory, ctx.fp)?.to_usize();
assert!(w == 2 || w == 3 || w == 4);
let chunk_size = args[4].read_value(ctx.memory, ctx.fp)?.to_usize();
assert!(24_usize.is_multiple_of(chunk_size));
let mut memory_index_decomposed = decomposed_ptr;
let mut memory_index_remaining = remaining_ptr;
for i in 0..num_to_decompose {
let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize();
for i in 0..24 / w {
let value = F::from_usize((value >> (w * i)) & ((1 << w) - 1));
for i in 0..24 / chunk_size {
let value = F::from_usize((value >> (chunk_size * i)) & ((1 << chunk_size) - 1));
ctx.memory.set(memory_index_decomposed, value)?;
memory_index_decomposed += 1;
}
Expand Down
9 changes: 8 additions & 1 deletion crates/rec_aggregation/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,19 @@ fn build_aggregation(
slot,
topology.log_inv_rate,
prox_gaps_conjecture,
tracing,
);
let elapsed = time.elapsed();

if tracing {
println!("{}", result.metadata.display());
if topology.children.is_empty() {
println!(
"{} XMSS/s",
(topology.raw_xmss as f64 / elapsed.as_secs_f64()).round() as usize
);
} else {
println!("{}s the final aggregation step", elapsed.as_secs_f64());
}
}

if !tracing {
Expand Down
5 changes: 0 additions & 5 deletions crates/rec_aggregation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ pub fn aggregate(
slot: u32,
log_inv_rate: usize,
prox_gaps_conjecture: bool,
tracing: bool,
) -> AggregatedSigs {
raw_xmss.sort_by_key(|(a, _)| Digest(a.merkle_root));
raw_xmss.dedup_by(|(a, _), (b, _)| a.merkle_root == b.merkle_root);
Expand Down Expand Up @@ -351,10 +350,6 @@ pub fn aggregate(
false,
);

if tracing {
println!("{}", execution_proof.metadata.display());
}

AggregatedSigs {
pub_keys: global_pub_keys,
proof: execution_proof.proof,
Expand Down
97 changes: 63 additions & 34 deletions crates/rec_aggregation/xmss_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,66 +36,95 @@ def xmss_verify(merkle_root, message, signature, slot_lo, slot_hi, merkle_chunks
encoding_fe = Array(DIGEST_LEN)
poseidon16(b_input, b_input + DIGEST_LEN, encoding_fe)

encoding = Array(NUM_ENCODING_FE * 24 / W)
encoding = Array(NUM_ENCODING_FE * 24 / (2 * W))
remaining = Array(NUM_ENCODING_FE)

# TODO: decompose by chunks of 2.w bits (or even 3.w bits) and use a big match on the w^2 (or w^3) possibilities
hint_decompose_bits_xmss(
encoding,
remaining,
encoding_fe,
NUM_ENCODING_FE,
W
2 * W
)

# check that the decomposition is correct
for i in unroll(0, NUM_ENCODING_FE):
for j in unroll(0, 24 / W):
assert encoding[i * (24 / W) + j] < CHAIN_LENGTH
for j in unroll(0, 24 / (2 * W)):
assert encoding[i * (24 / (2 * W)) + j] < CHAIN_LENGTH**2

assert remaining[i] < 2**7 - 1
assert remaining[i] < 2**7 - 1 # ensures uniformity + prevent overflow

partial_sum: Mut = remaining[i] * 2**24
for j in unroll(0, 24/W):
partial_sum += encoding[i * (24 / W) + j] * CHAIN_LENGTH ** j
for j in unroll(0, 24/(2*W)):
partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH ** 2) ** j
assert partial_sum == encoding_fe[i]

# we need to check the target sum
target_sum: Mut = encoding[0]
for i in unroll(1, V):
target_sum += encoding[i]
assert target_sum == TARGET_SUM

# grinding
for i in unroll(V, V + V_GRINDING):
assert encoding[i] == CHAIN_LENGTH - 1
debug_assert(V_GRINDING % 2 == 0)
debug_assert(V % 2 == 0)
for i in unroll(V / 2, (V + V_GRINDING) / 2):
assert encoding[i] == CHAIN_LENGTH**2 - 1

target_sum: Mut = 0

wots_public_key = Array(V * DIGEST_LEN)

for i in unroll(0, V):
num_hashes = (CHAIN_LENGTH - 1) - encoding[i]
chain_start = chain_starts + i * DIGEST_LEN
chain_end = wots_public_key + i * DIGEST_LEN
match_range(num_hashes,
range(0, 1), lambda _: copy_8(chain_start, chain_end),
range(1, 2), lambda _: poseidon16(chain_start, ZERO_VEC_PTR, chain_end),
range(2, CHAIN_LENGTH), lambda num_hashes_const: chain_hash(chain_start, num_hashes_const, chain_end))
for i in unroll(0, V / 2):
# num_hashes = (CHAIN_LENGTH - 1) - encoding[i]
chain_start = chain_starts + i * (DIGEST_LEN * 2)
chain_end = wots_public_key + i * (DIGEST_LEN * 2)
pair_chain_length_sum_ptr = Array(1)
match_range(encoding[i], range(0, CHAIN_LENGTH**2), lambda n: chain_hash(chain_start, n, chain_end, pair_chain_length_sum_ptr))
target_sum += pair_chain_length_sum_ptr[0]

assert target_sum == TARGET_SUM

wots_pubkey_hashed = slice_hash(wots_public_key, V)
xmss_merkle_verify(wots_pubkey_hashed, merkle_path, merkle_chunks, merkle_root)
return


def chain_hash(input, n: Const, output):
debug_assert(2 <= n)
states = Array((n-1) * DIGEST_LEN)
poseidon16(input, ZERO_VEC_PTR, states)
state_indexes = Array(n - 1)
state_indexes[0] = states
for i in unroll(1, n-1):
state_indexes[i] = state_indexes[i - 1] + DIGEST_LEN
poseidon16(state_indexes[i - 1], ZERO_VEC_PTR, state_indexes[i])
poseidon16(state_indexes[n - 2], ZERO_VEC_PTR, output)
def chain_hash(input_left, n: Const, output_left, pair_chain_length_sum_ptr):
debug_assert(n < CHAIN_LENGTH**2)

raw_left = n % CHAIN_LENGTH
raw_right = (n - raw_left) / CHAIN_LENGTH

n_left = (CHAIN_LENGTH - 1) - raw_left
if n_left == 0:
copy_8(input_left, output_left)
elif n_left == 1:
poseidon16(input_left, ZERO_VEC_PTR, output_left)
else:
states_left = Array((n_left-1) * DIGEST_LEN)
poseidon16(input_left, ZERO_VEC_PTR, states_left)
state_indexes_left = Array(n_left - 1)
state_indexes_left[0] = states_left
for i in unroll(1, n_left-1):
state_indexes_left[i] = state_indexes_left[i - 1] + DIGEST_LEN
poseidon16(state_indexes_left[i - 1], ZERO_VEC_PTR, state_indexes_left[i])
poseidon16(state_indexes_left[n_left - 2], ZERO_VEC_PTR, output_left)

n_right = (CHAIN_LENGTH - 1) - raw_right
debug_assert(raw_right < CHAIN_LENGTH)
input_right = input_left + DIGEST_LEN
output_right = output_left + DIGEST_LEN
if n_right == 0:
copy_8(input_right, output_right)
elif n_right == 1:
poseidon16(input_right, ZERO_VEC_PTR, output_right)
else:
states_right = Array((n_right-1) * DIGEST_LEN)
poseidon16(input_right, ZERO_VEC_PTR, states_right)
state_indexes_right = Array(n_right - 1)
state_indexes_right[0] = states_right
for i in unroll(1, n_right-1):
state_indexes_right[i] = state_indexes_right[i - 1] + DIGEST_LEN
poseidon16(state_indexes_right[i - 1], ZERO_VEC_PTR, state_indexes_right[i])
poseidon16(state_indexes_right[n_right - 2], ZERO_VEC_PTR, output_right)

pair_chain_length_sum_ptr[0] = raw_left + raw_right

return


Expand Down
6 changes: 3 additions & 3 deletions crates/xmss/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ type F = KoalaBear;
type Digest = [F; DIGEST_SIZE];

// WOTS
pub const V: usize = 40;
pub const V: usize = 42;
pub const W: usize = 3;
pub const CHAIN_LENGTH: usize = 1 << W;
pub const NUM_CHAIN_HASHES: usize = 120;
pub const NUM_CHAIN_HASHES: usize = 110;
pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES;
pub const V_GRINDING: usize = 3;
pub const V_GRINDING: usize = 2;
pub const LOG_LIFETIME: usize = 32;
pub const RANDOMNESS_LEN_FE: usize = 7;
pub const MESSAGE_LEN_FE: usize = 9;
Expand Down
2 changes: 1 addition & 1 deletion crates/xmss/test_data/benchmark_signers.json

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ mod tests {
slot,
log_inv_rate,
prox_gaps_conjecture,
false,
);

let pub_keys_and_sigs_b: Vec<_> = (3..5)
Expand All @@ -73,7 +72,6 @@ mod tests {
slot,
log_inv_rate,
prox_gaps_conjecture,
false,
);

let pub_keys_and_sigs_c: Vec<_> = (5..6)
Expand All @@ -88,7 +86,6 @@ mod tests {
slot,
log_inv_rate,
prox_gaps_conjecture,
false,
);

verify_aggregation(&aggregated_final, &message, slot, prox_gaps_conjecture).unwrap();
Expand Down
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn main() {
raw_xmss: 0,
children: vec![
AggregationTopology {
raw_xmss: 675,
raw_xmss: 700,
children: vec![],
log_inv_rate,
};
Expand Down Expand Up @@ -100,7 +100,7 @@ fn main() {
raw_xmss: 25,
children: vec![
AggregationTopology {
raw_xmss: 1350,
raw_xmss: 1400,
children: vec![],
log_inv_rate: 1,
};
Expand All @@ -114,7 +114,7 @@ fn main() {
raw_xmss: 0,
children: vec![
AggregationTopology {
raw_xmss: 1350,
raw_xmss: 1400,
children: vec![],
log_inv_rate: 2,
};
Expand Down