diff --git a/Cargo.lock b/Cargo.lock index 9e293fd2b8..f65f0ac59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,9 +4946,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-derive" @@ -11003,30 +11003,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs index 786ddb6ac0..a6b1aa89bd 100644 --- a/forester-utils/src/address_staging_tree.rs +++ b/forester-utils/src/address_staging_tree.rs @@ -121,7 +121,7 @@ impl AddressStagingTree { low_element_next_values: &[[u8; 32]], low_element_indices: &[u64], low_element_next_indices: &[u64], - low_element_proofs: &[Vec<[u8; 32]>], + low_element_proofs: &[[[u8; 32]; HEIGHT]], leaves_hashchain: [u8; 32], zkp_batch_size: usize, epoch: u64, @@ -145,15 +145,12 @@ impl AddressStagingTree { let inputs = get_batch_address_append_circuit_inputs::( next_index, old_root, - low_element_values.to_vec(), - low_element_next_values.to_vec(), - low_element_indices.iter().map(|v| *v as usize).collect(), - low_element_next_indices - .iter() - .map(|v| *v as usize) - .collect(), - low_element_proofs.to_vec(), - addresses.to_vec(), + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, &mut self.sparse_tree, leaves_hashchain, zkp_batch_size, diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index ed135cb6a4..a0f3e3bb5b 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -9,6 +9,7 @@ use light_client::{ indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, rpc::Rpc, }; +use light_hasher::hash_chain::create_hash_chain_from_slice; use crate::processor::v2::{common::clamp_to_u16, BatchContext}; @@ -22,6 +23,17 @@ pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> Mu } } +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { let rpc = context.rpc_pool.get_connection().await?; let mut account = rpc @@ -474,20 +486,52 @@ impl StreamingAddressQueue { } } - pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { let available = self.wait_for_batch(end); if start >= available { - return None; + return Ok(None); } let actual_end = end.min(available); let data = lock_recover(&self.data, "streaming_address_queue.data"); - Some(BatchDataSlice { - addresses: data.addresses[start..actual_end].to_vec(), + + let addresses = data.addresses[start..actual_end].to_vec(); + if addresses.is_empty() { + return Err(anyhow!("Empty batch at start={}", start)); + } + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { low_element_values: data.low_element_values[start..actual_end].to_vec(), low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - }) + low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + addresses, + leaves_hashchain, + })) } pub fn into_data(self) -> AddressQueueData { @@ -553,15 +597,6 @@ impl StreamingAddressQueue { } } -#[derive(Debug, Clone)] -pub struct BatchDataSlice { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, -} - pub async fn fetch_streaming_address_batches( context: &BatchContext, total_elements: u64, diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 06e94d5500..51ab05143a 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -14,11 +14,10 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, - common::get_leaves_hashchain, errors::V2Error, helpers::{ fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - lock_recover, StreamingAddressQueue, + AddressBatchSnapshot, StreamingAddressQueue, }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, @@ -267,9 +266,23 @@ impl BatchJobBuilder for AddressQueueData { let batch_end = start + zkp_batch_size_usize; - let batch_data = self - .streaming_queue - .get_batch_data(start, batch_end) + let streaming_queue = &self.streaming_queue; + let staging_tree = &mut self.staging_tree; + let hashchain_idx = start / zkp_batch_size_usize; + let AddressBatchSnapshot { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + leaves_hashchain, + } = streaming_queue + .get_batch_snapshot::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + start, + batch_end, + hashchain_idx, + )? .ok_or_else(|| { anyhow!( "Batch data not available: start={}, end={}, available={}", @@ -278,31 +291,21 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.available_batches() * zkp_batch_size_usize ) })?; - - let addresses = &batch_data.addresses; let zkp_batch_size_actual = addresses.len(); - - if zkp_batch_size_actual == 0 { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let low_element_values = &batch_data.low_element_values; - let low_element_next_values = &batch_data.low_element_next_values; - let low_element_indices = &batch_data.low_element_indices; - let low_element_next_indices = &batch_data.low_element_next_indices; - - let low_element_proofs: Vec> = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - (start..start + zkp_batch_size_actual) - .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) - .collect::, _>>()? - }; - - let hashchain_idx = start / zkp_batch_size_usize; - let leaves_hashchain = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? - }; + let result = staging_tree + .process_batch( + &addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ) + .map_err(|err| map_address_staging_error(tree, err))?; let tree_batch = tree_next_index / zkp_batch_size_usize; let absolute_index = data_start + start; @@ -318,24 +321,6 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.is_complete() ); - let result = self.staging_tree.process_batch( - addresses, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - &low_element_proofs, - leaves_hashchain, - zkp_batch_size_actual, - epoch, - tree, - ); - - let result = match result { - Ok(r) => r, - Err(err) => return Err(map_address_staging_error(tree, err)), - }; - Ok(Some(( ProofInput::AddressAppend(result.circuit_inputs), result.new_root, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index a16a7925a7..097ae1f9a9 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -73,7 +73,6 @@ use account_compression::{ use anchor_lang::{prelude::AccountMeta, AnchorSerialize, Discriminator}; use create_address_test_program::create_invoke_cpi_instruction; use forester_utils::{ - account_zero_copy::AccountZeroCopy, address_merkle_tree_config::{address_tree_ready_for_rollover, state_tree_ready_for_rollover}, forester_epoch::{Epoch, Forester, TreeAccounts}, utils::airdrop_lamports, @@ -194,6 +193,7 @@ use crate::{ }, test_batch_forester::{perform_batch_append, perform_batch_nullify}, test_forester::{empty_address_queue_test, nullify_compressed_accounts}, + AccountZeroCopy, }; pub struct User { @@ -748,70 +748,67 @@ where .with_address_queue(None, Some(batch.batch_size as u16)); let result = self .indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) + .get_queue_elements( + merkle_tree_pubkey.to_bytes(), + options, + None, + ) .await .unwrap(); - let addresses = result - .value - .address_queue - .map(|aq| aq.addresses) - .unwrap_or_default(); + let address_queue = result.value.address_queue.unwrap(); + let low_element_proofs = address_queue + .reconstruct_all_proofs::<{ + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize + }>() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = self - .indexer - .get_multiple_new_address_proofs( - merkle_tree_pubkey.to_bytes(), - addresses.clone(), - None, - ) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices - .push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices - .push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values - .push(non_inclusion_proof.low_address_next_value); - - low_element_proofs - .push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = self.indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); - let mut sparse_merkle_tree = SparseMerkleTree::::new(<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items).unwrap(), start_index); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + subtrees, + .. + } = address_queue; + let mut sparse_merkle_tree = SparseMerkleTree::< + Poseidon, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >::new( + subtrees.as_slice().try_into().unwrap(), + start_index, + ); - let mut changelog: Vec> = Vec::new(); - let mut indexed_changelog: Vec> = Vec::new(); + let mut changelog: Vec< + ChangelogEntry<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>, + > = Vec::new(); + let mut indexed_changelog: Vec< + IndexedChangelogEntry< + usize, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >, + > = Vec::new(); let inputs = get_batch_address_append_circuit_inputs::< { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, >( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -834,9 +831,13 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| RpcError::CustomError(error.to_string())) + .unwrap(); + let (proof_a, proof_b, proof_c) = + proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 4458aa03b3..0101b235aa 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -260,7 +260,7 @@ impl MockBatchedAddressForester { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; HEIGHT]> = Vec::new(); for new_element_value in &new_element_values { let non_inclusion_proof = self .merkle_tree @@ -270,7 +270,18 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + let proof = non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; + low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); let mut merkle_tree = match <[[u8; 32]; HEIGHT]>::try_from(subtrees) { @@ -287,12 +298,12 @@ impl MockBatchedAddressForester { let inputs = match get_batch_address_append_circuit_inputs::( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values.clone(), + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut merkle_tree, leaves_hashchain, zkp_batch_size as usize, diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 8e6909704f..a28efa9ca3 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -319,13 +319,13 @@ pub async fn get_batched_nullify_ix_data( }) } -use forester_utils::{ - account_zero_copy::AccountZeroCopy, instructions::create_account::create_account_instruction, -}; +use forester_utils::instructions::create_account::create_account_instruction; use light_client::indexer::{Indexer, QueueElementsV2Options}; use light_program_test::indexer::state_tree::StateMerkleTreeBundle; use light_sparse_merkle_tree::SparseMerkleTree; +use crate::AccountZeroCopy; + pub async fn assert_registry_created_batched_state_merkle_tree( rpc: &mut R, payer_pubkey: Pubkey, @@ -663,50 +663,33 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_slice(addresses.as_slice()).unwrap(); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 1, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = indexer - .get_multiple_new_address_proofs(merkle_tree_pubkey.to_bytes(), addresses.clone(), None) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices.push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices.push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values.push(non_inclusion_proof.low_address_next_value); - - low_element_proofs.push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + subtrees, + .. + } = address_queue; let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new( - <[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items) - .unwrap(), - start_index, - ); + >::new(subtrees.as_slice().try_into().unwrap(), start_index); let mut changelog: Vec> = Vec::new(); @@ -718,12 +701,12 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, diff --git a/prover/client/src/constants.rs b/prover/client/src/constants.rs index 18a5c05a45..151bf87918 100644 --- a/prover/client/src/constants.rs +++ b/prover/client/src/constants.rs @@ -1,4 +1,4 @@ -pub const SERVER_ADDRESS: &str = "http://localhost:3001"; +pub const SERVER_ADDRESS: &str = "http://127.0.0.1:3001"; pub const HEALTH_CHECK: &str = "/health"; pub const PROVE_PATH: &str = "/prove"; diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index 85c1bc8fbe..e095bf3579 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -37,6 +37,9 @@ pub enum ProverClientError { #[error("Invalid proof data: {0}")] InvalidProofData(String), + #[error("Integer conversion failed: {0}")] + IntegerConversion(String), + #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 6ea223e79f..98457479e2 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -6,6 +6,8 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; +use crate::errors::ProverClientError; + pub fn get_project_root() -> Option { let output = Command::new("git") .args(["rev-parse", "--show-toplevel"]) @@ -48,7 +50,7 @@ pub fn compute_root_from_merkle_proof( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], path_index: u32, -) -> ([u8; 32], ChangelogEntry) { +) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); let mut current_hash = leaf; @@ -56,14 +58,14 @@ pub fn compute_root_from_merkle_proof( for (level, path_element) in path_elements.iter().enumerate() { changelog_entry.path[level] = Some(current_hash); if current_index.is_multiple_of(2) { - current_hash = Poseidon::hashv(&[¤t_hash, path_element]).unwrap(); + current_hash = Poseidon::hashv(&[¤t_hash, path_element])?; } else { - current_hash = Poseidon::hashv(&[path_element, ¤t_hash]).unwrap(); + current_hash = Poseidon::hashv(&[path_element, ¤t_hash])?; } current_index /= 2; } - (current_hash, changelog_entry) + Ok((current_hash, changelog_entry)) } pub fn big_uint_to_string(big_uint: &BigUint) -> String { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index f80e8d49e4..32408fdc02 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; use light_hasher::{ bigint::bigint_to_be_bytes_array, @@ -187,21 +187,28 @@ impl BatchAddressAppendInputs { pub fn get_batch_address_append_circuit_inputs( next_index: usize, current_root: [u8; 32], - low_element_values: Vec<[u8; 32]>, - low_element_next_values: Vec<[u8; 32]>, - low_element_indices: Vec, - low_element_next_indices: Vec, - low_element_proofs: Vec>, - new_element_values: Vec<[u8; 32]>, + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[impl Copy + TryInto + Debug], + low_element_next_indices: &[impl Copy + TryInto + Debug], + low_element_proofs: &[[[u8; 32]; HEIGHT]], + new_element_values: &[[u8; 32]], sparse_merkle_tree: &mut SparseMerkleTree, leaves_hashchain: [u8; 32], zkp_batch_size: usize, changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); - - let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + let new_element_values = &new_element_values[..zkp_batch_size]; + let mut new_root = [0u8; 32]; + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + + let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) })?; if computed_hashchain != leaves_hashchain { @@ -229,15 +236,6 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = vec![]; - let mut new_element_circuit_merkle_proofs = vec![]; - - let mut patched_low_element_next_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_next_indices: Vec = Vec::new(); - let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_indices: Vec = Vec::new(); - let mut patcher = ChangelogProofPatcher::new::(changelog); let is_first_batch = indexed_changelog.is_empty(); @@ -245,21 +243,33 @@ pub fn get_batch_address_append_circuit_inputs( for i in 0..new_element_values.len() { let mut changelog_index = 0; + let low_element_index = low_element_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element index {:?} does not fit into usize", + low_element_indices[i] + )) + })?; + let low_element_next_index = low_element_next_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element next index {:?} does not fit into usize", + low_element_next_indices[i] + )) + })?; let new_element_index = next_index + i; let mut low_element: IndexedElement = IndexedElement { - index: low_element_indices[i], + index: low_element_index, value: BigUint::from_bytes_be(&low_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; let mut new_element: IndexedElement = IndexedElement { index: new_element_index, value: BigUint::from_bytes_be(&new_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof = low_element_proofs[i]; let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); patch_indexed_changelogs( 0, @@ -293,18 +303,10 @@ pub fn get_batch_address_append_circuit_inputs( next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; + let low_element_changelog_proof = low_element_proof; let intermediate_root = { - let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof - .clone() - .try_into() - .map_err(|v: Vec<[u8; 32]>| { - ProverClientError::ProofPatchFailed(format!( - "low element proof length mismatch: expected {}, got {}", - HEIGHT, - v.len() - )) - })?; + let mut low_element_proof_arr = low_element_changelog_proof; patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); let merkle_proof = low_element_proof_arr; @@ -321,7 +323,7 @@ pub fn get_batch_address_append_circuit_inputs( old_low_leaf_hash, &merkle_proof, low_element.index as u32, - ); + )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) .map_err(|e| { @@ -362,7 +364,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_leaf_hash, &merkle_proof, new_low_element.index as u32, - ); + )?; patcher.push_changelog_entry::(changelog, changelog_entry); low_element_circuit_merkle_proofs.push( @@ -376,13 +378,7 @@ pub fn get_batch_address_append_circuit_inputs( }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT] - .try_into() - .map_err(|_| { - ProverClientError::ProofPatchFailed( - "low_element_proof slice conversion failed".to_string(), - ) - })?, + proof: low_element_changelog_proof, changelog_index: indexed_changelog.len(), //change_log_index, }; @@ -409,7 +405,7 @@ pub fn get_batch_address_append_circuit_inputs( new_element_leaf_hash, &merkle_proof_array, current_index as u32, - ); + )?; if i == 0 && changelog.len() == 1 { if sparse_next_idx_before != current_index { @@ -436,7 +432,7 @@ pub fn get_batch_address_append_circuit_inputs( zero_hash, &merkle_proof_array, current_index as u32, - ); + )?; if root_with_zero != intermediate_root { tracing::error!( "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index ef0327ac1d..7dd578e599 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -187,8 +187,11 @@ pub fn get_batch_append_inputs( }; // Update the root based on the current proof and nullifier - let (updated_root, changelog_entry) = - compute_root_from_merkle_proof(final_leaf, &merkle_proof_array, start_index + i as u32); + let (updated_root, changelog_entry) = compute_root_from_merkle_proof( + final_leaf, + &merkle_proof_array, + start_index + i as u32, + )?; new_root = updated_root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2136d01d10..7f8c08e0d1 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -175,7 +175,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index); + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 3bf1bab785..56ae20d98a 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -51,7 +51,10 @@ pub async fn spawn_prover() { } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = reqwest::Client::new(); + let client = match reqwest::Client::builder().no_proxy().build() { + Ok(client) => client, + Err(_) => return false, + }; let mut result = false; for _ in 0..retries { match client diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 22f58d5362..ac73c3809e 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -45,7 +45,8 @@ async fn prove_batch_address_append() { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::new(); // Generate non-inclusion proofs for each element for new_element_value in &new_element_values { @@ -57,7 +58,13 @@ async fn prove_batch_address_append() { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); } // Convert big integers to byte arrays @@ -87,12 +94,12 @@ async fn prove_batch_address_append() { get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut sparse_merkle_tree, hash_chain, zkp_batch_size, diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 26d16ae235..5698719c8f 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -1142,17 +1142,16 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - let mut proof_vec = x.proof.clone(); - if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { return Err(IndexerError::InvalidParameters(format!( "Merkle proof length ({}) is less than canopy depth ({})", - proof_vec.len(), + x.proof.len(), STATE_MERKLE_TREE_CANOPY_DEPTH, ))); } - proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH); + let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; - let proof = proof_vec + let proof = x.proof[..proof_len] .iter() .map(|s| Hash::from_base58(s)) .collect::, IndexerError>>() @@ -1703,15 +1702,13 @@ impl Indexer for PhotonIndexer { async fn get_subtrees( &self, - _merkle_tree_pubkey: [u8; 32], + merkle_tree_pubkey: [u8; 32], _config: Option, ) -> Result>, IndexerError> { - #[cfg(not(feature = "v2"))] - unimplemented!(); - #[cfg(feature = "v2")] - { - todo!(); - } + Err(IndexerError::NotImplemented(format!( + "PhotonIndexer::get_subtrees is not implemented for merkle tree {}", + solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey) + ))) } } diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 40e7cc0f6e..3f52d72798 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::super::IndexerError; #[derive(Debug, Clone, PartialEq, Default)] @@ -64,13 +66,14 @@ pub struct AddressQueueData { } impl AddressQueueData { + const ADDRESS_TREE_HEIGHT: usize = 40; + /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. - /// The tree_height is needed to know how many levels to traverse. - pub fn reconstruct_proof( + pub fn reconstruct_proof( &self, address_idx: usize, - tree_height: u8, - ) -> Result, IndexerError> { + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -81,10 +84,10 @@ impl AddressQueueData { ), } })?; - let mut proof = Vec::with_capacity(tree_height as usize); + let mut proof = [[0u8; 32]; HEIGHT]; let mut pos = leaf_index; - for level in 0..tree_height { + for (level, proof_element) in proof.iter_mut().enumerate() { let sibling_pos = if pos.is_multiple_of(2) { pos + 1 } else { @@ -114,28 +117,224 @@ impl AddressQueueData { self.node_hashes.len(), ), })?; - proof.push(*hash); + *proof_element = *hash; pos /= 2; } Ok(proof) } + /// Reconstruct a contiguous batch of proofs while reusing a single node lookup table. + pub fn reconstruct_proofs( + &self, + address_range: std::ops::Range, + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + let node_lookup = self.build_node_lookup(); + let mut proofs = Vec::with_capacity(address_range.len()); + + for address_idx in address_range { + proofs.push(self.reconstruct_proof_with_lookup::(address_idx, &node_lookup)?); + } + + Ok(proofs) + } + /// Reconstruct all proofs for all addresses - pub fn reconstruct_all_proofs( + pub fn reconstruct_all_proofs( &self, - tree_height: u8, - ) -> Result>, IndexerError> { - (0..self.addresses.len()) - .map(|i| self.reconstruct_proof(i, tree_height)) - .collect() + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + self.reconstruct_proofs::(0..self.addresses.len()) + } + + fn build_node_lookup(&self) -> HashMap { + let mut lookup = HashMap::with_capacity(self.nodes.len()); + for (idx, node) in self.nodes.iter().copied().enumerate() { + lookup.entry(node).or_insert(idx); + } + lookup + } + + fn reconstruct_proof_with_lookup( + &self, + address_idx: usize, + node_lookup: &HashMap, + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; + let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "address_idx {} out of bounds for low_element_indices (len {})", + address_idx, + self.low_element_indices.len(), + ), + } + })?; + let mut proof = [[0u8; 32]; HEIGHT]; + let mut pos = leaf_index; + + for (level, proof_element) in proof.iter_mut().enumerate() { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let sibling_idx = Self::encode_node_index(level, sibling_pos); + let hash_idx = node_lookup.get(&sibling_idx).copied().ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "Missing proof node at level {} position {} (encoded: {})", + level, sibling_pos, sibling_idx + ), + } + })?; + let hash = + self.node_hashes + .get(hash_idx) + .ok_or_else(|| IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "node_hashes index {} out of bounds (len {})", + hash_idx, + self.node_hashes.len(), + ), + })?; + *proof_element = *hash; + pos /= 2; + } + + Ok(proof) } /// Encode node index: (level << 56) | position #[inline] - fn encode_node_index(level: u8, position: u64) -> u64 { + fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } + + fn validate_proof_height(&self) -> Result<(), IndexerError> { + if HEIGHT == Self::ADDRESS_TREE_HEIGHT { + return Ok(()); + } + + Err(IndexerError::InvalidParameters(format!( + "address queue proofs require HEIGHT={} but got HEIGHT={}", + Self::ADDRESS_TREE_HEIGHT, + HEIGHT + ))) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, hint::black_box, time::Instant}; + + use super::AddressQueueData; + + fn hash_from_node(node_index: u64) -> [u8; 32] { + let mut hash = [0u8; 32]; + hash[..8].copy_from_slice(&node_index.to_le_bytes()); + hash[8..16].copy_from_slice(&node_index.rotate_left(17).to_le_bytes()); + hash[16..24].copy_from_slice(&node_index.rotate_right(9).to_le_bytes()); + hash[24..32].copy_from_slice(&(node_index ^ 0xA5A5_A5A5_A5A5_A5A5).to_le_bytes()); + hash + } + + fn build_queue_data(num_addresses: usize) -> AddressQueueData { + let low_element_indices = (0..num_addresses) + .map(|i| (i as u64).saturating_mul(2)) + .collect::>(); + let mut nodes = BTreeMap::new(); + + for &leaf_index in &low_element_indices { + let mut pos = leaf_index; + for level in 0..HEIGHT { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let node_index = ((level as u64) << 56) | sibling_pos; + nodes + .entry(node_index) + .or_insert_with(|| hash_from_node(node_index)); + pos /= 2; + } + } + + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes.into_iter().unzip(); + + AddressQueueData { + addresses: vec![[0u8; 32]; num_addresses], + low_element_values: vec![[1u8; 32]; num_addresses], + low_element_next_values: vec![[2u8; 32]; num_addresses], + low_element_indices, + low_element_next_indices: (0..num_addresses).map(|i| (i as u64) + 1).collect(), + nodes, + node_hashes, + initial_root: [9u8; 32], + leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], + subtrees: vec![[4u8; 32]; HEIGHT], + start_index: 0, + root_seq: 0, + } + } + + #[test] + fn batched_reconstruction_matches_individual_reconstruction() { + let queue = build_queue_data::<40>(128); + + let expected = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::<40>(i).unwrap()) + .collect::>(); + let actual = queue + .reconstruct_proofs::<40>(0..queue.addresses.len()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + #[ignore = "profiling helper"] + fn profile_reconstruct_proofs_batch() { + const HEIGHT: usize = 40; + const NUM_ADDRESSES: usize = 2_048; + const ITERS: usize = 25; + + let queue = build_queue_data::(NUM_ADDRESSES); + + let baseline_start = Instant::now(); + for _ in 0..ITERS { + let proofs = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::(i).unwrap()) + .collect::>(); + black_box(proofs); + } + let baseline = baseline_start.elapsed(); + + let batched_start = Instant::now(); + for _ in 0..ITERS { + black_box( + queue + .reconstruct_proofs::(0..queue.addresses.len()) + .unwrap(), + ); + } + let batched = batched_start.elapsed(); + + println!( + "queue reconstruction profile: addresses={}, height={}, iters={}, individual={:?}, batched={:?}, speedup={:.2}x", + NUM_ADDRESSES, + HEIGHT, + ITERS, + baseline, + batched, + baseline.as_secs_f64() / batched.as_secs_f64(), + ); + } } /// V2 Queue Elements Result with deduplicated node data diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 061ad5074b..c70088dd40 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -53,6 +53,9 @@ pub enum LoadAccountsError { #[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")] MissingPdaCompressed { index: usize, pubkey: Pubkey }, + #[error("Cold PDA (pubkey {pubkey}) missing data")] + MissingPdaCompressedData { pubkey: Pubkey }, + #[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")] MissingAtaCompressed { index: usize, pubkey: Pubkey }, @@ -67,6 +70,7 @@ pub enum LoadAccountsError { } const MAX_ATAS_PER_IX: usize = 8; +const MAX_PDAS_PER_IX: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -113,14 +117,18 @@ where }) .collect(); - let pda_hashes = collect_pda_hashes(&cold_pdas)?; + let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let pda_hashes = pda_groups + .iter() + .map(|group| collect_pda_hashes(group)) + .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; let (pda_proofs, ata_proofs, mint_proofs) = futures::join!( - fetch_proofs(&pda_hashes, indexer), + fetch_proof_batches(&pda_hashes, indexer), fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer), - fetch_proofs(&mint_hashes, indexer), + fetch_individual_proofs(&mint_hashes, indexer), ); let pda_proofs = pda_proofs?; @@ -136,9 +144,9 @@ where // 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs). // Token PDAs are created on-chain via CPI inside DecompressVariant. - for (spec, proof) in cold_pdas.iter().zip(pda_proofs) { + for (group, proof) in pda_groups.into_iter().zip(pda_proofs) { out.push(build_pda_load( - &[spec], + &group, proof, fee_payer, compression_config, @@ -146,8 +154,7 @@ where } // 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist - let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect(); - for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) { + for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) { out.extend(build_ata_load(chunk, proof, fee_payer)?); } @@ -195,23 +202,77 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result, Lo .collect() } -async fn fetch_proofs( +/// Groups already-ordered PDA specs into contiguous runs of the same program id. +/// +/// This preserves input order rather than globally regrouping by program. Callers that +/// want maximal batching across interleaved program ids should sort before calling. +fn group_pda_specs<'a, V>( + specs: &[&'a PdaSpec], + max_per_group: usize, +) -> Vec>> { + assert!(max_per_group > 0, "max_per_group must be non-zero"); + if specs.is_empty() { + return Vec::new(); + } + + let mut groups = Vec::new(); + let mut current = Vec::with_capacity(max_per_group); + let mut current_program: Option = None; + + for spec in specs { + let program_id = spec.program_id(); + let should_split = current_program + .map(|existing| existing != program_id || current.len() >= max_per_group) + .unwrap_or(false); + + if should_split { + groups.push(current); + current = Vec::with_capacity(max_per_group); + } + + current_program = Some(program_id); + current.push(*spec); + } + + if !current.is_empty() { + groups.push(current); + } + + groups +} + +async fn fetch_individual_proofs( hashes: &[[u8; 32]], indexer: &I, ) -> Result, IndexerError> { if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len()); - for hash in hashes { - proofs.push( - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await? - .value, - ); + + futures::future::try_join_all(hashes.iter().map(|hash| async move { + indexer + .get_validity_proof(vec![*hash], vec![], None) + .await + .map(|response| response.value) + })) + .await +} + +async fn fetch_proof_batches( + hash_batches: &[Vec<[u8; 32]>], + indexer: &I, +) -> Result, IndexerError> { + if hash_batches.is_empty() { + return Ok(vec![]); } - Ok(proofs) + + futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { + indexer + .get_validity_proof(hashes.clone(), vec![], None) + .await + .map(|response| response.value) + })) + .await } async fn fetch_proofs_batched( @@ -222,16 +283,13 @@ async fn fetch_proofs_batched( if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size)); - for chunk in hashes.chunks(batch_size) { - proofs.push( - indexer - .get_validity_proof(chunk.to_vec(), vec![], None) - .await? - .value, - ); - } - Ok(proofs) + + let hash_batches = hashes + .chunks(batch_size) + .map(|chunk| chunk.to_vec()) + .collect::>(); + + fetch_proof_batches(&hash_batches, indexer).await } fn build_pda_load( @@ -262,11 +320,16 @@ where let hot_addresses: Vec = specs.iter().map(|s| s.address()).collect(); let cold_accounts: Vec<(CompressedAccount, V)> = specs .iter() - .map(|s| { - let compressed = s.compressed().expect("cold spec must have data").clone(); - (compressed, s.variant.clone()) + .map(|s| -> Result<_, LoadAccountsError> { + let compressed = + s.compressed() + .cloned() + .ok_or(LoadAccountsError::MissingPdaCompressedData { + pubkey: s.address(), + })?; + Ok((compressed, s.variant.clone())) }) - .collect(); + .collect::, _>>()?; let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default(); diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..c51298b9cd 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,7 +726,9 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - start_index: start as u64, + // Consumers use start_index as the sparse tree's next insertion index, + // not the pagination offset used for queue slicing. + start_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { diff --git a/sparse-merkle-tree/src/indexed_changelog.rs b/sparse-merkle-tree/src/indexed_changelog.rs index bbd30e1ee6..7e6a26cff7 100644 --- a/sparse-merkle-tree/src/indexed_changelog.rs +++ b/sparse-merkle-tree/src/indexed_changelog.rs @@ -29,7 +29,7 @@ pub fn patch_indexed_changelogs( low_element: &mut IndexedElement, new_element: &mut IndexedElement, low_element_next_value: &mut BigUint, - low_leaf_proof: &mut Vec<[u8; 32]>, + low_leaf_proof: &mut [[u8; 32]; HEIGHT], ) -> Result<(), SparseMerkleTreeError> { // Tests are in program-tests/merkle-tree/tests/indexed_changelog.rs let next_indexed_changelog_indices: Vec = (*indexed_changelogs) @@ -69,7 +69,7 @@ pub fn patch_indexed_changelogs( // Patch the next value. *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value); // Patch the proof. - *low_leaf_proof = changelog_entry.proof.to_vec(); + *low_leaf_proof = changelog_entry.proof; } // If we found a new low element. @@ -82,7 +82,7 @@ pub fn patch_indexed_changelogs( next_index: new_low_element_changelog_entry.element.next_index, }; - *low_leaf_proof = new_low_element_changelog_entry.proof.to_vec(); + *low_leaf_proof = new_low_element_changelog_entry.proof; new_element.next_index = low_element.next_index; if new_low_element_changelog_index == indexed_changelogs.len() - 1 { return Ok(()); diff --git a/sparse-merkle-tree/tests/indexed_changelog.rs b/sparse-merkle-tree/tests/indexed_changelog.rs index 7d37142b46..59efda6fde 100644 --- a/sparse-merkle-tree/tests/indexed_changelog.rs +++ b/sparse-merkle-tree/tests/indexed_changelog.rs @@ -92,7 +92,8 @@ fn test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -114,7 +115,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -124,7 +125,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); println!("patched -------------------"); @@ -206,7 +207,8 @@ fn debug_test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -228,7 +230,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -238,7 +240,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); man_indexed_array.elements[low_element.index()] = low_element.clone();