diff --git a/Cargo.lock b/Cargo.lock index 9e293fd2b8..f65f0ac59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,9 +4946,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-derive" @@ -11003,30 +11003,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", diff --git a/forester-utils/src/account_zero_copy.rs b/forester-utils/src/account_zero_copy.rs index bd72ec245f..3af6692111 100644 --- a/forester-utils/src/account_zero_copy.rs +++ b/forester-utils/src/account_zero_copy.rs @@ -1,4 +1,4 @@ -use std::{fmt, marker::PhantomData, mem, pin::Pin}; +use std::{fmt, mem}; use light_client::rpc::Rpc; use light_concurrent_merkle_tree::{ @@ -8,7 +8,7 @@ use light_hash_set::HashSet; use light_hasher::Hasher; use light_indexed_merkle_tree::{copy::IndexedMerkleTreeCopy, errors::IndexedMerkleTreeError}; use num_traits::{CheckedAdd, CheckedSub, ToBytes, Unsigned}; -use solana_sdk::{account::Account, pubkey::Pubkey}; +use solana_sdk::pubkey::Pubkey; use thiserror::Error; #[derive(Error, Debug)] @@ -19,52 +19,15 @@ pub enum AccountZeroCopyError { AccountNotFound(Pubkey), } -#[derive(Debug, Clone)] -pub struct AccountZeroCopy<'a, T> { - pub account: Pin>, - deserialized: *const T, - _phantom_data: PhantomData<&'a T>, -} - -impl<'a, T> AccountZeroCopy<'a, T> { - pub async fn new( - rpc: &mut R, - address: Pubkey, - ) -> Result, AccountZeroCopyError> { - let account = rpc - .get_account(address) - .await - .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? - .ok_or(AccountZeroCopyError::AccountNotFound(address))?; - let account = Box::pin(account); - let deserialized = account.data[8..].as_ptr() as *const T; - - Ok(Self { - account, - deserialized, - _phantom_data: PhantomData, - }) - } - - // Safe method to access `deserialized` ensuring the lifetime is respected - pub fn deserialized(&self) -> &'a T { - unsafe { &*self.deserialized } - } +fn copy_hash_set_from_account_bytes(bytes: &[u8]) -> Result { + let mut owned = bytes.to_vec(); + // SAFETY: We pass an owned, mutable copy of bytes that are expected to contain a serialized + // hash set account image. + unsafe { HashSet::from_bytes_copy(&mut owned) } } /// Fetches the given account, then copies and serializes it as a `HashSet`. -/// -/// # Safety -/// -/// This is highly unsafe. Ensuring that: -/// -/// * The correct account is used. -/// * The account has enough space to be treated as a HashSet with specified -/// parameters. -/// * The account data is aligned. -/// -/// Is the caller's responsibility. -pub async unsafe fn get_hash_set( +pub async fn get_hash_set( rpc: &mut R, pubkey: Pubkey, ) -> Result { @@ -73,9 +36,7 @@ pub async unsafe fn get_hash_set( .await .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? .ok_or(AccountZeroCopyError::AccountNotFound(pubkey))?; - let mut data = account.data.clone(); - - HashSet::from_bytes_copy(&mut data[8 + mem::size_of::()..]) + copy_hash_set_from_account_bytes(&account.data[8 + mem::size_of::()..]) .map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e))) } @@ -171,17 +132,11 @@ where IndexedMerkleTreeCopy::from_bytes_copy(&data[offset..]) } -/// Parse HashSet from raw queue account data bytes -/// -/// # Safety -/// Same safety requirements as `get_hash_set`. -pub unsafe fn parse_hash_set_from_bytes( - data: &[u8], -) -> Result { +/// Parse HashSet from raw queue account data bytes. +pub fn parse_hash_set_from_bytes(data: &[u8]) -> Result { let offset = 8 + mem::size_of::(); if data.len() <= offset { return Err(light_hash_set::HashSetError::BufferSize(offset, data.len())); } - let mut data_copy = data[offset..].to_vec(); - HashSet::from_bytes_copy(&mut data_copy) + copy_hash_set_from_account_bytes(&data[offset..]) } diff --git a/forester-utils/src/address_merkle_tree_config.rs b/forester-utils/src/address_merkle_tree_config.rs index 8d8018c92e..14895859e2 100644 --- a/forester-utils/src/address_merkle_tree_config.rs +++ b/forester-utils/src/address_merkle_tree_config.rs @@ -2,7 +2,7 @@ use account_compression::{ AddressMerkleTreeAccount, AddressMerkleTreeConfig, AddressQueueConfig, NullifierQueueConfig, QueueAccount, StateMerkleTreeAccount, StateMerkleTreeConfig, }; -use anchor_lang::Discriminator; +use anchor_lang::{AccountDeserialize, Discriminator}; use light_account_checks::discriminator::Discriminator as LightDiscriminator; use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; use light_client::{ @@ -14,40 +14,66 @@ use num_traits::Zero; use solana_sdk::pubkey::Pubkey; use crate::account_zero_copy::{ - get_concurrent_merkle_tree, get_hash_set, get_indexed_merkle_tree, AccountZeroCopy, - AccountZeroCopyError, + get_concurrent_merkle_tree, get_indexed_merkle_tree, parse_concurrent_merkle_tree_from_bytes, + parse_hash_set_from_bytes, parse_indexed_merkle_tree_from_bytes, AccountZeroCopyError, }; +fn deserialize_account( + data: &[u8], + pubkey: Pubkey, +) -> Result { + if data.len() < 8 { + return Err(AccountZeroCopyError::RpcError(format!( + "Account {} data too short: {}", + pubkey, + data.len() + ))); + } + + T::try_deserialize(&mut &data[..]).map_err(|e| { + AccountZeroCopyError::RpcError(format!("Failed to deserialize account {}: {}", pubkey, e)) + }) +} + pub async fn get_address_bundle_config( rpc: &mut R, address_bundle: AddressMerkleTreeAccounts, ) -> Result<(AddressMerkleTreeConfig, AddressQueueConfig), AccountZeroCopyError> { - // Get queue metadata - don't hold AccountZeroCopy across await points - let address_queue_meta_data = { - let account = AccountZeroCopy::::new(rpc, address_bundle.queue).await?; - account.deserialized().metadata - }; - let address_queue = - unsafe { get_hash_set::(rpc, address_bundle.queue).await? }; + let address_queue_account = rpc + .get_account(address_bundle.queue) + .await + .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound(address_bundle.queue))?; + let address_queue_meta_data = + deserialize_account::(&address_queue_account.data, address_bundle.queue)? + .metadata; + let address_queue = parse_hash_set_from_bytes::(&address_queue_account.data) + .map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e)))?; let queue_config = AddressQueueConfig { network_fee: Some(address_queue_meta_data.rollover_metadata.network_fee), // rollover_threshold: address_queue_meta_data.rollover_threshold, capacity: address_queue.get_capacity() as u16, sequence_threshold: address_queue.sequence_threshold as u64, }; - // Get tree metadata - don't hold AccountZeroCopy across await points - let address_tree_meta_data = { - let account = - AccountZeroCopy::::new(rpc, address_bundle.merkle_tree) - .await?; - account.deserialized().metadata - }; - let address_tree = - get_indexed_merkle_tree::( - rpc, + let address_tree_account = rpc + .get_account(address_bundle.merkle_tree) + .await + .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound( address_bundle.merkle_tree, + ))?; + let address_tree_meta_data = deserialize_account::( + &address_tree_account.data, + address_bundle.merkle_tree, + )? + .metadata; + let address_tree = + parse_indexed_merkle_tree_from_bytes::( + &address_tree_account.data, ) - .await?; + .map_err(|e| { + AccountZeroCopyError::RpcError(format!("IndexedMerkleTree parse error: {:?}", e)) + })?; let address_merkle_tree_config = AddressMerkleTreeConfig { height: address_tree.height as u32, changelog_size: address_tree.merkle_tree.changelog.capacity() as u64, @@ -73,64 +99,77 @@ pub async fn get_state_bundle_config( rpc: &mut R, state_tree_bundle: StateMerkleTreeAccounts, ) -> Result<(StateMerkleTreeConfig, NullifierQueueConfig), AccountZeroCopyError> { - // Get queue metadata - don't hold AccountZeroCopy across await points - let address_queue_meta_data = { - let account = - AccountZeroCopy::::new(rpc, state_tree_bundle.nullifier_queue).await?; - account.deserialized().metadata - }; - let address_queue = - unsafe { get_hash_set::(rpc, state_tree_bundle.nullifier_queue).await? }; + let queue_account = rpc + .get_account(state_tree_bundle.nullifier_queue) + .await + .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound( + state_tree_bundle.nullifier_queue, + ))?; + let nullifier_queue_metadata = deserialize_account::( + &queue_account.data, + state_tree_bundle.nullifier_queue, + )? + .metadata; + let nullifier_queue = parse_hash_set_from_bytes::(&queue_account.data) + .map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e)))?; let queue_config = NullifierQueueConfig { - network_fee: Some(address_queue_meta_data.rollover_metadata.network_fee), - capacity: address_queue.get_capacity() as u16, - sequence_threshold: address_queue.sequence_threshold as u64, - }; - // Get tree metadata - don't hold AccountZeroCopy across await points - let address_tree_meta_data = { - let account = - AccountZeroCopy::::new(rpc, state_tree_bundle.merkle_tree) - .await?; - account.deserialized().metadata + network_fee: Some(nullifier_queue_metadata.rollover_metadata.network_fee), + capacity: nullifier_queue.get_capacity() as u16, + sequence_threshold: nullifier_queue.sequence_threshold as u64, }; - let address_tree = get_concurrent_merkle_tree::( - rpc, + let state_tree_account = rpc + .get_account(state_tree_bundle.merkle_tree) + .await + .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound( + state_tree_bundle.merkle_tree, + ))?; + let state_tree_metadata = deserialize_account::( + &state_tree_account.data, state_tree_bundle.merkle_tree, - ) - .await?; - let address_merkle_tree_config = StateMerkleTreeConfig { - height: address_tree.height as u32, - changelog_size: address_tree.changelog.capacity() as u64, - roots_size: address_tree.roots.capacity() as u64, - canopy_depth: address_tree.canopy_depth as u64, - rollover_threshold: if address_tree_meta_data + )? + .metadata; + let state_tree = + parse_concurrent_merkle_tree_from_bytes::( + &state_tree_account.data, + ) + .map_err(|e| { + AccountZeroCopyError::RpcError(format!("ConcurrentMerkleTree parse error: {:?}", e)) + })?; + let state_merkle_tree_config = StateMerkleTreeConfig { + height: state_tree.height as u32, + changelog_size: state_tree.changelog.capacity() as u64, + roots_size: state_tree.roots.capacity() as u64, + canopy_depth: state_tree.canopy_depth as u64, + rollover_threshold: if state_tree_metadata .rollover_metadata .rollover_threshold .is_zero() { None } else { - Some(address_tree_meta_data.rollover_metadata.rollover_threshold) + Some(state_tree_metadata.rollover_metadata.rollover_threshold) }, - network_fee: Some(address_tree_meta_data.rollover_metadata.network_fee), + network_fee: Some(state_tree_metadata.rollover_metadata.network_fee), close_threshold: None, }; - Ok((address_merkle_tree_config, queue_config)) + Ok((state_merkle_tree_config, queue_config)) } pub async fn address_tree_ready_for_rollover( rpc: &mut R, merkle_tree: Pubkey, ) -> Result { - // Get account data - don't hold AccountZeroCopy across await points - let (address_tree_meta_data, account_data_len, account_lamports) = { - let account = AccountZeroCopy::::new(rpc, merkle_tree).await?; - ( - account.deserialized().metadata, - account.account.data.len(), - account.account.lamports, - ) - }; + let account = rpc + .get_account(merkle_tree) + .await + .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound(merkle_tree))?; + let address_tree_meta_data = + deserialize_account::(&account.data, merkle_tree)?.metadata; + let account_data_len = account.data.len(); + let account_lamports = account.lamports; let rent_exemption = rpc .get_minimum_balance_for_rent_exemption(account_data_len) .await @@ -166,15 +205,18 @@ pub async fn state_tree_ready_for_rollover( .get_minimum_balance_for_rent_exemption(account.data.len()) .await .map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?; + if account.data.len() < 8 { + return Err(AccountZeroCopyError::RpcError(format!( + "Account {} data too short: {}", + merkle_tree, + account.data.len() + ))); + } let discriminator = &account.data[0..8]; let (next_index, tree_meta_data, height) = match discriminator { d if d == StateMerkleTreeAccount::DISCRIMINATOR => { - // Get tree metadata - don't hold AccountZeroCopy across await points - let tree_meta_data = { - let account = - AccountZeroCopy::::new(rpc, merkle_tree).await?; - account.deserialized().metadata - }; + let tree_meta_data = + deserialize_account::(&account.data, merkle_tree)?.metadata; let tree = get_concurrent_merkle_tree::( rpc, merkle_tree, diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs index 786ddb6ac0..a6b1aa89bd 100644 --- a/forester-utils/src/address_staging_tree.rs +++ b/forester-utils/src/address_staging_tree.rs @@ -121,7 +121,7 @@ impl AddressStagingTree { low_element_next_values: &[[u8; 32]], low_element_indices: &[u64], low_element_next_indices: &[u64], - low_element_proofs: &[Vec<[u8; 32]>], + low_element_proofs: &[[[u8; 32]; HEIGHT]], leaves_hashchain: [u8; 32], zkp_batch_size: usize, epoch: u64, @@ -145,15 +145,12 @@ impl AddressStagingTree { let inputs = get_batch_address_append_circuit_inputs::( next_index, old_root, - low_element_values.to_vec(), - low_element_next_values.to_vec(), - low_element_indices.iter().map(|v| *v as usize).collect(), - low_element_next_indices - .iter() - .map(|v| *v as usize) - .collect(), - low_element_proofs.to_vec(), - addresses.to_vec(), + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, &mut self.sparse_tree, leaves_hashchain, zkp_batch_size, diff --git a/forester-utils/src/forester_epoch.rs b/forester-utils/src/forester_epoch.rs index d92bac8fdb..2a2f3a10d9 100644 --- a/forester-utils/src/forester_epoch.rs +++ b/forester-utils/src/forester_epoch.rs @@ -111,7 +111,7 @@ pub fn get_schedule_for_queue( total_epoch_weight, epoch, ) - .unwrap(); + .map_err(|e| ForesterUtilsError::Parse(e.to_string()))?; vec.push(Some(ForesterSlot { slot: light_slot, start_solana_slot, @@ -326,7 +326,7 @@ impl Epoch { let mut epoch = protocol_config .get_latest_register_epoch(current_solana_slot) - .unwrap(); + .map_err(|e| RpcError::CustomError(e.to_string()))?; let registration_start_slot = protocol_config.genesis_slot + epoch * protocol_config.active_phase_length; @@ -388,7 +388,7 @@ impl Epoch { let epoch_pda = rpc .get_anchor_account::(&epoch_pda_pubkey) .await? - .unwrap(); + .ok_or_else(|| RpcError::AccountDoesNotExist(epoch_pda_pubkey.to_string()))?; let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority(derivation, target_epoch).0; @@ -426,11 +426,11 @@ impl Epoch { let epoch_pda = rpc .get_anchor_account::(&self.epoch_pda) .await? - .unwrap(); + .ok_or_else(|| RpcError::AccountDoesNotExist(self.epoch_pda.to_string()))?; let mut forester_epoch_pda = rpc .get_anchor_account::(&self.forester_epoch_pda) .await? - .unwrap(); + .ok_or_else(|| RpcError::AccountDoesNotExist(self.forester_epoch_pda.to_string()))?; // IF active phase has started and total_epoch_weight is not set, set it now to if forester_epoch_pda.total_epoch_weight.is_none() { forester_epoch_pda.total_epoch_weight = Some(epoch_pda.registered_weight); diff --git a/forester-utils/src/instructions/compress_and_close_mint.rs b/forester-utils/src/instructions/compress_and_close_mint.rs index 4dcd3c2639..ca6a8abce6 100644 --- a/forester-utils/src/instructions/compress_and_close_mint.rs +++ b/forester-utils/src/instructions/compress_and_close_mint.rs @@ -56,17 +56,23 @@ pub async fn create_compress_and_close_mint_instruction( .get_validity_proof(vec![compressed_mint_account.hash], vec![], None) .await? .value; + let proof_account = rpc_proof_result.accounts.first().ok_or_else(|| { + RpcError::CustomError("Missing compressed mint proof account".to_string()) + })?; // Build MintWithContext let compressed_mint_inputs = MintWithContext { - prove_by_index: rpc_proof_result.accounts[0].root_index.proof_by_index(), + prove_by_index: proof_account.root_index.proof_by_index(), leaf_index: compressed_mint_account.leaf_index, - root_index: rpc_proof_result.accounts[0] - .root_index - .root_index() - .unwrap_or_default(), + root_index: proof_account.root_index.root_index().unwrap_or_default(), address: compressed_mint_address, - mint: compressed_mint.map(|m| m.try_into().unwrap()), + mint: compressed_mint + .map(|mint| { + mint.try_into().map_err(|e| { + RpcError::CustomError(format!("Failed to convert compressed mint: {:?}", e)) + }) + }) + .transpose()?, }; // Build instruction data with CompressAndCloseMint action @@ -91,7 +97,7 @@ pub async fn create_compress_and_close_mint_instruction( })?; // Build account metas configuration - let state_tree_info = rpc_proof_result.accounts[0].tree_info; + let state_tree_info = proof_account.tree_info; let config = MintActionMetaConfig::new( payer, payer, // authority - permissionless, using payer diff --git a/forester-utils/src/rate_limiter.rs b/forester-utils/src/rate_limiter.rs index 546f4726a1..badff7ee79 100644 --- a/forester-utils/src/rate_limiter.rs +++ b/forester-utils/src/rate_limiter.rs @@ -16,6 +16,9 @@ pub trait UseRateLimiter { pub enum RateLimiterError { #[error("Rate limit exceeded")] RateLimitExceeded, + + #[error("Invalid requests_per_second value: {0}")] + InvalidRequestsPerSecond(u32), } #[derive(Clone, Debug)] @@ -24,18 +27,26 @@ pub struct RateLimiter { } impl RateLimiter { - pub fn new(requests_per_second: u32) -> Self { - // Create a quota that allows exactly one request per 1/requests_per_second seconds - let quota = Quota::with_period(Duration::from_secs_f64(1.0 / requests_per_second as f64)) - .unwrap() - .allow_burst(NonZeroU32::new(1).unwrap()); - RateLimiter { + pub fn new(requests_per_second: u32) -> Result { + if requests_per_second == 0 { + return Err(RateLimiterError::InvalidRequestsPerSecond( + requests_per_second, + )); + } + + let period = Duration::from_secs_f64(1.0 / requests_per_second as f64); + let quota = Quota::with_period(period) + .ok_or(RateLimiterError::InvalidRequestsPerSecond( + requests_per_second, + ))? + .allow_burst(NonZeroU32::MIN); + Ok(RateLimiter { governor: Arc::new(Governor::new( quota, InMemoryState::default(), DefaultClock::default(), )), - } + }) } pub async fn acquire(&self) -> Result<(), RateLimiterError> { @@ -98,7 +109,7 @@ mod tests { #[tokio::test] async fn test_rate_limiter_basic() { - let limiter = RateLimiter::new(10); + let limiter = RateLimiter::new(10).unwrap(); let mut successes = 0; for _ in 0..20 { @@ -120,6 +131,7 @@ mod tests { } let rate_limiter = RateLimiter::new(10); + let rate_limiter = rate_limiter.unwrap(); let client = RateLimitedClient::new(MockClient, rate_limiter); let result = client @@ -131,7 +143,7 @@ mod tests { #[tokio::test] async fn test_rate_limiter_concurrent() { - let rate_limiter = RateLimiter::new(10); + let rate_limiter = RateLimiter::new(10).unwrap(); let test_duration = Duration::from_secs(3); let start_time = Instant::now(); let mut total_successful = 0; @@ -162,7 +174,7 @@ mod tests { #[tokio::test] async fn test_rate_limiter_with_wait() { - let rate_limiter = RateLimiter::new(10); + let rate_limiter = RateLimiter::new(10).unwrap(); let start_time = Instant::now(); for _ in 0..15 { diff --git a/forester-utils/src/registry.rs b/forester-utils/src/registry.rs index b7b49308bc..38db4cb787 100644 --- a/forester-utils/src/registry.rs +++ b/forester-utils/src/registry.rs @@ -38,7 +38,9 @@ pub async fn update_test_forester( let mut pre_account_state = rpc .get_anchor_account::(&get_forester_pda(derivation_key).0) .await? - .unwrap(); + .ok_or_else(|| { + RpcError::AccountDoesNotExist(get_forester_pda(derivation_key).0.to_string()) + })?; let (signers, new_forester_authority) = if let Some(new_authority) = new_forester_authority { pre_account_state.authority = new_authority.pubkey(); @@ -69,7 +71,10 @@ pub async fn assert_registered_forester( expected_account: ForesterPda, ) -> Result<(), RpcError> { let pda = get_forester_pda(forester).0; - let account_data = rpc.get_anchor_account::(&pda).await?.unwrap(); + let account_data = rpc + .get_anchor_account::(&pda) + .await? + .ok_or_else(|| RpcError::AccountDoesNotExist(pda.to_string()))?; if account_data != expected_account { return Err(RpcError::AssertRpcError(format!( "Expected account data: {:?}, got: {:?}", diff --git a/forester/src/config.rs b/forester/src/config.rs index 550828b7c3..8e1109c861 100644 --- a/forester/src/config.rs +++ b/forester/src/config.rs @@ -338,7 +338,13 @@ impl ForesterConfig { .into()); } - valid.into_iter().map(|r| r.unwrap()).collect() + valid + .into_iter() + .collect::, _>>() + .map_err(|_| ConfigError::InvalidArguments { + field: "tree_ids", + invalid_values: vec!["failed to parse tree_ids".to_string()], + })? }, sleep_after_processing_ms: 10_000, sleep_when_idle_ms: 45_000, diff --git a/forester/src/epoch_manager.rs b/forester/src/epoch_manager.rs index f52efa1b13..c1ee6544fc 100644 --- a/forester/src/epoch_manager.rs +++ b/forester/src/epoch_manager.rs @@ -1,4 +1,5 @@ use std::{ + any::Any, collections::HashMap, sync::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, @@ -14,7 +15,7 @@ use forester_utils::{ forester_epoch::{get_epoch_phases, Epoch, ForesterSlot, TreeAccounts, TreeForesterSchedule}, rpc_pool::SolanaRpcPool, }; -use futures::future::join_all; +use futures::{future::join_all, stream::FuturesUnordered, FutureExt, StreamExt}; use light_client::{ indexer::{Indexer, MerkleProof, NewAddressProofWithContext}, rpc::{LightClient, LightClientConfig, RetryConfig, Rpc, RpcError}, @@ -39,7 +40,8 @@ use solana_sdk::{ transaction::TransactionError, }; use tokio::{ - sync::{mpsc, oneshot, Mutex}, + runtime::Handle, + sync::{mpsc, oneshot, Mutex, Semaphore}, task::JoinHandle, time::{sleep, Instant, MissedTickBehavior}, }; @@ -94,8 +96,6 @@ type StateBatchProcessorMap = type AddressBatchProcessorMap = Arc>>)>>; type ProcessorInitLockMap = Arc>>>; -type TreeProcessingTask = JoinHandle>; - /// Coordinates re-finalization across parallel `process_queue` tasks when new /// foresters register mid-epoch. Only one task performs the on-chain /// `finalize_registration` tx; others wait for it to complete. @@ -221,6 +221,46 @@ impl std::ops::AddAssign for ProcessingMetrics { } } +fn panic_payload_message(payload: &(dyn Any + Send)) -> String { + if let Some(message) = payload.downcast_ref::() { + message.clone() + } else if let Some(message) = payload.downcast_ref::<&'static str>() { + (*message).to_string() + } else { + "non-string panic payload".to_string() + } +} + +const NEW_TREE_WORKER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +fn max_parallel_tree_workers(tree_count: usize) -> usize { + if tree_count == 0 { + return 0; + } + + let cpu_count = std::thread::available_parallelism() + .map(|parallelism| parallelism.get()) + .unwrap_or(4); + tree_count.min(std::cmp::max(1, cpu_count / 2)) +} + +struct NewTreeWorker { + tree: Pubkey, + epoch: u64, + cancel: Option>, + completion: oneshot::Receiver<()>, + thread_handle: std::thread::JoinHandle<()>, +} + +impl std::fmt::Debug for NewTreeWorker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NewTreeWorker") + .field("tree", &self.tree) + .field("epoch", &self.epoch) + .finish_non_exhaustive() + } +} + #[derive(Copy, Clone, Debug)] pub struct WorkReport { pub epoch: u64, @@ -280,6 +320,7 @@ pub struct EpochManager { run_id: Arc, /// Per-epoch registration trackers to coordinate re-finalization when new foresters register mid-epoch registration_trackers: Arc>>, + new_tree_workers: Arc>>, } impl Clone for EpochManager { @@ -310,6 +351,7 @@ impl Clone for EpochManager { heartbeat: self.heartbeat.clone(), run_id: self.run_id.clone(), registration_trackers: self.registration_trackers.clone(), + new_tree_workers: self.new_tree_workers.clone(), } } } @@ -359,9 +401,129 @@ impl EpochManager { heartbeat, run_id: Arc::::from(run_id), registration_trackers: Arc::new(DashMap::new()), + new_tree_workers: Arc::new(Mutex::new(Vec::new())), }) } + fn join_new_tree_worker_with_run_id(run_id: Arc, worker: NewTreeWorker) { + if let Err(payload) = worker.thread_handle.join() { + error!( + event = "new_tree_worker_join_panicked", + run_id = %run_id, + tree = %worker.tree, + epoch = worker.epoch, + panic = %panic_payload_message(payload.as_ref()), + "New tree worker panicked while joining" + ); + } + } + + fn join_new_tree_worker(&self, worker: NewTreeWorker) { + Self::join_new_tree_worker_with_run_id(self.run_id.clone(), worker); + } + + fn detach_new_tree_worker_join(&self, worker: NewTreeWorker) { + let run_id = self.run_id.clone(); + let tree = worker.tree; + let epoch = worker.epoch; + std::thread::spawn(move || { + warn!( + event = "new_tree_worker_join_deferred", + run_id = %run_id, + tree = %tree, + epoch, + "Deferring timed-out new-tree worker join to background thread" + ); + Self::join_new_tree_worker_with_run_id(run_id, worker); + }); + } + + async fn reap_finished_new_tree_workers(&self) { + let finished_workers = { + let mut workers = self.new_tree_workers.lock().await; + let mut pending = Vec::with_capacity(workers.len()); + let mut finished = Vec::new(); + + for worker in workers.drain(..) { + if worker.thread_handle.is_finished() { + finished.push(worker); + } else { + pending.push(worker); + } + } + + *workers = pending; + finished + }; + + for mut worker in finished_workers { + let _ = worker.completion.try_recv(); + self.join_new_tree_worker(worker); + } + } + + async fn register_new_tree_worker(&self, worker: NewTreeWorker) { + self.reap_finished_new_tree_workers().await; + self.new_tree_workers.lock().await.push(worker); + } + + async fn shutdown_new_tree_workers(&self, timeout_duration: Duration) { + let mut workers = { + let mut guard = self.new_tree_workers.lock().await; + std::mem::take(&mut *guard) + }; + + if workers.is_empty() { + return; + } + + info!( + event = "new_tree_workers_shutdown_started", + run_id = %self.run_id, + worker_count = workers.len(), + timeout_secs = timeout_duration.as_secs_f64(), + "Shutting down tracked new-tree workers" + ); + + for worker in &mut workers { + if let Some(cancel) = worker.cancel.take() { + let _ = cancel.send(()); + } + } + + let deadline = Instant::now() + timeout_duration; + for mut worker in workers { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + warn!( + event = "new_tree_worker_shutdown_timed_out", + run_id = %self.run_id, + tree = %worker.tree, + epoch = worker.epoch, + "Timed out waiting for new-tree worker shutdown" + ); + self.detach_new_tree_worker_join(worker); + continue; + } + + match tokio::time::timeout(remaining, &mut worker.completion).await { + Ok(Ok(())) | Ok(Err(_)) => { + self.join_new_tree_worker(worker); + } + Err(_) => { + warn!( + event = "new_tree_worker_shutdown_timed_out", + run_id = %self.run_id, + tree = %worker.tree, + epoch = worker.epoch, + "Timed out waiting for new-tree worker shutdown" + ); + self.detach_new_tree_worker_join(worker); + } + } + } + } + pub async fn run(self: Arc) -> Result<()> { let (tx, mut rx) = mpsc::channel(100); let tx = Arc::new(tx); @@ -411,8 +573,40 @@ impl EpochManager { }, ); + let mut epoch_tasks = FuturesUnordered::new(); let result = loop { tokio::select! { + Some((epoch, result)) = epoch_tasks.next(), if !epoch_tasks.is_empty() => { + match result { + Ok(Ok(())) => { + debug!( + event = "epoch_processing_completed", + run_id = %self.run_id, + epoch, + "Epoch processed successfully" + ); + } + Ok(Err(e)) => { + error!( + event = "epoch_processing_failed", + run_id = %self.run_id, + epoch, + error = ?e, + "Error processing epoch" + ); + } + Err(payload) => { + let payload: Box = payload; + error!( + event = "epoch_processing_panicked", + run_id = %self.run_id, + epoch, + panic = %panic_payload_message(payload.as_ref()), + "Epoch processing panicked" + ); + } + } + } epoch_opt = rx.recv() => { match epoch_opt { Some(epoch) => { @@ -423,16 +617,11 @@ impl EpochManager { "Received epoch from monitor" ); let self_clone = Arc::clone(&self); - tokio::spawn(async move { - if let Err(e) = self_clone.process_epoch(epoch).await { - error!( - event = "epoch_processing_failed", - run_id = %self_clone.run_id, - epoch, - error = ?e, - "Error processing epoch" - ); - } + epoch_tasks.push(async move { + let result = std::panic::AssertUnwindSafe(self_clone.process_epoch(epoch)) + .catch_unwind() + .await; + (epoch, result) }); } None => { @@ -488,6 +677,8 @@ impl EpochManager { // Abort monitor_handle on exit monitor_handle.abort(); + self.shutdown_new_tree_workers(NEW_TREE_WORKER_SHUTDOWN_TIMEOUT) + .await; result } @@ -720,33 +911,87 @@ impl EpochManager { epoch = current_epoch, "Spawning task to process new tree in current epoch" ); - tokio::spawn(async move { + let (cancel_tx, mut cancel_rx) = oneshot::channel(); + let (completion_tx, completion_rx) = oneshot::channel(); + let thread_handle = std::thread::spawn(move || { let tree_pubkey = tree_schedule.tree_accounts.merkle_tree; - if let Err(e) = self_clone - .process_queue( - &epoch_info.epoch, - epoch_info.forester_epoch_pda.clone(), - tree_schedule, - tracker, - ) - .await - { - error!( - event = "new_tree_process_queue_failed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - error = ?e, - "Error processing queue for new tree" - ); - } else { - info!( - event = "new_tree_process_queue_succeeded", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - "Successfully processed new tree in current epoch" - ); + let run_id = self_clone.run_id.clone(); + let thread_run_id = run_id.clone(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(async move { + tokio::select! { + result = self_clone + .clone() + .process_queue( + epoch_info.epoch.clone(), + epoch_info.forester_epoch_pda.clone(), + tree_schedule, + tracker, + ) => { + if let Err(e) = result { + error!( + event = "new_tree_process_queue_failed", + run_id = %thread_run_id, + tree = %tree_pubkey, + error = ?e, + "Error processing queue for new tree" + ); + } else { + info!( + event = "new_tree_process_queue_succeeded", + run_id = %thread_run_id, + tree = %tree_pubkey, + "Successfully processed new tree in current epoch" + ); + } + } + _ = &mut cancel_rx => { + info!( + event = "new_tree_process_queue_cancelled", + run_id = %thread_run_id, + tree = %tree_pubkey, + "Cancellation requested for new tree worker" + ); + } + } + Ok::<(), anyhow::Error>(()) + }) + })); + let _ = completion_tx.send(()); + match result { + Ok(Ok(())) => {} + Ok(Err(error)) => { + error!( + event = "new_tree_runtime_build_failed", + run_id = %run_id, + tree = %tree_pubkey, + error = ?error, + "Failed to build background runtime for new tree processing" + ); + } + Err(payload) => { + error!( + event = "new_tree_processing_task_panicked", + run_id = %run_id, + tree = %tree_pubkey, + panic = %panic_payload_message(payload.as_ref()), + "New tree processing thread panicked" + ); + } } }); + self.register_new_tree_worker(NewTreeWorker { + tree: new_tree.merkle_tree, + epoch: current_epoch, + cancel: Some(cancel_tx), + completion: completion_rx, + thread_handle, + }) + .await; } Ok(None) => { debug!( @@ -1615,16 +1860,19 @@ impl EpochManager { .cloned() .collect(); + let max_parallel_tree_workers = max_parallel_tree_workers(trees_to_process.len()); info!( event = "active_work_cycle_started", run_id = %self.run_id, current_slot, active_phase_end, tree_count = trees_to_process.len(), + parallel_tree_worker_limit = max_parallel_tree_workers, "Starting active work cycle" ); let self_arc = Arc::new(self.clone()); + let worker_slots = Arc::new(Semaphore::new(max_parallel_tree_workers)); let registration_tracker = self .registration_trackers .entry(epoch_info.epoch.epoch) @@ -1636,13 +1884,15 @@ impl EpochManager { .value() .clone(); - let mut handles: Vec = Vec::with_capacity(trees_to_process.len()); + let runtime_handle = Handle::current(); + let mut tasks = Vec::with_capacity(trees_to_process.len()); for tree in trees_to_process { + let tree_pubkey = tree.tree_accounts.merkle_tree; debug!( event = "tree_processing_task_spawned", run_id = %self.run_id, - tree = %tree.tree_accounts.merkle_tree, + tree = %tree_pubkey, tree_type = ?tree.tree_accounts.tree_type, "Spawning tree processing task" ); @@ -1652,41 +1902,68 @@ impl EpochManager { let epoch_clone = epoch_info.epoch.clone(); let forester_epoch_pda = epoch_info.forester_epoch_pda.clone(); let tracker = registration_tracker.clone(); - - let handle = tokio::spawn(async move { - self_clone - .process_queue(&epoch_clone, forester_epoch_pda, tree, tracker) - .await + let worker_slots = worker_slots.clone(); + let runtime_handle = runtime_handle.clone(); + tasks.push(async move { + let permit = match worker_slots.acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + return Ok(( + tree_pubkey, + Err(anyhow!("tree worker semaphore was closed unexpectedly")), + )); + } + }; + tokio::task::spawn_blocking(move || { + let _permit = permit; + let result = runtime_handle.block_on(self_clone.process_queue( + epoch_clone, + forester_epoch_pda, + tree, + tracker, + )); + (tree_pubkey, result) + }) + .await }); - - handles.push(handle); } - debug!("Waiting for {} tree processing tasks", handles.len()); - let results = join_all(handles).await; + debug!("Waiting for {} tree processing tasks", tasks.len()); + let results = join_all(tasks).await; let mut success_count = 0usize; let mut error_count = 0usize; let mut panic_count = 0usize; for result in results { match result { - Ok(Ok(())) => success_count += 1, - Ok(Err(e)) => { + Ok((_, Ok(()))) => success_count += 1, + Ok((tree_pubkey, Err(e))) => { error_count += 1; error!( event = "tree_processing_task_failed", run_id = %self.run_id, + tree = %tree_pubkey, error = ?e, "Error processing queue" ); } - Err(e) => { + Err(join_error) => { panic_count += 1; - error!( - event = "tree_processing_task_panicked", - run_id = %self.run_id, - error = ?e, - "Tree processing task panicked" - ); + if join_error.is_panic() { + let payload = join_error.into_panic(); + error!( + event = "tree_processing_task_join_panicked", + run_id = %self.run_id, + panic = %panic_payload_message(payload.as_ref()), + "Tree processing task panicked before completion" + ); + } else { + error!( + event = "tree_processing_task_join_failed", + run_id = %self.run_id, + error = ?join_error, + "Tree processing task failed to join" + ); + } } } } @@ -1713,15 +1990,9 @@ impl EpochManager { Ok(current_slot) } - #[instrument( - level = "debug", - skip(self, epoch_info, forester_epoch_pda, tree_schedule, registration_tracker), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree) - )] pub(crate) async fn process_queue( - &self, - epoch_info: &Epoch, + self: Arc, + epoch_info: Epoch, mut forester_epoch_pda: ForesterEpochPda, mut tree_schedule: TreeForesterSchedule, registration_tracker: Arc, @@ -1758,26 +2029,28 @@ impl EpochManager { if let Some((slot_idx, light_slot_details)) = next_slot_to_process { let result = match tree_type { TreeType::StateV1 | TreeType::AddressV1 | TreeType::Unknown => { - self.process_light_slot( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - ) - .await + self.clone() + .process_light_slot( + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.tree_accounts, + light_slot_details.clone(), + ) + .await } TreeType::StateV2 | TreeType::AddressV2 => { let consecutive_end = tree_schedule .get_consecutive_eligibility_end(slot_idx) .unwrap_or(light_slot_details.end_solana_slot); - self.process_light_slot_v2( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - consecutive_end, - ) - .await + self.clone() + .process_light_slot_v2( + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.tree_accounts, + light_slot_details.clone(), + consecutive_end, + ) + .await } }; @@ -1817,23 +2090,30 @@ impl EpochManager { // where cached_weight is correct but schedule was never recomputed. if force_refinalize || last_weight_check.elapsed() >= WEIGHT_CHECK_INTERVAL { last_weight_check = Instant::now(); - if let Err(e) = self + match self + .clone() .maybe_refinalize( - epoch_info, - &mut forester_epoch_pda, - &mut tree_schedule, - ®istration_tracker, + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.clone(), + registration_tracker.clone(), force_refinalize, ) .await { - warn!( - event = "refinalize_check_failed", - run_id = %self.run_id, - forced = force_refinalize, - error = ?e, - "Failed to check/perform re-finalization" - ); + Ok((updated_pda, updated_schedule)) => { + forester_epoch_pda = updated_pda; + tree_schedule = updated_schedule; + } + Err(e) => { + warn!( + event = "refinalize_check_failed", + run_id = %self.run_id, + forced = force_refinalize, + error = ?e, + "Failed to check/perform re-finalization" + ); + } } } } else { @@ -1866,13 +2146,13 @@ impl EpochManager { /// When `force` is true (e.g. after a ForesterNotEligible error), skips /// the weight-change check and unconditionally refreshes the schedule. async fn maybe_refinalize( - &self, - epoch_info: &Epoch, - forester_epoch_pda: &mut ForesterEpochPda, - tree_schedule: &mut TreeForesterSchedule, - registration_tracker: &RegistrationTracker, + self: Arc, + epoch_info: Epoch, + forester_epoch_pda: ForesterEpochPda, + tree_schedule: TreeForesterSchedule, + registration_tracker: Arc, force: bool, - ) -> Result<()> { + ) -> Result<(ForesterEpochPda, TreeForesterSchedule)> { let mut rpc = self.rpc_pool.get_connection().await?; let epoch_pda_address = get_epoch_pda_address(epoch_info.epoch); let on_chain_epoch_pda: EpochPda = rpc @@ -1885,7 +2165,7 @@ impl EpochManager { let weight_changed = on_chain_weight != cached_weight; if !weight_changed && !force { - return Ok(()); + return Ok((forester_epoch_pda, tree_schedule)); } if weight_changed { @@ -1921,7 +2201,7 @@ impl EpochManager { "Skipping re-finalization because not enough active-phase time remains for confirmation" ); registration_tracker.complete_refinalize(cached_weight); - return Ok(()); + return Ok((forester_epoch_pda, tree_schedule)); }; let payer = self.config.payer_keypair.pubkey(); let signers = [&self.config.payer_keypair]; @@ -1999,33 +2279,24 @@ impl EpochManager { &refreshed_epoch_pda, )?; - *forester_epoch_pda = updated_pda; - *tree_schedule = new_schedule; - info!( event = "schedule_recomputed_after_refinalize", run_id = %self.run_id, epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree, - new_eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(), + tree = %new_schedule.tree_accounts.merkle_tree, + new_eligible_slots = new_schedule.slots.iter().filter(|s| s.is_some()).count(), "Recomputed schedule after re-finalization" ); - Ok(()) + Ok((updated_pda, new_schedule)) } - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_accounts.merkle_tree) - )] async fn process_light_slot( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, ) -> std::result::Result<(), ForesterError> { debug!( event = "light_slot_processing_started", @@ -2070,26 +2341,24 @@ impl EpochManager { break 'inner_processing_loop; } - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.queue, - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &tree_accounts.queue, + epoch_info.epoch, + &epoch_info, + )? { break 'inner_processing_loop; } let processing_start_time = Instant::now(); let items_processed_this_iteration = match self + .clone() .dispatch_tree_processing( - epoch_info, - epoch_pda, + epoch_info.clone(), + epoch_pda.clone(), tree_accounts, - forester_slot_details, + forester_slot_details.clone(), forester_slot_details.end_solana_slot, estimated_slot, ) @@ -2158,17 +2427,12 @@ impl EpochManager { Ok(()) } - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end), - fields(tree = %tree_accounts.merkle_tree) - )] async fn process_light_slot_v2( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, ) -> std::result::Result<(), ForesterError> { debug!( @@ -2195,7 +2459,7 @@ impl EpochManager { // Try to send any cached proofs first let cached_send_start = Instant::now(); if let Some(items_sent) = self - .try_send_cached_proofs(epoch_info, tree_accounts, consecutive_eligibility_end) + .try_send_cached_proofs(&epoch_info, &tree_accounts, consecutive_eligibility_end) .await? { if items_sent > 0 { @@ -2242,27 +2506,25 @@ impl EpochManager { break 'inner_processing_loop; } - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.merkle_tree, - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &tree_accounts.merkle_tree, + epoch_info.epoch, + &epoch_info, + )? { break 'inner_processing_loop; } // Process directly - the processor fetches queue data from the indexer let processing_start_time = Instant::now(); match self + .clone() .dispatch_tree_processing( - epoch_info, - epoch_pda, + epoch_info.clone(), + epoch_pda.clone(), tree_accounts, - forester_slot_details, + forester_slot_details.clone(), consecutive_eligibility_end, estimated_slot, ) @@ -2327,7 +2589,7 @@ impl EpochManager { Ok(()) } - async fn check_forester_eligibility( + fn check_forester_eligibility( &self, epoch_pda: &ForesterEpochPda, current_light_slot: u64, @@ -2389,16 +2651,17 @@ impl EpochManager { #[allow(clippy::too_many_arguments)] async fn dispatch_tree_processing( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, current_solana_slot: u64, ) -> std::result::Result { match tree_accounts.tree_type { TreeType::Unknown => self + .clone() .dispatch_compression( epoch_info, epoch_pda, @@ -2408,18 +2671,24 @@ impl EpochManager { .await .map_err(ForesterError::from), TreeType::StateV1 | TreeType::AddressV1 => { - self.process_v1( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - current_solana_slot, - ) - .await + self.clone() + .process_v1( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + current_solana_slot, + ) + .await } TreeType::StateV2 | TreeType::AddressV2 => { let result = self - .process_v2(epoch_info, tree_accounts, consecutive_eligibility_end) + .clone() + .process_v2( + epoch_info.clone(), + tree_accounts, + consecutive_eligibility_end, + ) .await?; // Accumulate processing metrics for this epoch self.add_processing_metrics(epoch_info.epoch, result.metrics) @@ -2430,10 +2699,10 @@ impl EpochManager { } async fn dispatch_compression( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, ) -> Result { let current_slot = self.slot_tracker.estimated_current_slot(); @@ -2455,16 +2724,13 @@ impl EpochManager { let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + &epoch_info, + )? { debug!( "Skipping compression: forester not eligible for current light slot {}", current_light_slot @@ -2518,85 +2784,85 @@ impl EpochManager { // Create parallel compression futures use futures::stream::StreamExt; - // Collect chunks into owned vectors to avoid lifetime issues - let batches: Vec<(usize, Vec<_>)> = accounts - .chunks(config.batch_size) - .enumerate() - .map(|(idx, chunk)| (idx, chunk.to_vec())) - .collect(); - + let run_id = self.run_id.clone(); let slot_tracker = self.slot_tracker.clone(); // Shared cancellation flag - when set, all pending futures should skip processing let cancelled = Arc::new(AtomicBool::new(false)); - let compression_futures = batches.into_iter().map(|(batch_idx, batch)| { - let compressor = compressor.clone(); - let slot_tracker = slot_tracker.clone(); - let cancelled = cancelled.clone(); - async move { - // Check if already cancelled by another future - if cancelled.load(Ordering::Relaxed) { - debug!( - "Skipping compression batch {}/{}: cancelled", - batch_idx + 1, - num_batches - ); - return Err((batch_idx, batch.len(), Cancelled.into())); - } - - // Check forester is still eligible before processing this batch - let current_slot = slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - // Signal cancellation to all other futures - cancelled.store(true, Ordering::Relaxed); - warn!( - event = "compression_ctoken_cancelled_not_eligible", - run_id = %self.run_id, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Cancelling compression because forester is no longer eligible" - ); - return Err(( - batch_idx, - batch.len(), - anyhow!("Forester no longer eligible"), - )); - } + let compression_futures = + accounts + .chunks(config.batch_size) + .enumerate() + .map(|(batch_idx, chunk)| { + let batch = chunk.to_vec(); + let compressor = compressor.clone(); + let run_id = run_id.clone(); + let slot_tracker = slot_tracker.clone(); + let cancelled = cancelled.clone(); + async move { + // Check if already cancelled by another future + if cancelled.load(Ordering::Relaxed) { + debug!( + "Skipping compression batch {}/{}: cancelled", + batch_idx + 1, + num_batches + ); + return Err((batch_idx, batch.len(), Cancelled.into())); + } - debug!( - "Processing compression batch {}/{} with {} accounts", - batch_idx + 1, - num_batches, - batch.len() - ); + // Check forester is still eligible before processing this batch + let current_slot = slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + // Signal cancellation to all other futures + cancelled.store(true, Ordering::Relaxed); + warn!( + event = "compression_ctoken_cancelled_not_eligible", + run_id = %run_id, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Cancelling compression because forester is no longer eligible" + ); + return Err(( + batch_idx, + batch.len(), + anyhow!("Forester no longer eligible"), + )); + } - match compressor - .compress_batch(&batch, registered_forester_pda) - .await - { - Ok(sig) => { debug!( - "Compression batch {}/{} succeeded: {}", + "Processing compression batch {}/{} with {} accounts", batch_idx + 1, num_batches, - sig + batch.len() ); - Ok((batch_idx, batch.len(), sig)) - } - Err(e) => { - error!( - event = "compression_ctoken_batch_failed", - run_id = %self.run_id, - batch = batch_idx + 1, - total_batches = num_batches, - error = ?e, - "Compression batch failed" - ); - Err((batch_idx, batch.len(), e)) + + match compressor + .compress_batch(&batch, registered_forester_pda) + .await + { + Ok(sig) => { + debug!( + "Compression batch {}/{} succeeded: {}", + batch_idx + 1, + num_batches, + sig + ); + Ok((batch_idx, batch.len(), sig)) + } + Err(e) => { + error!( + event = "compression_ctoken_batch_failed", + run_id = %run_id, + batch = batch_idx + 1, + total_batches = num_batches, + error = ?e, + "Compression batch failed" + ); + Err((batch_idx, batch.len(), e)) + } + } } - } - } - }); + }); // Execute batches in parallel with concurrency limit let results = futures::stream::iter(compression_futures) @@ -2644,7 +2910,7 @@ impl EpochManager { // Process PDA compression if configured let pda_compressed = self - .dispatch_pda_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .dispatch_pda_compression(&epoch_info, &epoch_pda, consecutive_eligibility_end) .await .unwrap_or_else(|e| { error!( @@ -2658,7 +2924,7 @@ impl EpochManager { // Process Mint compression let mint_compressed = self - .dispatch_mint_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .dispatch_mint_compression(&epoch_info, &epoch_pda, consecutive_eligibility_end) .await .unwrap_or_else(|e| { error!( @@ -2943,16 +3209,13 @@ impl EpochManager { let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + epoch_info, + )? { debug!( "Skipping {} compression: forester not eligible for current light slot {}", label, current_light_slot @@ -2964,11 +3227,11 @@ impl EpochManager { } async fn process_v1( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, current_solana_slot: u64, ) -> std::result::Result { let slots_remaining = forester_slot_details @@ -3018,7 +3281,7 @@ impl EpochManager { &self.config.derivation_pubkey, self.rpc_pool.clone(), &batched_tx_config, - *tree_accounts, + tree_accounts, transaction_builder, ) .await?; @@ -3033,7 +3296,7 @@ impl EpochManager { ); } - match self.rollover_if_needed(tree_accounts).await { + match self.rollover_if_needed(&tree_accounts).await { Ok(_) => Ok(num_sent), Err(e) => { error!( @@ -3283,15 +3546,15 @@ impl EpochManager { } async fn process_v2( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, + self: Arc, + epoch_info: Epoch, + tree_accounts: TreeAccounts, consecutive_eligibility_end: u64, ) -> std::result::Result { match tree_accounts.tree_type { TreeType::StateV2 => { let processor = self - .get_or_create_state_processor(epoch_info, tree_accounts) + .get_or_create_state_processor(&epoch_info, &tree_accounts) .await?; let cache = self @@ -3358,7 +3621,7 @@ impl EpochManager { } TreeType::AddressV2 => { let processor = self - .get_or_create_address_processor(epoch_info, tree_accounts) + .get_or_create_address_processor(&epoch_info, &tree_accounts) .await?; let cache = self diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index 80c4539075..799374392b 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -671,7 +671,7 @@ fn parse_tree_status( let (queue_len, queue_cap) = queue_account .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } + parse_hash_set_from_bytes::(&acc.data) .ok() .map(|hs| { let len = hs @@ -726,7 +726,7 @@ fn parse_tree_status( let (queue_len, queue_cap) = queue_account .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } + parse_hash_set_from_bytes::(&acc.data) .ok() .map(|hs| { let len = hs diff --git a/forester/src/main.rs b/forester/src/main.rs index 870964a9c2..9c646ff4b5 100644 --- a/forester/src/main.rs +++ b/forester/src/main.rs @@ -82,11 +82,19 @@ async fn main() -> Result<(), ForesterError> { let rpc_rate_limiter = config .external_services .rpc_rate_limit - .map(RateLimiter::new); + .map(RateLimiter::new) + .transpose() + .map_err(|e| ForesterError::General { + error: format!("invalid rpc rate limit: {}", e), + })?; let send_tx_limiter = config .external_services .send_tx_rate_limit - .map(RateLimiter::new); + .map(RateLimiter::new) + .transpose() + .map_err(|e| ForesterError::General { + error: format!("invalid send_tx rate limit: {}", e), + })?; let mut shutdown_sender_compressible: Option> = None; let mut shutdown_sender_bootstrap: Option> = None; diff --git a/forester/src/metrics.rs b/forester/src/metrics.rs index a7be12dd83..ece7d6a577 100644 --- a/forester/src/metrics.rs +++ b/forester/src/metrics.rs @@ -440,21 +440,19 @@ pub async fn metrics_handler() -> Result { if let Err(e) = encoder.encode(®ISTRY.gather(), &mut buffer) { error!("could not encode custom metrics: {}", e); }; - let mut res = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let mut res = String::from_utf8(buffer).unwrap_or_else(|e| { error!("custom metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); let mut buffer = Vec::new(); if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) { error!("could not encode prometheus metrics: {}", e); }; - let res_prometheus = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let res_prometheus = String::from_utf8(buffer).unwrap_or_else(|e| { error!("prometheus metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); res.push_str(&res_prometheus); Ok(res) diff --git a/forester/src/priority_fee.rs b/forester/src/priority_fee.rs index c788712e1c..0b9ee4049c 100644 --- a/forester/src/priority_fee.rs +++ b/forester/src/priority_fee.rs @@ -208,7 +208,7 @@ pub async fn request_priority_fee_estimate( .map_err(|error| PriorityFeeEstimateError::ClientBuild(error.clone()))?; let response = http_client - .post(url.clone()) + .post(url.as_str()) .header("Content-Type", "application/json") .json(&rpc_request) .send() diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index ed135cb6a4..a0f3e3bb5b 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -9,6 +9,7 @@ use light_client::{ indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, rpc::Rpc, }; +use light_hasher::hash_chain::create_hash_chain_from_slice; use crate::processor::v2::{common::clamp_to_u16, BatchContext}; @@ -22,6 +23,17 @@ pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> Mu } } +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { let rpc = context.rpc_pool.get_connection().await?; let mut account = rpc @@ -474,20 +486,52 @@ impl StreamingAddressQueue { } } - pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { let available = self.wait_for_batch(end); if start >= available { - return None; + return Ok(None); } let actual_end = end.min(available); let data = lock_recover(&self.data, "streaming_address_queue.data"); - Some(BatchDataSlice { - addresses: data.addresses[start..actual_end].to_vec(), + + let addresses = data.addresses[start..actual_end].to_vec(); + if addresses.is_empty() { + return Err(anyhow!("Empty batch at start={}", start)); + } + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { low_element_values: data.low_element_values[start..actual_end].to_vec(), low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - }) + low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + addresses, + leaves_hashchain, + })) } pub fn into_data(self) -> AddressQueueData { @@ -553,15 +597,6 @@ impl StreamingAddressQueue { } } -#[derive(Debug, Clone)] -pub struct BatchDataSlice { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, -} - pub async fn fetch_streaming_address_batches( context: &BatchContext, total_elements: u64, diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index b7afeacf0b..81b32f0d1f 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -37,7 +37,7 @@ impl ProofInput { } } - fn to_json(&self, tree_id: &str, batch_index: u64) -> String { + fn to_json(&self, tree_id: &str, batch_index: u64) -> Result { match self { ProofInput::Append(inputs) => BatchAppendInputsJson::from_inputs(inputs) .with_tree_id(tree_id.to_string()) @@ -194,7 +194,24 @@ async fn run_proof_pipeline( async fn submit_and_poll_proof(clients: Arc, job: ProofJob) { let client = clients.get_client(&job.inputs); // Use seq as batch_index for ordering in the prover queue - let inputs_json = job.inputs.to_json(&job.tree_id, job.seq); + let inputs_json = match job.inputs.to_json(&job.tree_id, job.seq) { + Ok(inputs_json) => inputs_json, + Err(error) => { + let _ = job + .result_tx + .send(ProofJobResult { + seq: job.seq, + result: Err(format!("Failed to serialize proof inputs: {}", error)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: 0, + submitted_at: std::time::Instant::now(), + }) + .await; + return; + } + }; let circuit_type = job.inputs.circuit_type(); let round_trip_start = std::time::Instant::now(); @@ -260,7 +277,7 @@ async fn poll_and_send_result( let client = clients.get_client(&inputs); // Poll; on job_not_found, resubmit once and poll the new job. - let result = match client.poll_proof_completion(job_id.clone()).await { + let result = match client.poll_proof_completion(&job_id).await { Ok(proof) => { let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; debug!( @@ -276,7 +293,22 @@ async fn poll_and_send_result( ); tokio::time::sleep(Duration::from_millis(200)).await; - let inputs_json = inputs.to_json(&tree_id, seq); + let inputs_json = match inputs.to_json(&tree_id, seq) { + Ok(inputs_json) => inputs_json, + Err(error) => { + let result = ProofJobResult { + seq, + result: Err(format!("Failed to serialize proof inputs: {}", error)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: round_trip_start.elapsed().as_millis() as u64, + submitted_at: round_trip_start, + }; + let _ = result_tx.send(result).await; + return; + } + }; let circuit_type = inputs.circuit_type(); match client.submit_proof_async(inputs_json, circuit_type).await { Ok(SubmitProofResult::Queued(new_job_id)) => { @@ -284,7 +316,7 @@ async fn poll_and_send_result( "Resubmitted proof job seq={} type={} new_job_id={}", seq, circuit_type, new_job_id ); - match client.poll_proof_completion(new_job_id.clone()).await { + match client.poll_proof_completion(&new_job_id).await { Ok(proof) => { let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; debug!( diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 06e94d5500..51ab05143a 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -14,11 +14,10 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, - common::get_leaves_hashchain, errors::V2Error, helpers::{ fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - lock_recover, StreamingAddressQueue, + AddressBatchSnapshot, StreamingAddressQueue, }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, @@ -267,9 +266,23 @@ impl BatchJobBuilder for AddressQueueData { let batch_end = start + zkp_batch_size_usize; - let batch_data = self - .streaming_queue - .get_batch_data(start, batch_end) + let streaming_queue = &self.streaming_queue; + let staging_tree = &mut self.staging_tree; + let hashchain_idx = start / zkp_batch_size_usize; + let AddressBatchSnapshot { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + leaves_hashchain, + } = streaming_queue + .get_batch_snapshot::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + start, + batch_end, + hashchain_idx, + )? .ok_or_else(|| { anyhow!( "Batch data not available: start={}, end={}, available={}", @@ -278,31 +291,21 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.available_batches() * zkp_batch_size_usize ) })?; - - let addresses = &batch_data.addresses; let zkp_batch_size_actual = addresses.len(); - - if zkp_batch_size_actual == 0 { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let low_element_values = &batch_data.low_element_values; - let low_element_next_values = &batch_data.low_element_next_values; - let low_element_indices = &batch_data.low_element_indices; - let low_element_next_indices = &batch_data.low_element_next_indices; - - let low_element_proofs: Vec> = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - (start..start + zkp_batch_size_actual) - .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) - .collect::, _>>()? - }; - - let hashchain_idx = start / zkp_batch_size_usize; - let leaves_hashchain = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? - }; + let result = staging_tree + .process_batch( + &addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ) + .map_err(|err| map_address_staging_error(tree, err))?; let tree_batch = tree_next_index / zkp_batch_size_usize; let absolute_index = data_start + start; @@ -318,24 +321,6 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.is_complete() ); - let result = self.staging_tree.process_batch( - addresses, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - &low_element_proofs, - leaves_hashchain, - zkp_batch_size_actual, - epoch, - tree, - ); - - let result = match result { - Ok(r) => r, - Err(err) => return Err(map_address_staging_error(tree, err)), - }; - Ok(Some(( ProofInput::AddressAppend(result.circuit_inputs), result.new_root, diff --git a/forester/src/queue_helpers.rs b/forester/src/queue_helpers.rs index f4a7dac704..44b65aca67 100644 --- a/forester/src/queue_helpers.rs +++ b/forester/src/queue_helpers.rs @@ -1,9 +1,9 @@ use account_compression::QueueAccount; +use forester_utils::account_zero_copy::parse_hash_set_from_bytes; use light_batched_merkle_tree::{ batch::BatchState, merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueAccount, }; use light_client::rpc::Rpc; -use light_hash_set::HashSet; use serde::{Deserialize, Serialize}; use solana_sdk::pubkey::Pubkey; use tracing::trace; @@ -181,7 +181,7 @@ pub async fn fetch_queue_item_data( ) -> Result { trace!("Fetching queue data for {:?}", queue_pubkey); let account = rpc.get_account(*queue_pubkey).await?; - let mut account = match account { + let account = match account { Some(acc) => acc, None => { tracing::warn!( @@ -209,7 +209,7 @@ pub async fn fetch_queue_item_data( total_pending: 0, }); } - let queue: HashSet = unsafe { HashSet::from_bytes_copy(&mut account.data[offset..])? }; + let queue = parse_hash_set_from_bytes::(&account.data)?; let end_index = queue.get_capacity(); let capacity = end_index as u64; diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index 1727ed108b..735cf06925 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -277,7 +277,7 @@ async fn e2e_test() { validator_args: vec![], })) .await; - spawn_prover().await; + spawn_prover().await.unwrap(); } let mut rpc = setup_rpc_connection(&env.protocol.forester).await; @@ -799,15 +799,22 @@ async fn setup_forester_pipeline( let (shutdown_bootstrap_sender, shutdown_bootstrap_receiver) = oneshot::channel(); let (work_report_sender, work_report_receiver) = mpsc::channel(100); - let service_handle = tokio::spawn(run_pipeline::( - Arc::from(config.clone()), - None, - None, - shutdown_receiver, - Some(shutdown_compressible_receiver), - Some(shutdown_bootstrap_receiver), - work_report_sender, - )); + let config = Arc::new(config.clone()); + let service_handle = tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(run_pipeline::( + config, + None, + None, + shutdown_receiver, + Some(shutdown_compressible_receiver), + Some(shutdown_bootstrap_receiver), + work_report_sender, + )) + }); ( service_handle, diff --git a/forester/tests/legacy/batched_state_async_indexer_test.rs b/forester/tests/legacy/batched_state_async_indexer_test.rs index fe599a39a8..178927fea8 100644 --- a/forester/tests/legacy/batched_state_async_indexer_test.rs +++ b/forester/tests/legacy/batched_state_async_indexer_test.rs @@ -87,7 +87,7 @@ async fn test_state_indexer_async_batched() { validator_args: vec![], })) .await; - spawn_prover().await; + spawn_prover().await.unwrap(); let env = TestAccounts::get_local_test_validator_accounts(); let mut config = forester_config(); diff --git a/forester/tests/legacy/test_utils.rs b/forester/tests/legacy/test_utils.rs index d535665d71..76b146a160 100644 --- a/forester/tests/legacy/test_utils.rs +++ b/forester/tests/legacy/test_utils.rs @@ -26,7 +26,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await; + spawn_validator(config).await.unwrap(); } #[allow(dead_code)] diff --git a/forester/tests/test_batch_append_spent.rs b/forester/tests/test_batch_append_spent.rs index e53c2b64eb..fe3bb8bef2 100644 --- a/forester/tests/test_batch_append_spent.rs +++ b/forester/tests/test_batch_append_spent.rs @@ -328,15 +328,22 @@ async fn run_forester(config: &ForesterConfig, duration: Duration) { tokio::sync::broadcast::channel(1); let (work_report_sender, _) = mpsc::channel(100); - let service_handle = tokio::spawn(run_pipeline::( - Arc::from(config.clone()), - None, - None, - shutdown_receiver, - Some(shutdown_compressible_receiver), - None, // shutdown_bootstrap - work_report_sender, - )); + let config = Arc::new(config.clone()); + let service_handle = tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(run_pipeline::( + config, + None, + None, + shutdown_receiver, + Some(shutdown_compressible_receiver), + None, // shutdown_bootstrap + work_report_sender, + )) + }); tokio::time::sleep(duration).await; diff --git a/forester/tests/test_compressible_pda.rs b/forester/tests/test_compressible_pda.rs index 97783d537c..2e134e7569 100644 --- a/forester/tests/test_compressible_pda.rs +++ b/forester/tests/test_compressible_pda.rs @@ -274,7 +274,8 @@ async fn test_compressible_pda_bootstrap() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await @@ -308,7 +309,8 @@ async fn test_compressible_pda_bootstrap() { RENT_SPONSOR, authority.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &authority.pubkey(), &[&authority]) .await @@ -468,7 +470,8 @@ async fn test_compressible_pda_compression() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await @@ -502,7 +505,8 @@ async fn test_compressible_pda_compression() { RENT_SPONSOR, authority.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &authority.pubkey(), &[&authority]) .await @@ -706,7 +710,8 @@ async fn test_compressible_pda_subscription() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await @@ -742,7 +747,8 @@ async fn test_compressible_pda_subscription() { RENT_SPONSOR, authority.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &authority.pubkey(), &[&authority]) .await diff --git a/forester/tests/test_indexer_interface.rs b/forester/tests/test_indexer_interface.rs index 6916cf0a3b..eaf56bf191 100644 --- a/forester/tests/test_indexer_interface.rs +++ b/forester/tests/test_indexer_interface.rs @@ -65,7 +65,8 @@ async fn test_indexer_interface_scenarios() { validator_args: vec![], use_surfpool: true, }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await diff --git a/forester/tests/test_utils.rs b/forester/tests/test_utils.rs index 4225503a19..3f50ad3695 100644 --- a/forester/tests/test_utils.rs +++ b/forester/tests/test_utils.rs @@ -36,7 +36,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await; + spawn_validator(config).await.unwrap(); } #[allow(dead_code)] diff --git a/program-tests/account-compression-test/tests/address_merkle_tree_tests.rs b/program-tests/account-compression-test/tests/address_merkle_tree_tests.rs index 2964ba9053..c294b6bbc7 100644 --- a/program-tests/account-compression-test/tests/address_merkle_tree_tests.rs +++ b/program-tests/account-compression-test/tests/address_merkle_tree_tests.rs @@ -75,10 +75,10 @@ async fn address_queue_and_tree_functional( ) .await .unwrap(); - let address_queue = unsafe { - get_hash_set::(&mut context, address_queue_pubkey).await - } - .unwrap(); + let address_queue = + get_hash_set::(&mut context, address_queue_pubkey) + .await + .unwrap(); assert!(address_queue.contains(&address1, None).unwrap()); assert!(address_queue.contains(&address2, None).unwrap()); @@ -105,10 +105,10 @@ async fn address_queue_and_tree_functional( ) .await .unwrap(); - let address_queue = unsafe { - get_hash_set::(&mut context, address_queue_pubkey).await - } - .unwrap(); + let address_queue = + get_hash_set::(&mut context, address_queue_pubkey) + .await + .unwrap(); address_queue .find_element(&address3, None) .unwrap() @@ -596,10 +596,10 @@ async fn update_address_merkle_tree_failing_tests( ) .await .unwrap(); - let address_queue = unsafe { - get_hash_set::(&mut context, address_queue_pubkey).await - } - .unwrap(); + let address_queue = + get_hash_set::(&mut context, address_queue_pubkey) + .await + .unwrap(); // CHECK: 2.1 cannot insert an address with an invalid low address test_with_invalid_low_element( &mut context, @@ -1087,10 +1087,10 @@ async fn update_address_merkle_tree_wrap_around( .await .unwrap(); - let address_queue = unsafe { - get_hash_set::(&mut context, address_queue_pubkey).await - } - .unwrap(); + let address_queue = + get_hash_set::(&mut context, address_queue_pubkey) + .await + .unwrap(); let value_index = address_queue .find_element_index(&address1, None) .unwrap() diff --git a/program-tests/account-compression-test/tests/merkle_tree_tests.rs b/program-tests/account-compression-test/tests/merkle_tree_tests.rs index ebcd59001a..f06990d218 100644 --- a/program-tests/account-compression-test/tests/merkle_tree_tests.rs +++ b/program-tests/account-compression-test/tests/merkle_tree_tests.rs @@ -312,13 +312,12 @@ async fn test_full_nullifier_queue( .unwrap(); assert_eq!(merkle_tree.root(), reference_merkle_tree.root()); let leaf_index = reference_merkle_tree.get_leaf_index(&leaf).unwrap() as u64; - let element_index = unsafe { + let element_index = get_hash_set::(&mut rpc, nullifier_queue_pubkey) .await .unwrap() .find_element_index(&BigUint::from_bytes_be(&leaf), None) - .unwrap() - }; + .unwrap(); // CHECK 2 nullify( &mut rpc, @@ -1225,8 +1224,9 @@ async fn functional_2_test_insert_into_nullifier_queues( ) .await .unwrap(); - let array = - unsafe { get_hash_set::(rpc, *nullifier_queue_pubkey).await }.unwrap(); + let array = get_hash_set::(rpc, *nullifier_queue_pubkey) + .await + .unwrap(); let element_0 = BigUint::from_bytes_be(&elements[0]); let (array_element_0, _) = array.find_element(&element_0, None).unwrap().unwrap(); assert_eq!(array_element_0.value_bytes(), [1u8; 32]); @@ -1307,8 +1307,9 @@ async fn functional_5_test_insert_into_nullifier_queue( ) .await .unwrap(); - let array = - unsafe { get_hash_set::(rpc, *nullifier_queue_pubkey).await }.unwrap(); + let array = get_hash_set::(rpc, *nullifier_queue_pubkey) + .await + .unwrap(); let (array_element, _) = array.find_element(&element, None).unwrap().unwrap(); assert_eq!(array_element.value_biguint(), element); @@ -2165,10 +2166,9 @@ pub async fn assert_element_inserted_in_nullifier_queue( nullifier_queue_pubkey: &Pubkey, nullifier: [u8; 32], ) { - let array = unsafe { - get_hash_set::(rpc, *nullifier_queue_pubkey).await - } - .unwrap(); + let array = get_hash_set::(rpc, *nullifier_queue_pubkey) + .await + .unwrap(); let nullifier_bn = BigUint::from_bytes_be(&nullifier); let (array_element, _) = array.find_element(&nullifier_bn, None).unwrap().unwrap(); assert_eq!(array_element.value_bytes(), nullifier); diff --git a/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs b/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs index da3c8d047d..6678af53e2 100644 --- a/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs +++ b/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs @@ -447,7 +447,7 @@ pub fn simulate_transaction( #[serial] #[tokio::test] async fn test_simulate_transactions() { - spawn_prover().await; + spawn_prover().await.unwrap(); let mut mock_indexer = MockBatchedForester::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>::default(); @@ -886,7 +886,7 @@ pub fn get_random_leaf(rng: &mut StdRng, active_leaves: &mut Vec<[u8; 32]>) -> ( #[serial] #[tokio::test] async fn test_e2e() { - spawn_prover().await; + spawn_prover().await.unwrap(); let mut mock_indexer = MockBatchedForester::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>::default(); @@ -1581,7 +1581,7 @@ pub fn get_rnd_bytes(rng: &mut StdRng) -> [u8; 32] { #[serial] #[tokio::test] async fn test_fill_state_queues_completely() { - spawn_prover().await; + spawn_prover().await.unwrap(); let mut current_slot = 1; let roothistory_capacity = vec![17, 80]; for root_history_capacity in roothistory_capacity { @@ -1981,7 +1981,7 @@ async fn test_fill_state_queues_completely() { #[serial] #[tokio::test] async fn test_fill_address_tree_completely() { - spawn_prover().await; + spawn_prover().await.unwrap(); let mut current_slot = 1; let roothistory_capacity = vec![17, 80]; // for root_history_capacity in roothistory_capacity { diff --git a/program-tests/compressed-token-test/tests/freeze/functional.rs b/program-tests/compressed-token-test/tests/freeze/functional.rs index 5d43d66015..42df6384c0 100644 --- a/program-tests/compressed-token-test/tests/freeze/functional.rs +++ b/program-tests/compressed-token-test/tests/freeze/functional.rs @@ -99,7 +99,7 @@ async fn freeze_or_thaw( #[tokio::test] #[serial] async fn test_freeze_thaw_v1_no_tlv_and_decompress() { - spawn_prover().await; + spawn_prover().await.unwrap(); let result = run_freeze_thaw_test(TokenDataVersion::V1).await; assert!(result.is_ok(), "Test failed: {:?}", result.err()); } @@ -109,7 +109,7 @@ async fn test_freeze_thaw_v1_no_tlv_and_decompress() { #[tokio::test] #[serial] async fn test_freeze_thaw_v2_no_tlv_and_decompress() { - spawn_prover().await; + spawn_prover().await.unwrap(); let result = run_freeze_thaw_test(TokenDataVersion::V2).await; assert!(result.is_ok(), "Test failed: {:?}", result.err()); } @@ -119,7 +119,7 @@ async fn test_freeze_thaw_v2_no_tlv_and_decompress() { #[tokio::test] #[serial] async fn test_freeze_thaw_sha_flat_no_tlv_and_decompress() { - spawn_prover().await; + spawn_prover().await.unwrap(); let result = run_freeze_thaw_test(TokenDataVersion::ShaFlat).await; assert!(result.is_ok(), "Test failed: {:?}", result.err()); } diff --git a/program-tests/compressed-token-test/tests/v1.rs b/program-tests/compressed-token-test/tests/v1.rs index 81c01c82fd..8af56cb82f 100644 --- a/program-tests/compressed-token-test/tests/v1.rs +++ b/program-tests/compressed-token-test/tests/v1.rs @@ -76,7 +76,7 @@ use spl_token::error::TokenError; #[serial] #[tokio::test] async fn test_wrapped_sol() { - spawn_prover().await; + spawn_prover().await.unwrap(); // is token 22 fails with Instruction: InitializeAccount, Program log: Error: Invalid Mint line 216 for is_token_22 in [false] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -729,7 +729,7 @@ async fn test_mint_to_failing() { #[serial] #[tokio::test] async fn test_transfers() { - spawn_prover().await; + spawn_prover().await.unwrap(); let possible_inputs = [1, 2, 3, 4, 8]; for input_num in possible_inputs { for output_num in 1..8 { @@ -892,7 +892,7 @@ async fn perform_transfer_22_test( #[serial] #[tokio::test] async fn test_decompression() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); let mut context = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -1036,7 +1036,7 @@ pub async fn assert_minted_to_all_token_pools( #[serial] #[tokio::test] async fn test_mint_to_and_burn_from_all_token_pools() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -1111,7 +1111,7 @@ async fn test_mint_to_and_burn_from_all_token_pools() { #[serial] #[tokio::test] async fn test_multiple_decompression() { - spawn_prover().await; + spawn_prover().await.unwrap(); let rng = &mut thread_rng(); for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); @@ -2286,7 +2286,7 @@ async fn test_revoke_failing() { #[serial] #[tokio::test] async fn test_burn() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -2558,7 +2558,7 @@ async fn test_burn() { #[serial] #[tokio::test] async fn failing_tests_burn() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -2913,7 +2913,7 @@ async fn failing_tests_burn() { /// 4. Freeze delegated tokens /// 5. Thaw delegated tokens async fn test_freeze_and_thaw(mint_amount: u64, delegated_amount: u64) { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3097,7 +3097,7 @@ async fn test_freeze_and_thaw_10000() { #[serial] #[tokio::test] async fn test_failing_freeze() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3360,7 +3360,7 @@ async fn test_failing_freeze() { #[serial] #[tokio::test] async fn test_failing_thaw() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3651,7 +3651,7 @@ async fn test_failing_thaw() { #[serial] #[tokio::test] async fn test_failing_decompression() { - spawn_prover().await; + spawn_prover().await.unwrap(); for is_token_22 in [false, true] { let mut context = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -4895,7 +4895,8 @@ async fn test_transfer_with_photon_and_batched_tree() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local_no_indexer()) .await diff --git a/program-tests/system-cpi-v2-test/tests/event.rs b/program-tests/system-cpi-v2-test/tests/event.rs index 9425d72144..2624847c9f 100644 --- a/program-tests/system-cpi-v2-test/tests/event.rs +++ b/program-tests/system-cpi-v2-test/tests/event.rs @@ -546,7 +546,8 @@ async fn generate_photon_test_data_multiple_events() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local_no_indexer()) .await diff --git a/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs b/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs index f6d52b1d2e..b8552e1f47 100644 --- a/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs +++ b/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs @@ -47,7 +47,7 @@ use solana_sdk::pubkey::Pubkey; #[serial] #[tokio::test] async fn functional_read_only() { - spawn_prover().await; + spawn_prover().await.unwrap(); for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)] { let config = if batched { let mut config = ProgramTestConfig::default_with_batched_trees(false); @@ -348,7 +348,7 @@ async fn functional_read_only() { #[serial] #[tokio::test] async fn functional_account_infos() { - spawn_prover().await; + spawn_prover().await.unwrap(); for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() { @@ -664,7 +664,7 @@ async fn functional_account_infos() { #[serial] #[tokio::test] async fn create_addresses_with_account_info() { - spawn_prover().await; + spawn_prover().await.unwrap(); let with_transaction_hash = true; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() @@ -1264,7 +1264,7 @@ async fn create_addresses_with_account_info() { #[serial] #[tokio::test] async fn create_addresses_with_read_only() { - spawn_prover().await; + spawn_prover().await.unwrap(); let with_transaction_hash = true; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() @@ -2024,7 +2024,7 @@ async fn compress_sol_with_account_info() { #[serial] #[tokio::test] async fn cpi_context_with_read_only() { - spawn_prover().await; + spawn_prover().await.unwrap(); let with_transaction_hash = false; let batched = true; for is_v2_ix in [true, false].into_iter() { @@ -2322,7 +2322,7 @@ async fn cpi_context_with_read_only() { #[serial] #[tokio::test] async fn cpi_context_with_account_info() { - spawn_prover().await; + spawn_prover().await.unwrap(); let with_transaction_hash = false; let batched = true; for is_v2_ix in [true, false].into_iter() { @@ -2841,7 +2841,7 @@ fn get_output_account_info(output_merkle_tree_index: u8) -> OutAccountInfo { #[serial] #[tokio::test] async fn test_duplicate_account_in_inputs_and_read_only() { - spawn_prover().await; + spawn_prover().await.unwrap(); let mut config = ProgramTestConfig::default_with_batched_trees(false); config.with_prover = false; diff --git a/program-tests/utils/src/account_zero_copy.rs b/program-tests/utils/src/account_zero_copy.rs new file mode 100644 index 0000000000..7e0ddaeb98 --- /dev/null +++ b/program-tests/utils/src/account_zero_copy.rs @@ -0,0 +1,113 @@ +use std::marker::PhantomData; + +use account_compression::{AddressMerkleTreeAccount, QueueAccount, StateMerkleTreeAccount}; +use anchor_lang::AccountDeserialize; +use borsh::BorshDeserialize; +use forester_utils::account_zero_copy::AccountZeroCopyError; +use light_batched_merkle_tree::{ + merkle_tree_metadata::BatchedMerkleTreeMetadata, queue::BatchedQueueMetadata, +}; +use light_client::rpc::Rpc; +use solana_sdk::{account::Account, pubkey::Pubkey}; + +pub trait AccountZeroCopyDeserialize: Sized { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result; +} + +fn deserialize_anchor_account( + data: &[u8], + pubkey: Pubkey, +) -> Result { + if data.len() < 8 { + return Err(AccountZeroCopyError::RpcError(format!( + "Account {} data too short: {}", + pubkey, + data.len() + ))); + } + + T::try_deserialize(&mut &data[..]).map_err(|error| { + AccountZeroCopyError::RpcError(format!( + "Failed to deserialize account {}: {}", + pubkey, error + )) + }) +} + +fn deserialize_borsh_account( + data: &[u8], + pubkey: Pubkey, +) -> Result { + if data.len() < 8 { + return Err(AccountZeroCopyError::RpcError(format!( + "Account {} data too short: {}", + pubkey, + data.len() + ))); + } + + T::try_from_slice(&data[8..]).map_err(|error| { + AccountZeroCopyError::RpcError(format!( + "Failed to deserialize account {}: {}", + pubkey, error + )) + }) +} + +impl AccountZeroCopyDeserialize for AddressMerkleTreeAccount { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result { + deserialize_anchor_account(data, pubkey) + } +} + +impl AccountZeroCopyDeserialize for QueueAccount { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result { + deserialize_anchor_account(data, pubkey) + } +} + +impl AccountZeroCopyDeserialize for StateMerkleTreeAccount { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result { + deserialize_anchor_account(data, pubkey) + } +} + +impl AccountZeroCopyDeserialize for BatchedMerkleTreeMetadata { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result { + deserialize_borsh_account(data, pubkey) + } +} + +impl AccountZeroCopyDeserialize for BatchedQueueMetadata { + fn deserialize_account(data: &[u8], pubkey: Pubkey) -> Result { + deserialize_borsh_account(data, pubkey) + } +} + +pub struct AccountZeroCopy { + pub account: Account, + pub pubkey: Pubkey, + _marker: PhantomData, +} + +impl AccountZeroCopy { + pub async fn new(rpc: &mut R, pubkey: Pubkey) -> Result { + let account = rpc + .get_account(pubkey) + .await + .map_err(|error| AccountZeroCopyError::RpcError(error.to_string()))? + .ok_or(AccountZeroCopyError::AccountNotFound(pubkey))?; + + Ok(Self { + account, + pubkey, + _marker: PhantomData, + }) + } +} + +impl AccountZeroCopy { + pub fn try_deserialized(&self) -> Result { + T::deserialize_account(&self.account.data, self.pubkey) + } +} diff --git a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs index 1ff92eeda9..5b1da2d229 100644 --- a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs +++ b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs @@ -211,7 +211,9 @@ pub async fn create_generic_transfer2_instruction( let mut packed_tree_accounts = PackedAccounts::default(); // tree infos must be packed before packing the token input accounts - let packed_tree_infos = rpc_proof_result.pack_tree_infos(&mut packed_tree_accounts); + let packed_tree_infos = rpc_proof_result + .pack_tree_infos(&mut packed_tree_accounts) + .map_err(|error| TokenSdkError::CpiError(error.to_string()))?; // We use a single shared output queue for all compress/compress-and-close operations to avoid ordering failures. let shared_output_queue = if packed_tree_infos.address_trees.is_empty() { diff --git a/program-tests/utils/src/address_tree_rollover.rs b/program-tests/utils/src/address_tree_rollover.rs index efc52387d5..402c4c52ea 100644 --- a/program-tests/utils/src/address_tree_rollover.rs +++ b/program-tests/utils/src/address_tree_rollover.rs @@ -258,10 +258,12 @@ pub async fn assert_rolled_over_address_merkle_tree_and_queue( // rent is reimbursed, 3 signatures cost 3 x 5000 lamports assert_eq!(*fee_payer_prior_balance, fee_payer_post_balance + 15000); { - let old_address_queue = - unsafe { get_hash_set::(rpc, *old_queue_pubkey).await }.unwrap(); - let new_address_queue = - unsafe { get_hash_set::(rpc, *new_queue_pubkey).await }.unwrap(); + let old_address_queue = get_hash_set::(rpc, *old_queue_pubkey) + .await + .unwrap(); + let new_address_queue = get_hash_set::(rpc, *new_queue_pubkey) + .await + .unwrap(); assert_eq!( old_address_queue.get_capacity(), diff --git a/program-tests/utils/src/assert_compressed_tx.rs b/program-tests/utils/src/assert_compressed_tx.rs index 8d6c2ced9d..ecc9c28fe4 100644 --- a/program-tests/utils/src/assert_compressed_tx.rs +++ b/program-tests/utils/src/assert_compressed_tx.rs @@ -1,8 +1,6 @@ use account_compression::{state::QueueAccount, StateMerkleTreeAccount}; use anchor_lang::Discriminator; -use forester_utils::account_zero_copy::{ - get_concurrent_merkle_tree, get_hash_set, AccountZeroCopy, -}; +use forester_utils::account_zero_copy::{get_concurrent_merkle_tree, get_hash_set}; use light_account_checks::discriminator::Discriminator as LightDiscriminator; use light_batched_merkle_tree::{ batch::Batch, merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueMetadata, @@ -22,7 +20,7 @@ use num_bigint::BigUint; use num_traits::FromBytes; use solana_sdk::{account::ReadableAccount, pubkey::Pubkey}; -use crate::system_program::get_sol_pool_pda; +use crate::{system_program::get_sol_pool_pda, AccountZeroCopy}; pub struct AssertCompressedTransactionInputs<'a, R: Rpc, I: Indexer + TestIndexerExtensions> { pub rpc: &'a mut R, @@ -133,11 +131,10 @@ pub async fn assert_nullifiers_exist_in_hash_sets( for (i, hash) in input_compressed_account_hashes.iter().enumerate() { match snapshots[i].tree_type { TreeType::StateV1 => { - let nullifier_queue = unsafe { + let nullifier_queue = get_hash_set::(rpc, snapshots[i].accounts.nullifier_queue) .await - } - .unwrap(); + .unwrap(); assert!(nullifier_queue .contains(&BigUint::from_be_bytes(hash.as_slice()), None) .unwrap()); @@ -183,8 +180,7 @@ pub async fn assert_addresses_exist_in_hash_sets( let discriminator = &account.data[0..8]; match discriminator { QueueAccount::DISCRIMINATOR => { - let address_queue = - unsafe { get_hash_set::(rpc, *pubkey).await }.unwrap(); + let address_queue = get_hash_set::(rpc, *pubkey).await.unwrap(); assert!(address_queue .contains(&BigUint::from_be_bytes(address), None) .unwrap()); @@ -490,7 +486,11 @@ pub async fn get_merkle_tree_snapshots( snapshots.push(MerkleTreeTestSnapShot { accounts: *account_bundle, root, - next_index: output_queue.deserialized().batch_metadata.next_index as usize, + next_index: output_queue + .try_deserialized() + .unwrap() + .batch_metadata + .next_index as usize, num_added_accounts: accounts .iter() .filter(|x| x.merkle_tree == account_bundle.merkle_tree) diff --git a/program-tests/utils/src/assert_merkle_tree.rs b/program-tests/utils/src/assert_merkle_tree.rs index aa5efbf44b..bf754391c4 100644 --- a/program-tests/utils/src/assert_merkle_tree.rs +++ b/program-tests/utils/src/assert_merkle_tree.rs @@ -1,10 +1,12 @@ use account_compression::StateMerkleTreeAccount; -use forester_utils::account_zero_copy::{get_concurrent_merkle_tree, AccountZeroCopy}; +use forester_utils::account_zero_copy::get_concurrent_merkle_tree; use light_client::rpc::Rpc; use light_hasher::Poseidon; use light_merkle_tree_metadata::fee::compute_rollover_fee; use solana_sdk::pubkey::Pubkey; +use crate::AccountZeroCopy; + #[allow(clippy::too_many_arguments)] pub async fn assert_merkle_tree_initialized( rpc: &mut R, @@ -29,7 +31,7 @@ pub async fn assert_merkle_tree_initialized( ) .await .unwrap(); - let merkle_tree_account = merkle_tree_account.deserialized(); + let merkle_tree_account = merkle_tree_account.try_deserialized().unwrap(); let balance_merkle_tree = rpc .get_account(*merkle_tree_pubkey) diff --git a/program-tests/utils/src/assert_queue.rs b/program-tests/utils/src/assert_queue.rs index 042a7bd523..f5a9e2717d 100644 --- a/program-tests/utils/src/assert_queue.rs +++ b/program-tests/utils/src/assert_queue.rs @@ -1,5 +1,5 @@ use account_compression::QueueAccount; -use forester_utils::account_zero_copy::{get_hash_set, AccountZeroCopy}; +use forester_utils::account_zero_copy::get_hash_set; use light_client::rpc::Rpc; use light_merkle_tree_metadata::{ access::AccessMetadata, fee::compute_rollover_fee, queue::QueueMetadata, @@ -7,6 +7,8 @@ use light_merkle_tree_metadata::{ }; use solana_sdk::pubkey::Pubkey; +use crate::AccountZeroCopy; + #[allow(clippy::too_many_arguments)] pub async fn assert_address_queue_initialized( rpc: &mut R, @@ -157,7 +159,7 @@ pub async fn assert_queue( let queue = AccountZeroCopy::::new(rpc, *queue_pubkey) .await .unwrap(); - let queue_account = queue.deserialized(); + let queue_account = queue.try_deserialized().unwrap(); let expected_rollover_meta_data = RolloverMetadata { index: expected_index, @@ -184,7 +186,9 @@ pub async fn assert_queue( }; assert_eq!(queue_account.metadata, expected_queue_meta_data); - let queue = unsafe { get_hash_set::(rpc, *queue_pubkey).await }.unwrap(); + let queue = get_hash_set::(rpc, *queue_pubkey) + .await + .unwrap(); assert_eq!(queue.get_capacity(), queue_config.capacity as usize); assert_eq!( queue.sequence_threshold, diff --git a/program-tests/utils/src/batched_address_tree.rs b/program-tests/utils/src/batched_address_tree.rs index 979b5b9d37..80473762ce 100644 --- a/program-tests/utils/src/batched_address_tree.rs +++ b/program-tests/utils/src/batched_address_tree.rs @@ -1,7 +1,7 @@ use std::cmp; use account_compression::{AddressMerkleTreeConfig, AddressQueueConfig, RegisteredProgram}; -use forester_utils::account_zero_copy::{get_hash_set, get_indexed_merkle_tree, AccountZeroCopy}; +use forester_utils::account_zero_copy::{get_hash_set, get_indexed_merkle_tree}; use light_client::rpc::{Rpc, RpcError}; use light_hasher::Poseidon; use light_merkle_tree_metadata::{ @@ -15,6 +15,8 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, }; +use crate::AccountZeroCopy; + #[allow(clippy::too_many_arguments)] #[inline(never)] pub async fn create_address_merkle_tree_and_queue_account_with_assert( @@ -141,7 +143,7 @@ pub async fn assert_address_merkle_tree_initialized( ) .await .unwrap(); - let merkle_tree_account = merkle_tree.deserialized(); + let merkle_tree_account = merkle_tree.try_deserialized().unwrap(); assert_eq!( merkle_tree_account @@ -354,7 +356,7 @@ pub async fn assert_queue( let queue = AccountZeroCopy::::new(rpc, *queue_pubkey) .await .unwrap(); - let queue_account = queue.deserialized(); + let queue_account = queue.try_deserialized().unwrap(); let expected_rollover_meta_data = RolloverMetadata { index: expected_index, @@ -381,9 +383,9 @@ pub async fn assert_queue( }; assert_eq!(queue_account.metadata, expected_queue_meta_data); - let queue = - unsafe { get_hash_set::(rpc, *queue_pubkey).await } - .unwrap(); + let queue = get_hash_set::(rpc, *queue_pubkey) + .await + .unwrap(); assert_eq!(queue.get_capacity(), queue_config.capacity as usize); assert_eq!( queue.sequence_threshold, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index a16a7925a7..46e4c50e94 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -73,7 +73,6 @@ use account_compression::{ use anchor_lang::{prelude::AccountMeta, AnchorSerialize, Discriminator}; use create_address_test_program::create_invoke_cpi_instruction; use forester_utils::{ - account_zero_copy::AccountZeroCopy, address_merkle_tree_config::{address_tree_ready_for_rollover, state_tree_ready_for_rollover}, forester_epoch::{Epoch, Forester, TreeAccounts}, utils::airdrop_lamports, @@ -194,6 +193,7 @@ use crate::{ }, test_batch_forester::{perform_batch_append, perform_batch_nullify}, test_forester::{empty_address_queue_test, nullify_compressed_accounts}, + AccountZeroCopy, }; pub struct User { @@ -748,70 +748,67 @@ where .with_address_queue(None, Some(batch.batch_size as u16)); let result = self .indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) + .get_queue_elements( + merkle_tree_pubkey.to_bytes(), + options, + None, + ) .await .unwrap(); - let addresses = result - .value - .address_queue - .map(|aq| aq.addresses) - .unwrap_or_default(); + let address_queue = result.value.address_queue.unwrap(); + let low_element_proofs = address_queue + .reconstruct_all_proofs::<{ + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize + }>() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = self - .indexer - .get_multiple_new_address_proofs( - merkle_tree_pubkey.to_bytes(), - addresses.clone(), - None, - ) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices - .push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices - .push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values - .push(non_inclusion_proof.low_address_next_value); - - low_element_proofs - .push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = self.indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); - let mut sparse_merkle_tree = SparseMerkleTree::::new(<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items).unwrap(), start_index); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + subtrees, + .. + } = address_queue; + let mut sparse_merkle_tree = SparseMerkleTree::< + Poseidon, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >::new( + subtrees.as_slice().try_into().unwrap(), + start_index, + ); - let mut changelog: Vec> = Vec::new(); - let mut indexed_changelog: Vec> = Vec::new(); + let mut changelog: Vec< + ChangelogEntry<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>, + > = Vec::new(); + let mut indexed_changelog: Vec< + IndexedChangelogEntry< + usize, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >, + > = Vec::new(); let inputs = get_batch_address_append_circuit_inputs::< { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, >( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -822,7 +819,7 @@ where let client = Client::new(); let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root).unwrap(); - let inputs = to_json(&inputs); + let inputs = to_json(&inputs).unwrap(); let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) @@ -834,9 +831,13 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| RpcError::CustomError(error.to_string())) + .unwrap(); + let (proof_a, proof_b, proof_c) = + proof_from_json_struct(proof_json).unwrap(); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c).unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { @@ -1278,7 +1279,8 @@ where .get_state_merkle_trees_mut() .push(StateMerkleTreeBundle { rollover_fee: state_tree_account - .deserialized() + .try_deserialized() + .unwrap() .metadata .rollover_metadata .rollover_fee as i64, @@ -1377,7 +1379,8 @@ where }) .unwrap(); bundle.rollover_fee = queue_account - .deserialized() + .try_deserialized() + .unwrap() .metadata .rollover_metadata .rollover_fee as i64; diff --git a/program-tests/utils/src/lib.rs b/program-tests/utils/src/lib.rs index a7fa07c4f8..d20237f96b 100644 --- a/program-tests/utils/src/lib.rs +++ b/program-tests/utils/src/lib.rs @@ -1,11 +1,10 @@ use std::cmp; use account_compression::{AddressMerkleTreeConfig, AddressQueueConfig, RegisteredProgram}; +pub use account_zero_copy::AccountZeroCopy; use batched_address_tree::assert_address_merkle_tree_initialized; pub use forester_utils::{ - account_zero_copy::{ - get_concurrent_merkle_tree, get_hash_set, get_indexed_merkle_tree, AccountZeroCopy, - }, + account_zero_copy::{get_concurrent_merkle_tree, get_hash_set, get_indexed_merkle_tree}, instructions::create_account_instruction, utils::airdrop_lamports, }; @@ -16,6 +15,7 @@ use solana_sdk::{ signature::{Keypair, Signature, Signer}, transaction, }; +pub mod account_zero_copy; pub mod actions; pub mod address; pub mod address_tree_rollover; diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 4458aa03b3..0101b235aa 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -260,7 +260,7 @@ impl MockBatchedAddressForester { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; HEIGHT]> = Vec::new(); for new_element_value in &new_element_values { let non_inclusion_proof = self .merkle_tree @@ -270,7 +270,18 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + let proof = non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; + low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); let mut merkle_tree = match <[[u8; 32]; HEIGHT]>::try_from(subtrees) { @@ -287,12 +298,12 @@ impl MockBatchedAddressForester { let inputs = match get_batch_address_append_circuit_inputs::( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values.clone(), + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut merkle_tree, leaves_hashchain, zkp_batch_size as usize, diff --git a/program-tests/utils/src/state_tree_rollover.rs b/program-tests/utils/src/state_tree_rollover.rs index 1402c2f749..f0959a611a 100644 --- a/program-tests/utils/src/state_tree_rollover.rs +++ b/program-tests/utils/src/state_tree_rollover.rs @@ -283,10 +283,12 @@ pub async fn assert_rolled_over_pair( *fee_payer_prior_balance, fee_payer_post_balance + 5000 * num_signatures + additional_rent ); - let old_address_queue = - unsafe { get_hash_set::(rpc, *old_nullifier_queue_pubkey).await }.unwrap(); - let new_address_queue = - unsafe { get_hash_set::(rpc, *new_nullifier_queue_pubkey).await }.unwrap(); + let old_address_queue = get_hash_set::(rpc, *old_nullifier_queue_pubkey) + .await + .unwrap(); + let new_address_queue = get_hash_set::(rpc, *new_nullifier_queue_pubkey) + .await + .unwrap(); assert_eq!( old_address_queue.get_capacity(), diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 8e6909704f..c11ddb3db1 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -165,7 +165,9 @@ pub async fn create_append_batch_ix_data( bundle.merkle_tree.root() ); let proof_client = ProofClient::local(); - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs) + .to_string() + .unwrap(); match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => ( @@ -296,7 +298,7 @@ pub async fn get_batched_nullify_ix_data( let proof_client = ProofClient::local(); let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root.to_biguint().unwrap()).unwrap(); - let inputs_json = update_inputs_string(&inputs); + let inputs_json = update_inputs_string(&inputs).unwrap(); let new_root = bundle.merkle_tree.root(); assert_eq!(circuit_inputs_new_root, new_root); @@ -319,13 +321,13 @@ pub async fn get_batched_nullify_ix_data( }) } -use forester_utils::{ - account_zero_copy::AccountZeroCopy, instructions::create_account::create_account_instruction, -}; +use forester_utils::instructions::create_account::create_account_instruction; use light_client::indexer::{Indexer, QueueElementsV2Options}; use light_program_test::indexer::state_tree::StateMerkleTreeBundle; use light_sparse_merkle_tree::SparseMerkleTree; +use crate::AccountZeroCopy; + pub async fn assert_registry_created_batched_state_merkle_tree( rpc: &mut R, payer_pubkey: Pubkey, @@ -663,50 +665,33 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_slice(addresses.as_slice()).unwrap(); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 1, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = indexer - .get_multiple_new_address_proofs(merkle_tree_pubkey.to_bytes(), addresses.clone(), None) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices.push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices.push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values.push(non_inclusion_proof.low_address_next_value); - - low_element_proofs.push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + subtrees, + .. + } = address_queue; let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new( - <[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items) - .unwrap(), - start_index, - ); + >::new(subtrees.as_slice().try_into().unwrap(), start_index); let mut changelog: Vec> = Vec::new(); @@ -718,12 +703,12 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -734,7 +719,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof(&inputs.new_root).unwrap(); - let inputs_json = to_json(&inputs); + let inputs_json = to_json(&inputs).unwrap(); match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => { diff --git a/program-tests/utils/src/test_forester.rs b/program-tests/utils/src/test_forester.rs index 13051a0517..dcba8e56c5 100644 --- a/program-tests/utils/src/test_forester.rs +++ b/program-tests/utils/src/test_forester.rs @@ -376,8 +376,9 @@ pub async fn empty_address_queue_test( .await .unwrap(); assert_eq!(address_tree_bundle.root(), address_merkle_tree.root()); - let address_queue = - unsafe { get_hash_set::(rpc, address_queue_pubkey).await }.unwrap(); + let address_queue = get_hash_set::(rpc, address_queue_pubkey) + .await + .unwrap(); let address = address_queue.first_no_seq().unwrap(); @@ -562,9 +563,9 @@ pub async fn empty_address_queue_test( let address_bundle = address_tree_bundle .new_element_with_low_element_index(old_low_address.index, &address.value_biguint()) .unwrap(); - let address_queue = - unsafe { get_hash_set::(rpc, address_queue_pubkey).await } - .unwrap(); + let address_queue = get_hash_set::(rpc, address_queue_pubkey) + .await + .unwrap(); assert_eq!( address_queue diff --git a/prover/client/src/constants.rs b/prover/client/src/constants.rs index 18a5c05a45..151bf87918 100644 --- a/prover/client/src/constants.rs +++ b/prover/client/src/constants.rs @@ -1,4 +1,4 @@ -pub const SERVER_ADDRESS: &str = "http://localhost:3001"; +pub const SERVER_ADDRESS: &str = "http://127.0.0.1:3001"; pub const HEALTH_CHECK: &str = "/health"; pub const PROVE_PATH: &str = "/prove"; diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index 85c1bc8fbe..e1ea89624f 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -37,6 +37,24 @@ pub enum ProverClientError { #[error("Invalid proof data: {0}")] InvalidProofData(String), + #[error("Integer conversion failed: {0}")] + IntegerConversion(String), + + #[error("JSON serialization failed: {0}")] + JsonSerialization(String), + + #[error("Failed to start prover process: {0}")] + ProcessStart(String), + + #[error("Failed to wait for prover process: {0}")] + ProcessWait(String), + + #[error("Project root not found")] + ProjectRootNotFound, + + #[error("Prover health check failed after startup")] + HealthCheckFailed, + #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 6ea223e79f..4124f91489 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -6,6 +6,8 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; +use crate::errors::ProverClientError; + pub fn get_project_root() -> Option { let output = Command::new("git") .args(["rev-parse", "--show-toplevel"]) @@ -33,10 +35,12 @@ pub fn convert_endianness_128(bytes: &[u8]) -> Vec { .collect::>() } -pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box> { +pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], ProverClientError> { let (_, bytes_be) = n.to_bytes_be(); if bytes_be.len() > 32 { - Err("Number too large to fit in [u8; 32]")?; + return Err(ProverClientError::InvalidProofData( + "Number too large to fit in [u8; 32]".to_string(), + )); } let mut array = [0; 32]; let bytes = &bytes_be[..bytes_be.len()]; @@ -48,7 +52,7 @@ pub fn compute_root_from_merkle_proof( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], path_index: u32, -) -> ([u8; 32], ChangelogEntry) { +) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); let mut current_hash = leaf; @@ -56,14 +60,14 @@ pub fn compute_root_from_merkle_proof( for (level, path_element) in path_elements.iter().enumerate() { changelog_entry.path[level] = Some(current_hash); if current_index.is_multiple_of(2) { - current_hash = Poseidon::hashv(&[¤t_hash, path_element]).unwrap(); + current_hash = Poseidon::hashv(&[¤t_hash, path_element])?; } else { - current_hash = Poseidon::hashv(&[path_element, ¤t_hash]).unwrap(); + current_hash = Poseidon::hashv(&[path_element, ¤t_hash])?; } current_index /= 2; } - (current_hash, changelog_entry) + Ok((current_hash, changelog_entry)) } pub fn big_uint_to_string(big_uint: &BigUint) -> String { @@ -85,8 +89,14 @@ pub fn create_vec_of_string(number_of_utxos: usize, element: &BigInt) -> Vec Vec { - vec![element.to_u32().unwrap(); number_of_utxos] +pub fn create_vec_of_u32( + number_of_utxos: usize, + element: &BigInt, +) -> Result, ProverClientError> { + let value = element.to_u32().ok_or_else(|| { + ProverClientError::IntegerConversion(format!("cannot convert {} to u32", element)) + })?; + Ok(vec![value; number_of_utxos]) } pub fn create_vec_of_vec_of_string( @@ -100,9 +110,10 @@ pub fn create_vec_of_vec_of_string( vec![vec; number_of_utxos] } -pub fn create_json_from_struct(json_struct: &T) -> String +pub fn create_json_from_struct(json_struct: &T) -> Result where T: Serialize, { - serde_json::to_string(json_struct).expect("JSON serialization failed for valid struct") + serde_json::to_string(json_struct) + .map_err(|e| ProverClientError::JsonSerialization(e.to_string())) } diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index c415a4d108..4349d24d51 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -19,6 +19,9 @@ pub struct ProofCompressed { pub c: [u8; 32], } +pub type CompressedProofParts = ([u8; 32], [u8; 64], [u8; 32]); +pub type UncompressedProofParts = ([u8; 64], [u8; 128], [u8; 64]); + #[derive(Debug, Clone, Copy)] pub struct ProofResult { pub proof: ProofCompressed, @@ -66,16 +69,23 @@ pub fn deserialize_gnark_proof_json(json_data: &str) -> serde_json::Result [u8; 32] { +pub fn deserialize_hex_string_to_be_bytes(hex_str: &str) -> Result<[u8; 32], ProverClientError> { let trimmed_str = hex_str.trim_start_matches("0x"); - let big_int = num_bigint::BigInt::from_str_radix(trimmed_str, 16).unwrap(); + let big_int = num_bigint::BigInt::from_str_radix(trimmed_str, 16) + .map_err(|e| ProverClientError::InvalidHexString(format!("{} ({})", hex_str, e)))?; let big_int_bytes = big_int.to_bytes_be().1; if big_int_bytes.len() < 32 { let mut result = [0u8; 32]; result[32 - big_int_bytes.len()..].copy_from_slice(&big_int_bytes); - result + Ok(result) } else { - big_int_bytes.try_into().unwrap() + let len = big_int_bytes.len(); + big_int_bytes.try_into().map_err(|_| { + ProverClientError::InvalidHexString(format!( + "expected at most 32 bytes, got {} for {}", + len, hex_str + )) + }) } } @@ -83,47 +93,92 @@ pub fn compress_proof( proof_a: &[u8; 64], proof_b: &[u8; 128], proof_c: &[u8; 64], -) -> ([u8; 32], [u8; 64], [u8; 32]) { - let proof_a = alt_bn128_g1_compress(proof_a).unwrap(); - let proof_b = alt_bn128_g2_compress(proof_b).unwrap(); - let proof_c = alt_bn128_g1_compress(proof_c).unwrap(); - (proof_a, proof_b, proof_c) +) -> Result { + let proof_a = alt_bn128_g1_compress(proof_a)?; + let proof_b = alt_bn128_g2_compress(proof_b)?; + let proof_c = alt_bn128_g1_compress(proof_c)?; + Ok((proof_a, proof_b, proof_c)) } -pub fn proof_from_json_struct(json: GnarkProofJson) -> ([u8; 64], [u8; 128], [u8; 64]) { - let proof_a_x = deserialize_hex_string_to_be_bytes(&json.ar[0]); - let proof_a_y = deserialize_hex_string_to_be_bytes(&json.ar[1]); - let proof_a: [u8; 64] = [proof_a_x, proof_a_y].concat().try_into().unwrap(); - let proof_a = negate_g1(&proof_a); - let proof_b_x_0 = deserialize_hex_string_to_be_bytes(&json.bs[0][0]); - let proof_b_x_1 = deserialize_hex_string_to_be_bytes(&json.bs[0][1]); - let proof_b_y_0 = deserialize_hex_string_to_be_bytes(&json.bs[1][0]); - let proof_b_y_1 = deserialize_hex_string_to_be_bytes(&json.bs[1][1]); +pub fn proof_from_json_struct( + json: GnarkProofJson, +) -> Result { + let proof_a_x = deserialize_hex_string_to_be_bytes( + json.ar + .first() + .ok_or_else(|| ProverClientError::InvalidProofData("missing ar[0]".to_string()))?, + )?; + let proof_a_y = deserialize_hex_string_to_be_bytes( + json.ar + .get(1) + .ok_or_else(|| ProverClientError::InvalidProofData("missing ar[1]".to_string()))?, + )?; + let proof_a: [u8; 64] = [proof_a_x, proof_a_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof_a length".to_string()))?; + let proof_a = negate_g1(&proof_a)?; + let proof_b_x_0 = deserialize_hex_string_to_be_bytes( + json.bs + .first() + .and_then(|row| row.first()) + .ok_or_else(|| ProverClientError::InvalidProofData("missing bs[0][0]".to_string()))?, + )?; + let proof_b_x_1 = deserialize_hex_string_to_be_bytes( + json.bs + .first() + .and_then(|row| row.get(1)) + .ok_or_else(|| ProverClientError::InvalidProofData("missing bs[0][1]".to_string()))?, + )?; + let proof_b_y_0 = deserialize_hex_string_to_be_bytes( + json.bs + .get(1) + .and_then(|row| row.first()) + .ok_or_else(|| ProverClientError::InvalidProofData("missing bs[1][0]".to_string()))?, + )?; + let proof_b_y_1 = deserialize_hex_string_to_be_bytes( + json.bs + .get(1) + .and_then(|row| row.get(1)) + .ok_or_else(|| ProverClientError::InvalidProofData("missing bs[1][1]".to_string()))?, + )?; let proof_b: [u8; 128] = [proof_b_x_0, proof_b_x_1, proof_b_y_0, proof_b_y_1] .concat() .try_into() - .unwrap(); + .map_err(|_| ProverClientError::InvalidProofData("invalid proof_b length".to_string()))?; - let proof_c_x = deserialize_hex_string_to_be_bytes(&json.krs[0]); - let proof_c_y = deserialize_hex_string_to_be_bytes(&json.krs[1]); - let proof_c: [u8; 64] = [proof_c_x, proof_c_y].concat().try_into().unwrap(); - (proof_a, proof_b, proof_c) + let proof_c_x = deserialize_hex_string_to_be_bytes( + json.krs + .first() + .ok_or_else(|| ProverClientError::InvalidProofData("missing krs[0]".to_string()))?, + )?; + let proof_c_y = deserialize_hex_string_to_be_bytes( + json.krs + .get(1) + .ok_or_else(|| ProverClientError::InvalidProofData("missing krs[1]".to_string()))?, + )?; + let proof_c: [u8; 64] = [proof_c_x, proof_c_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof_c length".to_string()))?; + Ok((proof_a, proof_b, proof_c)) } -pub fn negate_g1(g1_be: &[u8; 64]) -> [u8; 64] { +pub fn negate_g1(g1_be: &[u8; 64]) -> Result<[u8; 64], ProverClientError> { let g1_le = convert_endianness::<32, 64>(g1_be); - let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::No).unwrap(); + let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::No) + .map_err(|e| ProverClientError::InvalidProofData(e.to_string()))?; let g1_neg = g1.neg(); let mut g1_neg_be = [0u8; 64]; g1_neg .x .serialize_with_mode(&mut g1_neg_be[..32], Compress::No) - .unwrap(); + .map_err(|e| ProverClientError::InvalidProofData(e.to_string()))?; g1_neg .y .serialize_with_mode(&mut g1_neg_be[32..], Compress::No) - .unwrap(); + .map_err(|e| ProverClientError::InvalidProofData(e.to_string()))?; let g1_neg_be: [u8; 64] = convert_endianness::<32, 64>(&g1_neg_be); - g1_neg_be + Ok(g1_neg_be) } diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 1d557407bd..f95f5fac80 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -161,9 +161,9 @@ impl ProofClient { pub async fn poll_proof_completion( &self, - job_id: String, + job_id: &str, ) -> Result { - self.poll_for_result(&job_id, Duration::ZERO).await + self.poll_for_result(job_id, Duration::ZERO).await } pub async fn generate_proof( @@ -655,8 +655,8 @@ impl ProofClient { ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e)) })?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json)?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c)?; Ok(ProofResult { proof: ProofCompressed { @@ -673,7 +673,7 @@ impl ProofClient { inputs: BatchAddressAppendInputs, ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.new_root)?; - let inputs_json = to_json(&inputs); + let inputs_json = to_json(&inputs)?; let proof = self.generate_proof(inputs_json).await?; Ok((proof, new_root)) } @@ -682,10 +682,11 @@ impl ProofClient { &self, circuit_inputs: BatchAppendsCircuitInputs, ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { - let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>( - &circuit_inputs.new_root.to_biguint().unwrap(), - )?; - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); + let new_root_biguint = circuit_inputs.new_root.to_biguint().ok_or_else(|| { + ProverClientError::InvalidProofData("new_root must be non-negative".to_string()) + })?; + let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&new_root_biguint)?; + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string()?; let proof = self.generate_proof(inputs_json).await?; Ok((proof, new_root)) } @@ -694,10 +695,11 @@ impl ProofClient { &self, circuit_inputs: BatchUpdateCircuitInputs, ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { - let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>( - &circuit_inputs.new_root.to_biguint().unwrap(), - )?; - let json_str = update_inputs_string(&circuit_inputs); + let new_root_biguint = circuit_inputs.new_root.to_biguint().ok_or_else(|| { + ProverClientError::InvalidProofData("new_root must be non-negative".to_string()) + })?; + let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&new_root_biguint)?; + let json_str = update_inputs_string(&circuit_inputs)?; let proof = self.generate_proof(json_str).await?; Ok((proof, new_root)) } diff --git a/prover/client/src/proof_types/batch_address_append/json.rs b/prover/client/src/proof_types/batch_address_append/json.rs index cd31a326e8..4c1c0e5747 100644 --- a/prover/client/src/proof_types/batch_address_append/json.rs +++ b/prover/client/src/proof_types/batch_address_append/json.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_uint_to_string, create_json_from_struct}, proof_types::{batch_address_append::BatchAddressAppendInputs, circuit_type::CircuitType}, }; @@ -102,12 +103,12 @@ impl BatchAddressAppendInputsJson { } #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } } -pub fn to_json(inputs: &BatchAddressAppendInputs) -> String { +pub fn to_json(inputs: &BatchAddressAppendInputs) -> Result { let json_struct = BatchAddressAppendInputsJson::from_inputs(inputs); json_struct.to_string() } diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index f80e8d49e4..32408fdc02 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; use light_hasher::{ bigint::bigint_to_be_bytes_array, @@ -187,21 +187,28 @@ impl BatchAddressAppendInputs { pub fn get_batch_address_append_circuit_inputs( next_index: usize, current_root: [u8; 32], - low_element_values: Vec<[u8; 32]>, - low_element_next_values: Vec<[u8; 32]>, - low_element_indices: Vec, - low_element_next_indices: Vec, - low_element_proofs: Vec>, - new_element_values: Vec<[u8; 32]>, + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[impl Copy + TryInto + Debug], + low_element_next_indices: &[impl Copy + TryInto + Debug], + low_element_proofs: &[[[u8; 32]; HEIGHT]], + new_element_values: &[[u8; 32]], sparse_merkle_tree: &mut SparseMerkleTree, leaves_hashchain: [u8; 32], zkp_batch_size: usize, changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); - - let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + let new_element_values = &new_element_values[..zkp_batch_size]; + let mut new_root = [0u8; 32]; + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + + let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) })?; if computed_hashchain != leaves_hashchain { @@ -229,15 +236,6 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = vec![]; - let mut new_element_circuit_merkle_proofs = vec![]; - - let mut patched_low_element_next_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_next_indices: Vec = Vec::new(); - let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_indices: Vec = Vec::new(); - let mut patcher = ChangelogProofPatcher::new::(changelog); let is_first_batch = indexed_changelog.is_empty(); @@ -245,21 +243,33 @@ pub fn get_batch_address_append_circuit_inputs( for i in 0..new_element_values.len() { let mut changelog_index = 0; + let low_element_index = low_element_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element index {:?} does not fit into usize", + low_element_indices[i] + )) + })?; + let low_element_next_index = low_element_next_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element next index {:?} does not fit into usize", + low_element_next_indices[i] + )) + })?; let new_element_index = next_index + i; let mut low_element: IndexedElement = IndexedElement { - index: low_element_indices[i], + index: low_element_index, value: BigUint::from_bytes_be(&low_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; let mut new_element: IndexedElement = IndexedElement { index: new_element_index, value: BigUint::from_bytes_be(&new_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof = low_element_proofs[i]; let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); patch_indexed_changelogs( 0, @@ -293,18 +303,10 @@ pub fn get_batch_address_append_circuit_inputs( next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; + let low_element_changelog_proof = low_element_proof; let intermediate_root = { - let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof - .clone() - .try_into() - .map_err(|v: Vec<[u8; 32]>| { - ProverClientError::ProofPatchFailed(format!( - "low element proof length mismatch: expected {}, got {}", - HEIGHT, - v.len() - )) - })?; + let mut low_element_proof_arr = low_element_changelog_proof; patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); let merkle_proof = low_element_proof_arr; @@ -321,7 +323,7 @@ pub fn get_batch_address_append_circuit_inputs( old_low_leaf_hash, &merkle_proof, low_element.index as u32, - ); + )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) .map_err(|e| { @@ -362,7 +364,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_leaf_hash, &merkle_proof, new_low_element.index as u32, - ); + )?; patcher.push_changelog_entry::(changelog, changelog_entry); low_element_circuit_merkle_proofs.push( @@ -376,13 +378,7 @@ pub fn get_batch_address_append_circuit_inputs( }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT] - .try_into() - .map_err(|_| { - ProverClientError::ProofPatchFailed( - "low_element_proof slice conversion failed".to_string(), - ) - })?, + proof: low_element_changelog_proof, changelog_index: indexed_changelog.len(), //change_log_index, }; @@ -409,7 +405,7 @@ pub fn get_batch_address_append_circuit_inputs( new_element_leaf_hash, &merkle_proof_array, current_index as u32, - ); + )?; if i == 0 && changelog.len() == 1 { if sparse_next_idx_before != current_index { @@ -436,7 +432,7 @@ pub fn get_batch_address_append_circuit_inputs( zero_hash, &merkle_proof_array, current_index as u32, - ); + )?; if root_with_zero != intermediate_root { tracing::error!( "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ diff --git a/prover/client/src/proof_types/batch_append/json.rs b/prover/client/src/proof_types/batch_append/json.rs index 7f68d0899c..53b78e649a 100644 --- a/prover/client/src/proof_types/batch_append/json.rs +++ b/prover/client/src/proof_types/batch_append/json.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{batch_append::BatchAppendsCircuitInputs, circuit_type::CircuitType}, }; @@ -73,7 +74,7 @@ impl BatchAppendInputsJson { } #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } } diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index ef0327ac1d..c6a889acf7 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -25,8 +25,8 @@ pub struct BatchAppendsCircuitInputs { } impl BatchAppendsCircuitInputs { - pub fn public_inputs_arr(&self) -> [u8; 32] { - bigint_to_u8_32(&self.public_input_hash).unwrap() + pub fn public_inputs_arr(&self) -> Result<[u8; 32], ProverClientError> { + bigint_to_u8_32(&self.public_input_hash) } pub fn new( @@ -177,7 +177,13 @@ pub fn get_batch_append_inputs( } } - let merkle_proof_array = merkle_proof.try_into().unwrap(); + let merkle_proof_array = merkle_proof.as_slice().try_into().map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid merkle proof length: expected {}, got {}", + HEIGHT, + merkle_proof.len() + )) + })?; // Determine if we use the old or new leaf based on whether the old leaf is nullified (zeroed). let is_old_leaf_zero = old_leaf.iter().all(|&byte| byte == 0); let final_leaf = if is_old_leaf_zero { @@ -187,8 +193,11 @@ pub fn get_batch_append_inputs( }; // Update the root based on the current proof and nullifier - let (updated_root, changelog_entry) = - compute_root_from_merkle_proof(final_leaf, &merkle_proof_array, start_index + i as u32); + let (updated_root, changelog_entry) = compute_root_from_merkle_proof( + final_leaf, + &merkle_proof_array, + start_index + i as u32, + )?; new_root = updated_root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/batch_update/json.rs b/prover/client/src/proof_types/batch_update/json.rs index d03a93dc47..3faf7e676c 100644 --- a/prover/client/src/proof_types/batch_update/json.rs +++ b/prover/client/src/proof_types/batch_update/json.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{batch_update::BatchUpdateCircuitInputs, circuit_type::CircuitType}, }; @@ -103,12 +104,14 @@ impl BatchUpdateProofInputsJson { } #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } } -pub fn update_inputs_string(inputs: &BatchUpdateCircuitInputs) -> String { +pub fn update_inputs_string( + inputs: &BatchUpdateCircuitInputs, +) -> Result { let json_struct = BatchUpdateProofInputsJson::from_update_inputs(inputs); json_struct.to_string() } diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2136d01d10..1bd8acad43 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -31,8 +31,8 @@ pub struct BatchUpdateCircuitInputs { } impl BatchUpdateCircuitInputs { - pub fn public_inputs_arr(&self) -> [u8; 32] { - bigint_to_u8_32(&self.public_input_hash).unwrap() + pub fn public_inputs_arr(&self) -> Result<[u8; 32], ProverClientError> { + bigint_to_u8_32(&self.public_input_hash) } pub fn new( @@ -112,9 +112,9 @@ impl BatchUpdateCircuitInputs { pub struct BatchUpdateInputs<'a>(pub &'a [BatchUpdateCircuitInputs]); impl BatchUpdateInputs<'_> { - pub fn public_inputs(&self) -> Vec<[u8; 32]> { + pub fn public_inputs(&self) -> Result, ProverClientError> { // Concatenate all public inputs into a single flat vector - vec![self.0[0].public_inputs_arr()] + Ok(vec![self.0[0].public_inputs_arr()?]) } } @@ -168,14 +168,20 @@ pub fn get_batch_update_inputs( } } - let merkle_proof_array = merkle_proof.try_into().unwrap(); + let merkle_proof_array = merkle_proof.as_slice().try_into().map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid merkle proof length: expected {}, got {}", + HEIGHT, + merkle_proof.len() + )) + })?; // Use the adjusted index bytes for computing the nullifier. let mut index_bytes = [0u8; 32]; index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); - let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); + let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]])?; let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index); + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/combined/v1/json.rs b/prover/client/src/proof_types/combined/v1/json.rs index 377848bd5b..1186e628ef 100644 --- a/prover/client/src/proof_types/combined/v1/json.rs +++ b/prover/client/src/proof_types/combined/v1/json.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::create_json_from_struct, proof_types::{ circuit_type::CircuitType, @@ -24,22 +25,23 @@ pub struct CombinedJsonStruct { } impl CombinedJsonStruct { - pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Self { + pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Result { let inclusion_parameters = - BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters); - let non_inclusion_parameters = BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( - &inputs.non_inclusion_parameters, - ); - Self { + BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters)?; + let non_inclusion_parameters = + BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( + &inputs.non_inclusion_parameters, + )?; + Ok(Self { circuit_type: CircuitType::Combined.to_string(), state_tree_height: inclusion_parameters.state_tree_height, address_tree_height: non_inclusion_parameters.address_tree_height, inclusion: inclusion_parameters.inputs, non_inclusion: non_inclusion_parameters.inputs, - } + }) } #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } } diff --git a/prover/client/src/proof_types/combined/v2/json.rs b/prover/client/src/proof_types/combined/v2/json.rs index 322a5ee8ec..a15496dae3 100644 --- a/prover/client/src/proof_types/combined/v2/json.rs +++ b/prover/client/src/proof_types/combined/v2/json.rs @@ -2,6 +2,7 @@ use serde::Serialize; use crate::{ constants::{DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT}, + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{ circuit_type::CircuitType, @@ -29,25 +30,26 @@ pub struct CombinedJsonStruct { } impl CombinedJsonStruct { - pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Self { + pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Result { let inclusion_parameters = - BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters); - let non_inclusion_parameters = BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( - &inputs.non_inclusion_parameters, - ); + BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters)?; + let non_inclusion_parameters = + BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( + &inputs.non_inclusion_parameters, + )?; - Self { + Ok(Self { circuit_type: CircuitType::Combined.to_string(), state_tree_height: DEFAULT_BATCH_STATE_TREE_HEIGHT, address_tree_height: DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, public_input_hash: big_int_to_string(&inputs.public_input_hash), inclusion: inclusion_parameters.inputs, non_inclusion: non_inclusion_parameters.inputs, - } + }) } #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } } diff --git a/prover/client/src/proof_types/combined/v2/proof_inputs.rs b/prover/client/src/proof_types/combined/v2/proof_inputs.rs index 67a65c4e57..faea10ecdd 100644 --- a/prover/client/src/proof_types/combined/v2/proof_inputs.rs +++ b/prover/client/src/proof_types/combined/v2/proof_inputs.rs @@ -37,8 +37,8 @@ impl<'a> CombinedProofInputs<'a> { Ok(BigInt::from_bytes_be( num_bigint::Sign::Plus, &create_hash_chain_from_array([ - bigint_to_u8_32(&inclusion_parameters.public_input_hash).unwrap(), - bigint_to_u8_32(&non_inclusion_parameters.public_input_hash).unwrap(), + bigint_to_u8_32(&inclusion_parameters.public_input_hash)?, + bigint_to_u8_32(&non_inclusion_parameters.public_input_hash)?, ])?, )) } diff --git a/prover/client/src/proof_types/inclusion/v1/json.rs b/prover/client/src/proof_types/inclusion/v1/json.rs index 1a5f5defc5..d97ae28a55 100644 --- a/prover/client/src/proof_types/inclusion/v1/json.rs +++ b/prover/client/src/proof_types/inclusion/v1/json.rs @@ -2,6 +2,7 @@ use num_traits::ToPrimitive; use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{ circuit_type::CircuitType, @@ -21,25 +22,32 @@ pub struct BatchInclusionJsonStruct { impl BatchInclusionJsonStruct { #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } - pub fn from_inclusion_proof_inputs(inputs: &InclusionProofInputs) -> Self { + pub fn from_inclusion_proof_inputs( + inputs: &InclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.0.iter() { let proof_input = InclusionJsonStruct { root: big_int_to_string(&input.root), leaf: big_int_to_string(&input.leaf), - pathIndex: input.path_index.to_u32().unwrap(), + pathIndex: input.path_index.to_u32().ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "path index {} does not fit into u32", + input.path_index + )) + })?, pathElements: input.path_elements.iter().map(big_int_to_string).collect(), }; proof_inputs.push(proof_input); } - Self { + Ok(Self { circuit_type: CircuitType::Inclusion.to_string(), state_tree_height: 26, inputs: proof_inputs, - } + }) } } diff --git a/prover/client/src/proof_types/inclusion/v1/proof_inputs.rs b/prover/client/src/proof_types/inclusion/v1/proof_inputs.rs index 77a6777452..18758e31db 100644 --- a/prover/client/src/proof_types/inclusion/v1/proof_inputs.rs +++ b/prover/client/src/proof_types/inclusion/v1/proof_inputs.rs @@ -1,10 +1,13 @@ -use crate::{helpers::bigint_to_u8_32, proof_types::inclusion::v2::InclusionMerkleProofInputs}; +use crate::{ + errors::ProverClientError, helpers::bigint_to_u8_32, + proof_types::inclusion::v2::InclusionMerkleProofInputs, +}; impl InclusionMerkleProofInputs { - pub fn public_inputs_arr(&self) -> [[u8; 32]; 2] { - let root = bigint_to_u8_32(&self.root).unwrap(); - let leaf = bigint_to_u8_32(&self.leaf).unwrap(); - [root, leaf] + pub fn public_inputs_arr(&self) -> Result<[[u8; 32]; 2], ProverClientError> { + let root = bigint_to_u8_32(&self.root)?; + let leaf = bigint_to_u8_32(&self.leaf)?; + Ok([root, leaf]) } } @@ -12,14 +15,14 @@ impl InclusionMerkleProofInputs { pub struct InclusionProofInputs<'a>(pub &'a [InclusionMerkleProofInputs]); impl InclusionProofInputs<'_> { - pub fn public_inputs(&self) -> Vec<[u8; 32]> { + pub fn public_inputs(&self) -> Result, ProverClientError> { let mut roots = Vec::new(); let mut leaves = Vec::new(); for input in self.0 { - let input_arr = input.public_inputs_arr(); + let input_arr = input.public_inputs_arr()?; roots.push(input_arr[0]); leaves.push(input_arr[1]); } - [roots, leaves].concat() + Ok([roots, leaves].concat()) } } diff --git a/prover/client/src/proof_types/inclusion/v2/json.rs b/prover/client/src/proof_types/inclusion/v2/json.rs index 7fc9d34731..c69af3b3dd 100644 --- a/prover/client/src/proof_types/inclusion/v2/json.rs +++ b/prover/client/src/proof_types/inclusion/v2/json.rs @@ -3,6 +3,7 @@ use serde::Serialize; use crate::{ constants::DEFAULT_BATCH_STATE_TREE_HEIGHT, + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{circuit_type::CircuitType, inclusion::v2::InclusionProofInputs}, }; @@ -34,26 +35,33 @@ pub struct InclusionJsonStruct { impl BatchInclusionJsonStruct { #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } - pub fn from_inclusion_proof_inputs(inputs: &InclusionProofInputs) -> Self { + pub fn from_inclusion_proof_inputs( + inputs: &InclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.inputs.iter() { let prof_input = InclusionJsonStruct { root: big_int_to_string(&input.root), leaf: big_int_to_string(&input.leaf), - pathIndex: input.path_index.to_u32().unwrap(), + pathIndex: input.path_index.to_u32().ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "path index {} does not fit into u32", + input.path_index + )) + })?, pathElements: input.path_elements.iter().map(big_int_to_string).collect(), }; proof_inputs.push(prof_input); } - Self { + Ok(Self { circuit_type: CircuitType::Inclusion.to_string(), state_tree_height: DEFAULT_BATCH_STATE_TREE_HEIGHT, public_input_hash: big_int_to_string(&inputs.public_input_hash), inputs: proof_inputs, - } + }) } } diff --git a/prover/client/src/proof_types/inclusion/v2/proof_inputs.rs b/prover/client/src/proof_types/inclusion/v2/proof_inputs.rs index d6d4c78043..523163890f 100644 --- a/prover/client/src/proof_types/inclusion/v2/proof_inputs.rs +++ b/prover/client/src/proof_types/inclusion/v2/proof_inputs.rs @@ -28,16 +28,13 @@ impl<'a> InclusionProofInputs<'a> { pub fn public_input( inputs: &'a [InclusionMerkleProofInputs], ) -> Result { - let public_input_hash = create_two_inputs_hash_chain( - &inputs - .iter() - .map(|x| bigint_to_u8_32(&x.root).unwrap()) - .collect::>(), - &inputs - .iter() - .map(|x| bigint_to_u8_32(&x.leaf).unwrap()) - .collect::>(), - )?; + let mut roots = Vec::with_capacity(inputs.len()); + let mut leaves = Vec::with_capacity(inputs.len()); + for input in inputs { + roots.push(bigint_to_u8_32(&input.root)?); + leaves.push(bigint_to_u8_32(&input.leaf)?); + } + let public_input_hash = create_two_inputs_hash_chain(&roots, &leaves)?; Ok(BigInt::from_bytes_be( num_bigint::Sign::Plus, &public_input_hash, diff --git a/prover/client/src/proof_types/non_inclusion/v1/json.rs b/prover/client/src/proof_types/non_inclusion/v1/json.rs index 3cdc3c093f..9d712aa140 100644 --- a/prover/client/src/proof_types/non_inclusion/v1/json.rs +++ b/prover/client/src/proof_types/non_inclusion/v1/json.rs @@ -2,6 +2,7 @@ use num_traits::ToPrimitive; use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{circuit_type::CircuitType, non_inclusion::v1::NonInclusionProofInputs}, }; @@ -39,33 +40,48 @@ pub struct LegacyNonInclusionJsonStruct { impl BatchNonInclusionJsonStruct { #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } - pub fn from_non_inclusion_proof_inputs(inputs: &NonInclusionProofInputs) -> Self { + pub fn from_non_inclusion_proof_inputs( + inputs: &NonInclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.0 { let prof_input = LegacyNonInclusionJsonStruct { root: big_int_to_string(&input.root), value: big_int_to_string(&input.value), - path_index: input.index_hashed_indexed_element_leaf.to_u32().unwrap(), + path_index: input + .index_hashed_indexed_element_leaf + .to_u32() + .ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "path index {} does not fit into u32", + input.index_hashed_indexed_element_leaf + )) + })?, path_elements: input .merkle_proof_hashed_indexed_element_leaf .iter() .map(big_int_to_string) .collect(), - next_index: input.next_index.to_u32().unwrap(), + next_index: input.next_index.to_u32().ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "next index {} does not fit into u32", + input.next_index + )) + })?, leaf_lower_range_value: big_int_to_string(&input.leaf_lower_range_value), leaf_higher_range_value: big_int_to_string(&input.leaf_higher_range_value), }; proof_inputs.push(prof_input); } - Self { + Ok(Self { circuit_type: CircuitType::NonInclusion.to_string(), address_tree_height: 26, inputs: proof_inputs, - } + }) } } diff --git a/prover/client/src/proof_types/non_inclusion/v1/proof_inputs.rs b/prover/client/src/proof_types/non_inclusion/v1/proof_inputs.rs index 1a2cb8153a..56eed84fbd 100644 --- a/prover/client/src/proof_types/non_inclusion/v1/proof_inputs.rs +++ b/prover/client/src/proof_types/non_inclusion/v1/proof_inputs.rs @@ -1,12 +1,13 @@ use crate::{ - helpers::bigint_to_u8_32, proof_types::non_inclusion::v2::NonInclusionMerkleProofInputs, + errors::ProverClientError, helpers::bigint_to_u8_32, + proof_types::non_inclusion::v2::NonInclusionMerkleProofInputs, }; impl NonInclusionMerkleProofInputs { - pub fn public_inputs_legacy(&self) -> [[u8; 32]; 2] { - let root = bigint_to_u8_32(&self.root).unwrap(); - let value = bigint_to_u8_32(&self.value).unwrap(); - [root, value] + pub fn public_inputs_legacy(&self) -> Result<[[u8; 32]; 2], ProverClientError> { + let root = bigint_to_u8_32(&self.root)?; + let value = bigint_to_u8_32(&self.value)?; + Ok([root, value]) } } diff --git a/prover/client/src/proof_types/non_inclusion/v2/json.rs b/prover/client/src/proof_types/non_inclusion/v2/json.rs index f6174e724d..ec65c3c619 100644 --- a/prover/client/src/proof_types/non_inclusion/v2/json.rs +++ b/prover/client/src/proof_types/non_inclusion/v2/json.rs @@ -2,6 +2,7 @@ use num_traits::ToPrimitive; use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{circuit_type::CircuitType, non_inclusion::v2::NonInclusionProofInputs}, }; @@ -38,17 +39,27 @@ pub struct NonInclusionJsonStruct { impl BatchNonInclusionJsonStruct { #[allow(clippy::inherent_to_string)] - pub fn to_string(&self) -> String { + pub fn to_string(&self) -> Result { create_json_from_struct(&self) } - pub fn from_non_inclusion_proof_inputs(inputs: &NonInclusionProofInputs) -> Self { + pub fn from_non_inclusion_proof_inputs( + inputs: &NonInclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.inputs.iter() { let prof_input = NonInclusionJsonStruct { root: big_int_to_string(&input.root), value: big_int_to_string(&input.value), - path_index: input.index_hashed_indexed_element_leaf.to_u32().unwrap(), + path_index: input + .index_hashed_indexed_element_leaf + .to_u32() + .ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "path index {} does not fit into u32", + input.index_hashed_indexed_element_leaf + )) + })?, path_elements: input .merkle_proof_hashed_indexed_element_leaf .iter() @@ -60,11 +71,11 @@ impl BatchNonInclusionJsonStruct { proof_inputs.push(prof_input); } - Self { + Ok(Self { circuit_type: CircuitType::NonInclusion.to_string(), address_tree_height: 40, public_input_hash: big_int_to_string(&inputs.public_input_hash), inputs: proof_inputs, - } + }) } } diff --git a/prover/client/src/proof_types/non_inclusion/v2/proof_inputs.rs b/prover/client/src/proof_types/non_inclusion/v2/proof_inputs.rs index 807eae3484..15395d3f90 100644 --- a/prover/client/src/proof_types/non_inclusion/v2/proof_inputs.rs +++ b/prover/client/src/proof_types/non_inclusion/v2/proof_inputs.rs @@ -34,16 +34,13 @@ impl<'a> NonInclusionProofInputs<'a> { pub fn public_input( inputs: &'a [NonInclusionMerkleProofInputs], ) -> Result { - let public_input_hash = create_two_inputs_hash_chain( - &inputs - .iter() - .map(|x| bigint_to_u8_32(&x.root).unwrap()) - .collect::>(), - &inputs - .iter() - .map(|x| bigint_to_u8_32(&x.value).unwrap()) - .collect::>(), - )?; + let mut roots = Vec::with_capacity(inputs.len()); + let mut values = Vec::with_capacity(inputs.len()); + for input in inputs { + roots.push(bigint_to_u8_32(&input.root)?); + values.push(bigint_to_u8_32(&input.value)?); + } + let public_input_hash = create_two_inputs_hash_chain(&roots, &values)?; Ok(BigInt::from_bytes_be( num_bigint::Sign::Plus, &public_input_hash, diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 3bf1bab785..e136875c91 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -7,51 +7,76 @@ use std::{ use tracing::info; +#[cfg(feature = "devenv")] +use crate::helpers::get_project_root; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, - helpers::get_project_root, + errors::ProverClientError, }; static IS_LOADING: AtomicBool = AtomicBool::new(false); -pub async fn spawn_prover() { - if let Some(_project_root) = get_project_root() { - let prover_path: &str = { - #[cfg(feature = "devenv")] - { - &format!("{}/{}", _project_root.trim(), "cli/test_bin/run") - } - #[cfg(not(feature = "devenv"))] - { - println!("Running in production mode, using prover binary"); - "light" - } - }; +pub async fn spawn_prover() -> Result<(), ProverClientError> { + #[cfg(feature = "devenv")] + let project_root = get_project_root().ok_or(ProverClientError::ProjectRootNotFound)?; - if !health_check(10, 1).await && !IS_LOADING.load(Ordering::Relaxed) { - IS_LOADING.store(true, Ordering::Relaxed); + let prover_path: String = { + #[cfg(feature = "devenv")] + { + format!("{}/{}", project_root.trim(), "cli/test_bin/run") + } + #[cfg(not(feature = "devenv"))] + { + println!("Running in production mode, using prover binary"); + "light".to_string() + } + }; - let command = Command::new(prover_path) - .arg("start-prover") - .spawn() - .expect("Failed to start prover process"); + if health_check(10, 1).await { + return Ok(()); + } - let _ = command.wait_with_output(); + let loading_guard = IS_LOADING + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok(); - let health_result = health_check(120, 1).await; - if health_result { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); - } + if !loading_guard { + return if health_check(120, 1).await { + Ok(()) + } else { + Err(ProverClientError::HealthCheckFailed) + }; + } + + let spawn_result = async { + let command = Command::new(&prover_path) + .arg("start-prover") + .spawn() + .map_err(|error| ProverClientError::ProcessStart(error.to_string()))?; + + command + .wait_with_output() + .map_err(|error| ProverClientError::ProcessWait(error.to_string()))?; + + if health_check(120, 1).await { + info!("Prover started successfully"); + Ok(()) + } else { + Err(ProverClientError::HealthCheckFailed) } - } else { - panic!("Failed to find project root."); - }; + } + .await; + + IS_LOADING.store(false, Ordering::Release); + + spawn_result } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = reqwest::Client::new(); + let client = match reqwest::Client::builder().no_proxy().build() { + Ok(client) => client, + Err(_) => return false, + }; let mut result = false; for _ in 0..retries { match client diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 22f58d5362..ac73c3809e 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -45,7 +45,8 @@ async fn prove_batch_address_append() { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::new(); // Generate non-inclusion proofs for each element for new_element_value in &new_element_values { @@ -57,7 +58,13 @@ async fn prove_batch_address_append() { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); } // Convert big integers to byte arrays @@ -87,12 +94,12 @@ async fn prove_batch_address_append() { get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut sparse_merkle_tree, hash_chain, zkp_batch_size, diff --git a/prover/client/tests/batch_append.rs b/prover/client/tests/batch_append.rs index 801e155e92..38ccdf9424 100644 --- a/prover/client/tests/batch_append.rs +++ b/prover/client/tests/batch_append.rs @@ -12,7 +12,7 @@ mod init_merkle_tree; #[serial] #[tokio::test] async fn prove_batch_append_with_proofs() { - spawn_prover().await; + spawn_prover().await.unwrap(); const HEIGHT: usize = DEFAULT_BATCH_STATE_TREE_HEIGHT as usize; const CANOPY: usize = 0; @@ -67,7 +67,9 @@ async fn prove_batch_append_with_proofs() { // Serialize inputs to JSON let client = Client::new(); - let inputs_json = BatchAppendInputsJson::from_inputs(&inputs).to_string(); + let inputs_json = BatchAppendInputsJson::from_inputs(&inputs) + .to_string() + .unwrap(); // Send proof request to server let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) diff --git a/prover/client/tests/batch_update.rs b/prover/client/tests/batch_update.rs index c4aeeaa1c0..9d86da1bde 100644 --- a/prover/client/tests/batch_update.rs +++ b/prover/client/tests/batch_update.rs @@ -12,7 +12,7 @@ mod init_merkle_tree; #[serial] #[tokio::test] async fn prove_batch_update() { - spawn_prover().await; + spawn_prover().await.unwrap(); const HEIGHT: usize = DEFAULT_BATCH_STATE_TREE_HEIGHT as usize; const CANOPY: usize = 0; let num_insertions = 10; @@ -59,7 +59,7 @@ async fn prove_batch_update() { ) .unwrap(); let client = Client::new(); - let inputs = update_inputs_string(&inputs); + let inputs = update_inputs_string(&inputs).unwrap(); let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) .header("Content-Type", "text/plain; charset=utf-8") diff --git a/prover/client/tests/combined.rs b/prover/client/tests/combined.rs index 92b9cc2644..ada34d5b5c 100644 --- a/prover/client/tests/combined.rs +++ b/prover/client/tests/combined.rs @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{combined_inputs_string_v1, combined_inputs_string_ #[serial] #[tokio::test] async fn prove_combined() { - spawn_prover().await; + spawn_prover().await.unwrap(); let client = Client::new(); { for i in 1..=4 { diff --git a/prover/client/tests/inclusion.rs b/prover/client/tests/inclusion.rs index be50a6d3cb..850e5c14b9 100644 --- a/prover/client/tests/inclusion.rs +++ b/prover/client/tests/inclusion.rs @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{inclusion_inputs_string_v1, inclusion_inputs_strin #[serial] #[tokio::test] async fn prove_inclusion() { - spawn_prover().await; + spawn_prover().await.unwrap(); let client = Client::new(); // v2 - test all keys from 1 to 20 diff --git a/prover/client/tests/init_merkle_tree.rs b/prover/client/tests/init_merkle_tree.rs index 3bb5584cd3..79b1d04762 100644 --- a/prover/client/tests/init_merkle_tree.rs +++ b/prover/client/tests/init_merkle_tree.rs @@ -165,7 +165,7 @@ pub fn non_inclusion_inputs_string_v1( number_of_utxos: usize, ) -> (String, NonInclusionMerkleProofInputs) { let (json_struct, public_inputs) = non_inclusion_new_with_public_inputs_v1(number_of_utxos); - (json_struct.to_string(), public_inputs) + (json_struct.to_string().unwrap(), public_inputs) } pub fn non_inclusion_new_with_public_inputs_v1( @@ -203,7 +203,7 @@ pub fn non_inclusion_new_with_public_inputs_v1( pub fn non_inclusion_inputs_string_v2(number_of_utxos: usize) -> String { let (json_struct, _) = non_inclusion_new_with_public_inputs_v2(number_of_utxos).unwrap(); - json_struct.to_string() + json_struct.to_string().unwrap() } pub fn non_inclusion_new_with_public_inputs_v2( @@ -270,7 +270,7 @@ pub fn inclusion_new_with_public_inputs_v1(number_of_utxos: usize) -> BatchInclu pub fn inclusion_inputs_string_v1(number_of_utxos: usize) -> String { let json_struct = inclusion_new_with_public_inputs_v1(number_of_utxos); - json_struct.to_string() + json_struct.to_string().unwrap() } pub fn inclusion_new_with_public_inputs_v2( @@ -312,7 +312,7 @@ pub fn inclusion_new_with_public_inputs_v2( pub fn inclusion_inputs_string_v2(number_of_utxos: usize) -> String { let (json_struct, _) = inclusion_new_with_public_inputs_v2(number_of_utxos); - json_struct.to_string() + json_struct.to_string().unwrap() } fn combined_new_with_public_inputs_v1( @@ -334,7 +334,7 @@ fn combined_new_with_public_inputs_v1( pub fn combined_inputs_string_v1(num_inclusion: usize, num_non_inclusion: usize) -> String { let json_struct = combined_new_with_public_inputs_v1(num_inclusion, num_non_inclusion); - json_struct.to_string() + json_struct.to_string().unwrap() } fn combined_new_with_public_inputs_v2( @@ -366,5 +366,5 @@ fn combined_new_with_public_inputs_v2( pub fn combined_inputs_string_v2(num_inclusion: usize, num_non_inclusion: usize) -> String { let json_struct = combined_new_with_public_inputs_v2(num_inclusion, num_non_inclusion); - json_struct.unwrap().to_string() + json_struct.unwrap().to_string().unwrap() } diff --git a/prover/client/tests/non_inclusion.rs b/prover/client/tests/non_inclusion.rs index 39fc0d21dd..9dd874d08b 100644 --- a/prover/client/tests/non_inclusion.rs +++ b/prover/client/tests/non_inclusion.rs @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{non_inclusion_inputs_string_v1, non_inclusion_inpu #[serial] #[tokio::test] async fn prove_non_inclusion() { - spawn_prover().await; + spawn_prover().await.unwrap(); let client = Client::new(); // legacy height 26 { diff --git a/sdk-libs/client/src/fee.rs b/sdk-libs/client/src/fee.rs index 46d48fb231..e4367493f3 100644 --- a/sdk-libs/client/src/fee.rs +++ b/sdk-libs/client/src/fee.rs @@ -65,7 +65,11 @@ pub async fn assert_transaction_params( if let Some(transaction_params) = params { let mut deduped_signers = signers.to_vec(); deduped_signers.dedup(); - let post_balance = rpc.get_account(*payer).await?.unwrap().lamports; + let post_balance = rpc + .get_account(*payer) + .await? + .ok_or_else(|| RpcError::AccountDoesNotExist(payer.to_string()))? + .lamports; // Network fee is charged per input and per address let mut network_fee: i64 = 0; diff --git a/sdk-libs/client/src/indexer/error.rs b/sdk-libs/client/src/indexer/error.rs index d83b057989..10510e2fb8 100644 --- a/sdk-libs/client/src/indexer/error.rs +++ b/sdk-libs/client/src/indexer/error.rs @@ -135,7 +135,7 @@ impl Clone for IndexerError { IndexerError::CustomError("IndexedMerkleTreeError".to_string()) } IndexerError::InvalidResponseData => IndexerError::InvalidResponseData, - IndexerError::CustomError(_) => IndexerError::CustomError("IndexerError".to_string()), + IndexerError::CustomError(msg) => IndexerError::CustomError(msg.clone()), IndexerError::NotInitialized => IndexerError::NotInitialized, IndexerError::IndexerNotSyncedToSlot => IndexerError::IndexerNotSyncedToSlot, IndexerError::InvalidPackTreeType => IndexerError::InvalidPackTreeType, diff --git a/sdk-libs/client/src/indexer/options.rs b/sdk-libs/client/src/indexer/options.rs index 76a05d8a94..bc11e818a4 100644 --- a/sdk-libs/client/src/indexer/options.rs +++ b/sdk-libs/client/src/indexer/options.rs @@ -40,13 +40,12 @@ pub struct GetCompressedAccountsFilter { pub offset: u32, } -#[allow(clippy::from_over_into)] -impl Into for GetCompressedAccountsFilter { - fn into(self) -> FilterSelector { +impl From<&GetCompressedAccountsFilter> for FilterSelector { + fn from(filter: &GetCompressedAccountsFilter) -> FilterSelector { FilterSelector { memcmp: Some(Memcmp { - offset: self.offset as u64, - bytes: photon_api::types::Base58String(bs58::encode(&self.bytes).into_string()), + offset: filter.offset as u64, + bytes: photon_api::types::Base58String(bs58::encode(&filter.bytes).into_string()), }), } } @@ -56,7 +55,7 @@ impl GetCompressedAccountsByOwnerConfig { pub fn filters_to_photon(&self) -> Option> { self.filters .as_ref() - .map(|filters| filters.iter().map(|f| f.clone().into()).collect()) + .map(|filters| filters.iter().map(Into::into).collect()) } } diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 26d16ae235..5698719c8f 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -1142,17 +1142,16 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - let mut proof_vec = x.proof.clone(); - if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { return Err(IndexerError::InvalidParameters(format!( "Merkle proof length ({}) is less than canopy depth ({})", - proof_vec.len(), + x.proof.len(), STATE_MERKLE_TREE_CANOPY_DEPTH, ))); } - proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH); + let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; - let proof = proof_vec + let proof = x.proof[..proof_len] .iter() .map(|s| Hash::from_base58(s)) .collect::, IndexerError>>() @@ -1703,15 +1702,13 @@ impl Indexer for PhotonIndexer { async fn get_subtrees( &self, - _merkle_tree_pubkey: [u8; 32], + merkle_tree_pubkey: [u8; 32], _config: Option, ) -> Result>, IndexerError> { - #[cfg(not(feature = "v2"))] - unimplemented!(); - #[cfg(feature = "v2")] - { - todo!(); - } + Err(IndexerError::NotImplemented(format!( + "PhotonIndexer::get_subtrees is not implemented for merkle tree {}", + solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey) + ))) } } diff --git a/sdk-libs/client/src/indexer/types/proof.rs b/sdk-libs/client/src/indexer/types/proof.rs index 0b45e00986..0335470929 100644 --- a/sdk-libs/client/src/indexer/types/proof.rs +++ b/sdk-libs/client/src/indexer/types/proof.rs @@ -189,7 +189,10 @@ pub struct PackedTreeInfos { } impl ValidityProofWithContext { - pub fn pack_tree_infos(&self, packed_accounts: &mut PackedAccounts) -> PackedTreeInfos { + pub fn pack_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Result { let mut packed_tree_infos = Vec::new(); let mut address_trees = Vec::new(); let mut output_tree_index = None; @@ -209,19 +212,12 @@ impl ValidityProofWithContext { // If a next Merkle tree exists the Merkle tree is full -> use the next Merkle tree for new state. // Else use the current Merkle tree for new state. if let Some(next) = account.tree_info.next_tree_info { - // SAFETY: account will always have a state Merkle tree context. - // pack_output_tree_index only panics on an address Merkle tree context. - let index = next.pack_output_tree_index(packed_accounts).unwrap(); + let index = next.pack_output_tree_index(packed_accounts)?; if output_tree_index.is_none() { output_tree_index = Some(index); } } else { - // SAFETY: account will always have a state Merkle tree context. - // pack_output_tree_index only panics on an address Merkle tree context. - let index = account - .tree_info - .pack_output_tree_index(packed_accounts) - .unwrap(); + let index = account.tree_info.pack_output_tree_index(packed_accounts)?; if output_tree_index.is_none() { output_tree_index = Some(index); } @@ -244,13 +240,18 @@ impl ValidityProofWithContext { } else { Some(PackedStateTreeInfos { packed_tree_infos, - output_tree_index: output_tree_index.unwrap(), + output_tree_index: output_tree_index.ok_or_else(|| { + IndexerError::missing_result( + "pack_tree_infos", + "missing output tree index for state proof", + ) + })?, }) }; - PackedTreeInfos { + Ok(PackedTreeInfos { state_trees: packed_tree_infos, address_trees, - } + }) } pub fn from_api_model( diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 40e7cc0f6e..3f52d72798 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::super::IndexerError; #[derive(Debug, Clone, PartialEq, Default)] @@ -64,13 +66,14 @@ pub struct AddressQueueData { } impl AddressQueueData { + const ADDRESS_TREE_HEIGHT: usize = 40; + /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. - /// The tree_height is needed to know how many levels to traverse. - pub fn reconstruct_proof( + pub fn reconstruct_proof( &self, address_idx: usize, - tree_height: u8, - ) -> Result, IndexerError> { + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -81,10 +84,10 @@ impl AddressQueueData { ), } })?; - let mut proof = Vec::with_capacity(tree_height as usize); + let mut proof = [[0u8; 32]; HEIGHT]; let mut pos = leaf_index; - for level in 0..tree_height { + for (level, proof_element) in proof.iter_mut().enumerate() { let sibling_pos = if pos.is_multiple_of(2) { pos + 1 } else { @@ -114,28 +117,224 @@ impl AddressQueueData { self.node_hashes.len(), ), })?; - proof.push(*hash); + *proof_element = *hash; pos /= 2; } Ok(proof) } + /// Reconstruct a contiguous batch of proofs while reusing a single node lookup table. + pub fn reconstruct_proofs( + &self, + address_range: std::ops::Range, + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + let node_lookup = self.build_node_lookup(); + let mut proofs = Vec::with_capacity(address_range.len()); + + for address_idx in address_range { + proofs.push(self.reconstruct_proof_with_lookup::(address_idx, &node_lookup)?); + } + + Ok(proofs) + } + /// Reconstruct all proofs for all addresses - pub fn reconstruct_all_proofs( + pub fn reconstruct_all_proofs( &self, - tree_height: u8, - ) -> Result>, IndexerError> { - (0..self.addresses.len()) - .map(|i| self.reconstruct_proof(i, tree_height)) - .collect() + ) -> Result, IndexerError> { + self.validate_proof_height::()?; + self.reconstruct_proofs::(0..self.addresses.len()) + } + + fn build_node_lookup(&self) -> HashMap { + let mut lookup = HashMap::with_capacity(self.nodes.len()); + for (idx, node) in self.nodes.iter().copied().enumerate() { + lookup.entry(node).or_insert(idx); + } + lookup + } + + fn reconstruct_proof_with_lookup( + &self, + address_idx: usize, + node_lookup: &HashMap, + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; + let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "address_idx {} out of bounds for low_element_indices (len {})", + address_idx, + self.low_element_indices.len(), + ), + } + })?; + let mut proof = [[0u8; 32]; HEIGHT]; + let mut pos = leaf_index; + + for (level, proof_element) in proof.iter_mut().enumerate() { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let sibling_idx = Self::encode_node_index(level, sibling_pos); + let hash_idx = node_lookup.get(&sibling_idx).copied().ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "Missing proof node at level {} position {} (encoded: {})", + level, sibling_pos, sibling_idx + ), + } + })?; + let hash = + self.node_hashes + .get(hash_idx) + .ok_or_else(|| IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "node_hashes index {} out of bounds (len {})", + hash_idx, + self.node_hashes.len(), + ), + })?; + *proof_element = *hash; + pos /= 2; + } + + Ok(proof) } /// Encode node index: (level << 56) | position #[inline] - fn encode_node_index(level: u8, position: u64) -> u64 { + fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } + + fn validate_proof_height(&self) -> Result<(), IndexerError> { + if HEIGHT == Self::ADDRESS_TREE_HEIGHT { + return Ok(()); + } + + Err(IndexerError::InvalidParameters(format!( + "address queue proofs require HEIGHT={} but got HEIGHT={}", + Self::ADDRESS_TREE_HEIGHT, + HEIGHT + ))) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, hint::black_box, time::Instant}; + + use super::AddressQueueData; + + fn hash_from_node(node_index: u64) -> [u8; 32] { + let mut hash = [0u8; 32]; + hash[..8].copy_from_slice(&node_index.to_le_bytes()); + hash[8..16].copy_from_slice(&node_index.rotate_left(17).to_le_bytes()); + hash[16..24].copy_from_slice(&node_index.rotate_right(9).to_le_bytes()); + hash[24..32].copy_from_slice(&(node_index ^ 0xA5A5_A5A5_A5A5_A5A5).to_le_bytes()); + hash + } + + fn build_queue_data(num_addresses: usize) -> AddressQueueData { + let low_element_indices = (0..num_addresses) + .map(|i| (i as u64).saturating_mul(2)) + .collect::>(); + let mut nodes = BTreeMap::new(); + + for &leaf_index in &low_element_indices { + let mut pos = leaf_index; + for level in 0..HEIGHT { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let node_index = ((level as u64) << 56) | sibling_pos; + nodes + .entry(node_index) + .or_insert_with(|| hash_from_node(node_index)); + pos /= 2; + } + } + + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes.into_iter().unzip(); + + AddressQueueData { + addresses: vec![[0u8; 32]; num_addresses], + low_element_values: vec![[1u8; 32]; num_addresses], + low_element_next_values: vec![[2u8; 32]; num_addresses], + low_element_indices, + low_element_next_indices: (0..num_addresses).map(|i| (i as u64) + 1).collect(), + nodes, + node_hashes, + initial_root: [9u8; 32], + leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], + subtrees: vec![[4u8; 32]; HEIGHT], + start_index: 0, + root_seq: 0, + } + } + + #[test] + fn batched_reconstruction_matches_individual_reconstruction() { + let queue = build_queue_data::<40>(128); + + let expected = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::<40>(i).unwrap()) + .collect::>(); + let actual = queue + .reconstruct_proofs::<40>(0..queue.addresses.len()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + #[ignore = "profiling helper"] + fn profile_reconstruct_proofs_batch() { + const HEIGHT: usize = 40; + const NUM_ADDRESSES: usize = 2_048; + const ITERS: usize = 25; + + let queue = build_queue_data::(NUM_ADDRESSES); + + let baseline_start = Instant::now(); + for _ in 0..ITERS { + let proofs = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::(i).unwrap()) + .collect::>(); + black_box(proofs); + } + let baseline = baseline_start.elapsed(); + + let batched_start = Instant::now(); + for _ in 0..ITERS { + black_box( + queue + .reconstruct_proofs::(0..queue.addresses.len()) + .unwrap(), + ); + } + let batched = batched_start.elapsed(); + + println!( + "queue reconstruction profile: addresses={}, height={}, iters={}, individual={:?}, batched={:?}, speedup={:.2}x", + NUM_ADDRESSES, + HEIGHT, + ITERS, + baseline, + batched, + baseline.as_secs_f64() / batched.as_secs_f64(), + ); + } } /// V2 Queue Elements Result with deduplicated node data diff --git a/sdk-libs/client/src/indexer/types/token.rs b/sdk-libs/client/src/indexer/types/token.rs index 92903d234f..78162de8c0 100644 --- a/sdk-libs/client/src/indexer/types/token.rs +++ b/sdk-libs/client/src/indexer/types/token.rs @@ -90,7 +90,7 @@ impl Into> |token_account| light_token::compat::TokenDataWithMerkleContext { token_data: token_account.token, compressed_account: CompressedAccountWithMerkleContext::from( - token_account.account.clone(), + token_account.account, ), }, ) diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 7b5919cdb1..82c8733682 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -7,6 +7,8 @@ use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSeria use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; +use crate::interface::serialize::serialize_anchor_data; + /// Default address tree v2 pubkey. pub const ADDRESS_TREE_V2: Pubkey = solana_pubkey::pubkey!("amt2kaJA14v3urZbZvnc5v2np8jqvc4Z8zDep5wbtzx"); @@ -85,7 +87,7 @@ impl InitializeRentFreeConfig { self } - pub fn build(self) -> (Instruction, Pubkey) { + pub fn build(self) -> std::io::Result<(Instruction, Pubkey)> { let authority = self.authority.unwrap_or(self.fee_payer); let config_bump_u16 = self.config_bump as u16; let (config_pda, _) = Pubkey::find_program_address( @@ -119,9 +121,7 @@ impl InitializeRentFreeConfig { // SHA256("global:initialize_compression_config")[..8] const DISCRIMINATOR: [u8; 8] = [133, 228, 12, 169, 56, 76, 222, 61]; - let serialized_data = instruction_data - .try_to_vec() - .expect("Failed to serialize instruction data"); + let serialized_data = serialize_anchor_data(&instruction_data)?; let mut data = Vec::with_capacity(DISCRIMINATOR.len() + serialized_data.len()); data.extend_from_slice(&DISCRIMINATOR); @@ -133,6 +133,6 @@ impl InitializeRentFreeConfig { data, }; - (instruction, config_pda) + Ok((instruction, config_pda)) } } diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index f6d754b9b1..02560c51be 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -18,7 +18,10 @@ use light_token::constants::{ use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; -use crate::indexer::{CompressedAccount, TreeInfo, ValidityProofWithContext}; +use crate::{ + indexer::{CompressedAccount, TreeInfo, ValidityProofWithContext}, + interface::serialize::serialize_anchor_data, +}; #[inline] fn get_output_queue(tree_info: &TreeInfo) -> Pubkey { @@ -98,7 +101,7 @@ pub fn initialize_config( rent_sponsor: Pubkey, address_space: Vec, config_bump: Option, -) -> Instruction { +) -> std::io::Result { let config_bump = config_bump.unwrap_or(0); let config_bump_u16 = config_bump as u16; let (config_pda, _) = Pubkey::find_program_address( @@ -129,16 +132,16 @@ pub fn initialize_config( address_space: address_space.iter().map(|p| p.to_bytes()).collect(), config_bump, }; - let serialized = params.try_to_vec().expect("serialize params"); + let serialized = serialize_anchor_data(¶ms)?; let mut data = Vec::with_capacity(discriminator.len() + serialized.len()); data.extend_from_slice(discriminator); data.extend_from_slice(&serialized); - Instruction { + Ok(Instruction { program_id: *program_id, accounts, data, - } + }) } pub fn update_config( @@ -148,7 +151,7 @@ pub fn update_config( new_rent_sponsor: Option, new_address_space: Option>, new_update_authority: Option, -) -> Instruction { +) -> std::io::Result { let (config_pda, _) = Pubkey::find_program_address( &[light_account::LIGHT_CONFIG_SEED, &0u16.to_le_bytes()], program_id, @@ -167,16 +170,16 @@ pub fn update_config( new_write_top_up: None, new_address_space: new_address_space.map(|v| v.iter().map(|p| p.to_bytes()).collect()), }; - let serialized = params.try_to_vec().expect("serialize params"); + let serialized = serialize_anchor_data(¶ms)?; let mut data = Vec::with_capacity(discriminator.len() + serialized.len()); data.extend_from_slice(discriminator); data.extend_from_slice(&serialized); - Instruction { + Ok(Instruction { program_id: *program_id, accounts, data, - } + }) } /// Build load (decompress) instruction. @@ -234,7 +237,7 @@ where let output_queue = get_output_queue(&cold_accounts[0].0.tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts)?; let tree_infos = &packed_tree_infos .state_trees .as_ref() @@ -309,7 +312,7 @@ pub fn build_compress_accounts_idempotent( let output_queue = get_output_queue(&proof.accounts[0].tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts)?; let tree_infos = packed_tree_infos .state_trees .as_ref() diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 061ad5074b..260ca2b66c 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -53,6 +53,9 @@ pub enum LoadAccountsError { #[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")] MissingPdaCompressed { index: usize, pubkey: Pubkey }, + #[error("Cold PDA (pubkey {pubkey}) missing data")] + MissingPdaCompressedData { pubkey: Pubkey }, + #[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")] MissingAtaCompressed { index: usize, pubkey: Pubkey }, @@ -67,6 +70,7 @@ pub enum LoadAccountsError { } const MAX_ATAS_PER_IX: usize = 8; +const MAX_PDAS_PER_IX: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -113,14 +117,18 @@ where }) .collect(); - let pda_hashes = collect_pda_hashes(&cold_pdas)?; + let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let pda_hashes = pda_groups + .iter() + .map(|group| collect_pda_hashes(group)) + .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; let (pda_proofs, ata_proofs, mint_proofs) = futures::join!( - fetch_proofs(&pda_hashes, indexer), + fetch_proof_batches(&pda_hashes, indexer), fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer), - fetch_proofs(&mint_hashes, indexer), + fetch_individual_proofs(&mint_hashes, indexer), ); let pda_proofs = pda_proofs?; @@ -136,9 +144,9 @@ where // 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs). // Token PDAs are created on-chain via CPI inside DecompressVariant. - for (spec, proof) in cold_pdas.iter().zip(pda_proofs) { + for (group, proof) in pda_groups.into_iter().zip(pda_proofs) { out.push(build_pda_load( - &[spec], + &group, proof, fee_payer, compression_config, @@ -146,8 +154,7 @@ where } // 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist - let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect(); - for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) { + for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) { out.extend(build_ata_load(chunk, proof, fee_payer)?); } @@ -195,23 +202,77 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result, Lo .collect() } -async fn fetch_proofs( +/// Groups already-ordered PDA specs into contiguous runs of the same program id. +/// +/// This preserves input order rather than globally regrouping by program. Callers that +/// want maximal batching across interleaved program ids should sort before calling. +fn group_pda_specs<'a, V>( + specs: &[&'a PdaSpec], + max_per_group: usize, +) -> Vec>> { + assert!(max_per_group > 0, "max_per_group must be non-zero"); + if specs.is_empty() { + return Vec::new(); + } + + let mut groups = Vec::new(); + let mut current = Vec::with_capacity(max_per_group); + let mut current_program: Option = None; + + for spec in specs { + let program_id = spec.program_id(); + let should_split = current_program + .map(|existing| existing != program_id || current.len() >= max_per_group) + .unwrap_or(false); + + if should_split { + groups.push(current); + current = Vec::with_capacity(max_per_group); + } + + current_program = Some(program_id); + current.push(*spec); + } + + if !current.is_empty() { + groups.push(current); + } + + groups +} + +async fn fetch_individual_proofs( hashes: &[[u8; 32]], indexer: &I, ) -> Result, IndexerError> { if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len()); - for hash in hashes { - proofs.push( - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await? - .value, - ); + + futures::future::try_join_all(hashes.iter().map(|hash| async move { + indexer + .get_validity_proof(vec![*hash], vec![], None) + .await + .map(|response| response.value) + })) + .await +} + +async fn fetch_proof_batches( + hash_batches: &[Vec<[u8; 32]>], + indexer: &I, +) -> Result, IndexerError> { + if hash_batches.is_empty() { + return Ok(vec![]); } - Ok(proofs) + + futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { + indexer + .get_validity_proof(hashes.clone(), vec![], None) + .await + .map(|response| response.value) + })) + .await } async fn fetch_proofs_batched( @@ -222,16 +283,13 @@ async fn fetch_proofs_batched( if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size)); - for chunk in hashes.chunks(batch_size) { - proofs.push( - indexer - .get_validity_proof(chunk.to_vec(), vec![], None) - .await? - .value, - ); - } - Ok(proofs) + + let hash_batches = hashes + .chunks(batch_size) + .map(|chunk| chunk.to_vec()) + .collect::>(); + + fetch_proof_batches(&hash_batches, indexer).await } fn build_pda_load( @@ -262,11 +320,16 @@ where let hot_addresses: Vec = specs.iter().map(|s| s.address()).collect(); let cold_accounts: Vec<(CompressedAccount, V)> = specs .iter() - .map(|s| { - let compressed = s.compressed().expect("cold spec must have data").clone(); - (compressed, s.variant.clone()) + .map(|s| -> Result<_, LoadAccountsError> { + let compressed = + s.compressed() + .cloned() + .ok_or(LoadAccountsError::MissingPdaCompressedData { + pubkey: s.address(), + })?; + Ok((compressed, s.variant.clone())) }) - .collect(); + .collect::, _>>()?; let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default(); @@ -345,7 +408,9 @@ fn build_transfer2( fee_payer: Pubkey, ) -> Result { let mut packed = PackedAccounts::default(); - let packed_trees = proof.pack_tree_infos(&mut packed); + let packed_trees = proof + .pack_tree_infos(&mut packed) + .map_err(|e| LoadAccountsError::BuildInstruction(e.to_string()))?; let tree_infos = packed_trees .state_trees .as_ref() diff --git a/sdk-libs/client/src/interface/mod.rs b/sdk-libs/client/src/interface/mod.rs index 5a587556f4..f2b1c14689 100644 --- a/sdk-libs/client/src/interface/mod.rs +++ b/sdk-libs/client/src/interface/mod.rs @@ -8,6 +8,7 @@ pub mod instructions; pub mod light_program_interface; pub mod load_accounts; pub mod pack; +mod serialize; pub mod tx_size; pub use account_interface::{AccountInterface, AccountInterfaceError, TokenAccountInterface}; diff --git a/sdk-libs/client/src/interface/pack.rs b/sdk-libs/client/src/interface/pack.rs index 804a48751d..52a8491fbd 100644 --- a/sdk-libs/client/src/interface/pack.rs +++ b/sdk-libs/client/src/interface/pack.rs @@ -6,12 +6,15 @@ use solana_instruction::AccountMeta; use solana_pubkey::Pubkey; use thiserror::Error; -use crate::indexer::{TreeInfo, ValidityProofWithContext}; +use crate::indexer::{IndexerError, TreeInfo, ValidityProofWithContext}; #[derive(Debug, Error)] pub enum PackError { #[error("Failed to add system accounts: {0}")] SystemAccounts(#[from] light_sdk::error::LightSdkError), + + #[error("Indexer error: {0}")] + Indexer(#[from] IndexerError), } /// Packed state tree infos from validity proof. @@ -87,7 +90,7 @@ fn pack_proof_internal( // For mint creation: pack address tree first (index 1), then state tree. let (client_packed_tree_infos, state_tree_index) = if include_state_tree { // Pack tree infos first to ensure address tree is at index 1 - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; // Then add state tree (will be after address tree) let state_tree = output_tree @@ -99,7 +102,7 @@ fn pack_proof_internal( (tree_infos, Some(state_idx)) } else { - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; (tree_infos, None) }; let (remaining_accounts, system_offset, _) = packed.to_account_metas(); diff --git a/sdk-libs/client/src/interface/serialize.rs b/sdk-libs/client/src/interface/serialize.rs new file mode 100644 index 0000000000..c3aa8e1a30 --- /dev/null +++ b/sdk-libs/client/src/interface/serialize.rs @@ -0,0 +1,10 @@ +#[cfg(feature = "anchor")] +use anchor_lang::AnchorSerialize; +#[cfg(not(feature = "anchor"))] +use borsh::BorshSerialize as AnchorSerialize; + +pub(crate) fn serialize_anchor_data(value: &T) -> std::io::Result> { + let mut serialized = Vec::new(); + value.serialize(&mut serialized)?; + Ok(serialized) +} diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index 36ed7c04b3..cecb55efa3 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -1,4 +1,7 @@ -use std::process::{Command, Stdio}; +use std::{ + io, + process::{Command, Stdio}, +}; use light_prover_client::helpers::get_project_root; @@ -55,73 +58,83 @@ impl Default for LightValidatorConfig { } } -pub async fn spawn_validator(config: LightValidatorConfig) { - if let Some(project_root) = get_project_root() { - let path = "cli/test_bin/run test-validator"; - let mut path = format!("{}/{}", project_root.trim(), path); - if !config.enable_indexer { - path.push_str(" --skip-indexer"); - } +pub async fn spawn_validator(config: LightValidatorConfig) -> io::Result<()> { + let project_root = get_project_root().ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "Failed to determine project root for validator startup", + ) + })?; - if let Some(limit_ledger_size) = config.limit_ledger_size { - path.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); - } + let path = "cli/test_bin/run test-validator"; + let mut path = format!("{}/{}", project_root.trim(), path); + if !config.enable_indexer { + path.push_str(" --skip-indexer"); + } - for sbf_program in config.sbf_programs.iter() { - path.push_str(&format!( - " --sbf-program {} {}", - sbf_program.0, sbf_program.1 - )); - } + if let Some(limit_ledger_size) = config.limit_ledger_size { + path.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); + } - for upgradeable_program in config.upgradeable_programs.iter() { - path.push_str(&format!( - " --upgradeable-program {} {} {}", - upgradeable_program.program_id, - upgradeable_program.program_path, - upgradeable_program.upgrade_authority - )); - } + for sbf_program in config.sbf_programs.iter() { + path.push_str(&format!( + " --sbf-program {} {}", + sbf_program.0, sbf_program.1 + )); + } - if !config.enable_prover { - path.push_str(" --skip-prover"); - } + for upgradeable_program in config.upgradeable_programs.iter() { + path.push_str(&format!( + " --upgradeable-program {} {} {}", + upgradeable_program.program_id, + upgradeable_program.program_path, + upgradeable_program.upgrade_authority + )); + } - if config.use_surfpool { - path.push_str(" --use-surfpool"); - } + if !config.enable_prover { + path.push_str(" --skip-prover"); + } - for arg in config.validator_args.iter() { - path.push_str(&format!(" {}", arg)); - } + if config.use_surfpool { + path.push_str(" --use-surfpool"); + } - println!("Starting validator with command: {}", path); - - if config.use_surfpool { - // The CLI starts surfpool, prover, and photon, then exits once all - // services are ready. Wait for it to finish so we know everything - // is up before the test proceeds. - let mut child = Command::new("sh") - .arg("-c") - .arg(path) - .stdin(Stdio::null()) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()) - .spawn() - .expect("Failed to start server process"); - let status = child.wait().expect("Failed to wait for CLI process"); - assert!(status.success(), "CLI exited with error: {}", status); - } else { - let child = Command::new("sh") - .arg("-c") - .arg(path) - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("Failed to start server process"); - std::mem::drop(child); - tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; + for arg in config.validator_args.iter() { + path.push_str(&format!(" {}", arg)); + } + + println!("Starting validator with command: {}", path); + + if config.use_surfpool { + // The CLI starts surfpool, prover, and photon, then exits once all + // services are ready. Wait for it to finish so we know everything + // is up before the test proceeds. + let mut child = Command::new("sh") + .arg("-c") + .arg(path) + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .spawn()?; + let status = child.wait()?; + if !status.success() { + return Err(io::Error::other(format!( + "validator CLI exited with status {}", + status + ))); } + } else { + let child = Command::new("sh") + .arg("-c") + .arg(path) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn()?; + std::mem::drop(child); + tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; } + + Ok(()) } diff --git a/sdk-libs/client/src/rpc/client.rs b/sdk-libs/client/src/rpc/client.rs index 6f9e842520..6dc8f8f3ee 100644 --- a/sdk-libs/client/src/rpc/client.rs +++ b/sdk-libs/client/src/rpc/client.rs @@ -277,16 +277,16 @@ impl LightClient { .client .get_transaction_with_config(&signature, rpc_transaction_config) .map_err(|e| RpcError::CustomError(e.to_string()))?; - let decoded_transaction = transaction - .transaction - .transaction - .decode() - .clone() - .ok_or_else(|| { - RpcError::CustomError( - "Failed to decode transaction from RPC response".to_string(), - ) - })?; + let decoded_transaction = + transaction + .transaction + .transaction + .decode() + .ok_or_else(|| { + RpcError::CustomError( + "Failed to decode transaction from RPC response".to_string(), + ) + })?; let account_keys = decoded_transaction.message.static_account_keys(); let meta = transaction.transaction.meta.as_ref().ok_or_else(|| { RpcError::CustomError("Transaction missing metadata information".to_string()) diff --git a/sdk-libs/client/src/utils.rs b/sdk-libs/client/src/utils.rs index b8f2e05ecb..f5f488f5d9 100644 --- a/sdk-libs/client/src/utils.rs +++ b/sdk-libs/client/src/utils.rs @@ -7,10 +7,7 @@ pub fn find_light_bin() -> Option { { println!("Running 'which light' (feature 'devenv' is not enabled)"); use std::process::Command; - let output = Command::new("which") - .arg("light") - .output() - .expect("Failed to execute 'which light'"); + let output = Command::new("which").arg("light").output().ok()?; if !output.status.success() { return None; @@ -30,16 +27,15 @@ pub fn find_light_bin() -> Option { #[cfg(feature = "devenv")] { println!("Use only in light protocol monorepo. Using 'git rev-parse --show-toplevel' to find the location of 'light' binary"); - let light_protocol_toplevel = String::from_utf8_lossy( - &std::process::Command::new("git") - .arg("rev-parse") - .arg("--show-toplevel") - .output() - .expect("Failed to get top-level directory") - .stdout, - ) - .trim() - .to_string(); + let output = std::process::Command::new("git") + .arg("rev-parse") + .arg("--show-toplevel") + .output() + .ok()?; + if !output.status.success() { + return None; + } + let light_protocol_toplevel = String::from_utf8_lossy(&output.stdout).trim().to_string(); let light_path = PathBuf::from(format!("{}/target/deploy/", light_protocol_toplevel)); Some(light_path) } diff --git a/sdk-libs/macros/src/light_pdas/seeds/extract.rs b/sdk-libs/macros/src/light_pdas/seeds/extract.rs index 6d60aa2e5e..d16c5a6dbf 100644 --- a/sdk-libs/macros/src/light_pdas/seeds/extract.rs +++ b/sdk-libs/macros/src/light_pdas/seeds/extract.rs @@ -113,7 +113,7 @@ fn check_light_account_type(attrs: &[syn::Attribute]) -> (bool, bool, bool, bool _ => continue, }; - let token_vec: Vec<_> = tokens.clone().into_iter().collect(); + let token_vec: Vec<_> = tokens.into_iter().collect(); // Helper to check for a namespace prefix (e.g., "mint", "token", "associated_token") let has_namespace_prefix = |namespace: &str| { diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..fdd984c124 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,7 +726,9 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - start_index: start as u64, + // Consumers use start_index as the sparse tree's next insertion index, + // not the pagination offset used for queue slicing. + start_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { @@ -2170,18 +2172,22 @@ impl TestIndexer { let inclusion_proof_inputs = InclusionProofInputs::new(inclusion_proofs.as_slice()).unwrap(); ( - Some(BatchInclusionJsonStruct::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inclusion_proof_inputs) + .map_err(|e| IndexerError::CustomError(e.to_string()))?, + ), None, ) } else if height == STATE_MERKLE_TREE_HEIGHT as usize { let inclusion_proof_inputs = InclusionProofInputsLegacy(inclusion_proofs.as_slice()); ( None, - Some(BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + ) + .map_err(|e| IndexerError::CustomError(e.to_string()))?, + ), ) } else { return Err(IndexerError::CustomError( @@ -2259,7 +2265,8 @@ impl TestIndexer { Some( BatchNonInclusionJsonStructLegacy::from_non_inclusion_proof_inputs( &non_inclusion_proof_inputs, - ), + ) + .map_err(|e| IndexerError::CustomError(e.to_string()))?, ), ) } else if tree_heights[0] == 40 { @@ -2269,7 +2276,8 @@ impl TestIndexer { Some( BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( &non_inclusion_proof_inputs, - ), + ) + .map_err(|e| IndexerError::CustomError(e.to_string()))?, ), None, ) @@ -2356,9 +2364,22 @@ impl TestIndexer { ) .await?; if let Some(payload) = payload { - (indices, Vec::new(), payload.to_string()) + ( + indices, + Vec::new(), + payload + .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))?, + ) } else { - (indices, Vec::new(), payload_legacy.unwrap().to_string()) + ( + indices, + Vec::new(), + payload_legacy + .unwrap() + .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))?, + ) } } (None, Some(addresses)) => { @@ -2369,9 +2390,14 @@ impl TestIndexer { ) .await?; let payload_string = if let Some(payload) = payload { - payload.to_string() + payload + .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))? } else { - payload_legacy.unwrap().to_string() + payload_legacy + .unwrap() + .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))? }; (Vec::new(), indices, payload_string) } @@ -2448,6 +2474,7 @@ impl TestIndexer { non_inclusion: non_inclusion_payload.inputs, } .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))? } else if let Some(non_inclusion_payload) = non_inclusion_payload_legacy { CombinedJsonStructLegacy { circuit_type: ProofType::Combined.to_string(), @@ -2457,6 +2484,7 @@ impl TestIndexer { non_inclusion: non_inclusion_payload.inputs, } .to_string() + .map_err(|e| IndexerError::CustomError(e.to_string()))? } else { panic!("Unsupported tree height") }; @@ -2481,9 +2509,11 @@ impl TestIndexer { if response_result.status().is_success() { let body = response_result.text().await.unwrap(); let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json) + .map_err(|e| IndexerError::CustomError(e.to_string()))?; let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); + compress_proof(&proof_a, &proof_b, &proof_c) + .map_err(|e| IndexerError::CustomError(e.to_string()))?; return Ok(ValidityProofWithContext { accounts: account_proof_inputs, addresses: address_proof_inputs, diff --git a/sdk-libs/program-test/src/program_test/compressible_setup.rs b/sdk-libs/program-test/src/program_test/compressible_setup.rs index 3aede79a4c..92c2a45daf 100644 --- a/sdk-libs/program-test/src/program_test/compressible_setup.rs +++ b/sdk-libs/program-test/src/program_test/compressible_setup.rs @@ -67,7 +67,7 @@ pub async fn initialize_compression_config( rent_sponsor, address_space, config_bump, - ); + )?; let signers = if payer.pubkey() == authority.pubkey() { vec![payer] @@ -97,7 +97,7 @@ pub async fn update_compression_config( new_rent_sponsor, new_address_space, new_update_authority, - ); + )?; rpc.create_and_send_transaction(&[instruction], &payer.pubkey(), &[payer, authority]) .await diff --git a/sdk-libs/program-test/src/program_test/light_program_test.rs b/sdk-libs/program-test/src/program_test/light_program_test.rs index b7179d2530..cf5bfa2a88 100644 --- a/sdk-libs/program-test/src/program_test/light_program_test.rs +++ b/sdk-libs/program-test/src/program_test/light_program_test.rs @@ -393,11 +393,15 @@ impl LightProgramTest { #[cfg(feature = "devenv")] { - spawn_prover().await; + spawn_prover() + .await + .map_err(|error| RpcError::CustomError(error.to_string()))?; } #[cfg(not(feature = "devenv"))] if config.with_prover { - spawn_prover().await; + spawn_prover() + .await + .map_err(|error| RpcError::CustomError(error.to_string()))?; } Ok(context) diff --git a/sdk-tests/anchor-manual-test/tests/shared.rs b/sdk-tests/anchor-manual-test/tests/shared.rs index fc947c7ee5..c4ceb8441d 100644 --- a/sdk-tests/anchor-manual-test/tests/shared.rs +++ b/sdk-tests/anchor-manual-test/tests/shared.rs @@ -36,7 +36,8 @@ pub async fn setup_test_env() -> (LightProgramTest, Keypair, Pubkey) { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/anchor-semi-manual-test/tests/shared/mod.rs b/sdk-tests/anchor-semi-manual-test/tests/shared/mod.rs index 293c26399c..2680923b59 100644 --- a/sdk-tests/anchor-semi-manual-test/tests/shared/mod.rs +++ b/sdk-tests/anchor-semi-manual-test/tests/shared/mod.rs @@ -40,7 +40,8 @@ pub async fn setup_test_env() -> TestEnv { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/anchor-semi-manual-test/tests/stress_test.rs b/sdk-tests/anchor-semi-manual-test/tests/stress_test.rs index 8804f890be..7f70bc9466 100644 --- a/sdk-tests/anchor-semi-manual-test/tests/stress_test.rs +++ b/sdk-tests/anchor-semi-manual-test/tests/stress_test.rs @@ -93,7 +93,8 @@ async fn setup() -> (StressTestContext, TestPdas) { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/client-test/tests/light_client.rs b/sdk-tests/client-test/tests/light_client.rs index c8d1bb28bc..1cc7fe33bd 100644 --- a/sdk-tests/client-test/tests/light_client.rs +++ b/sdk-tests/client-test/tests/light_client.rs @@ -59,7 +59,7 @@ async fn test_all_endpoints() { validator_args: vec![], }; - spawn_validator(config).await; + spawn_validator(config).await.unwrap(); let test_accounts = TestAccounts::get_local_test_validator_accounts(); let mut rpc: LightClient = LightClient::new(LightClientConfig::local()).await.unwrap(); diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/amm_stress_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/amm_stress_test.rs index 1bd04c0cbc..e7e2ce7d68 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/amm_stress_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/amm_stress_test.rs @@ -116,7 +116,8 @@ async fn setup() -> AmmTestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/amm_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/amm_test.rs index 050531f970..29a5fa180b 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/amm_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/amm_test.rs @@ -98,7 +98,8 @@ async fn setup() -> AmmTestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs index 747c75ce32..cf8a8f9a31 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs @@ -67,7 +67,8 @@ async fn test_create_pdas_and_mint_auto() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/d10_ata_idempotent_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/d10_ata_idempotent_test.rs index 26fcf0d6d2..29d2949070 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/d10_ata_idempotent_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/d10_ata_idempotent_test.rs @@ -49,7 +49,8 @@ impl D10TestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/d10_token_accounts_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/d10_token_accounts_test.rs index 32b1e218e5..e9cbba42a7 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/d10_token_accounts_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/d10_token_accounts_test.rs @@ -54,7 +54,8 @@ impl D10TestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/d11_zero_copy_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/d11_zero_copy_test.rs index 7950bbee7a..7b6323dfd3 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/d11_zero_copy_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/d11_zero_copy_test.rs @@ -88,7 +88,8 @@ impl D11TestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/failing_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/failing_tests.rs index d89a87e25c..ee9dda76e2 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/failing_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/failing_tests.rs @@ -67,7 +67,8 @@ impl FailingTestContext { csdk_anchor_full_derived_test::program_rent_sponsor(), payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 9b40b900e5..40ae494961 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -71,7 +71,8 @@ impl TestContext { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/shared.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/shared.rs index e2f75c07d5..676b8d534e 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/shared.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/shared.rs @@ -56,7 +56,8 @@ impl SharedTestContext { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/pinocchio-light-program-test/tests/shared/mod.rs b/sdk-tests/pinocchio-light-program-test/tests/shared/mod.rs index c855d4289a..93b36c169c 100644 --- a/sdk-tests/pinocchio-light-program-test/tests/shared/mod.rs +++ b/sdk-tests/pinocchio-light-program-test/tests/shared/mod.rs @@ -42,7 +42,8 @@ pub async fn setup_test_env() -> TestEnv { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/pinocchio-light-program-test/tests/stress_test.rs b/sdk-tests/pinocchio-light-program-test/tests/stress_test.rs index 42e41ece33..5fa3efd2f2 100644 --- a/sdk-tests/pinocchio-light-program-test/tests/stress_test.rs +++ b/sdk-tests/pinocchio-light-program-test/tests/stress_test.rs @@ -129,7 +129,8 @@ async fn setup() -> (StressTestContext, TestPdas) { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/pinocchio-manual-test/tests/shared.rs b/sdk-tests/pinocchio-manual-test/tests/shared.rs index 6289b74524..4ed3262fb3 100644 --- a/sdk-tests/pinocchio-manual-test/tests/shared.rs +++ b/sdk-tests/pinocchio-manual-test/tests/shared.rs @@ -36,7 +36,8 @@ pub async fn setup_test_env() -> (LightProgramTest, Keypair, Pubkey) { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs index 154f4e2045..cca6594691 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs @@ -127,7 +127,7 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = rpc .get_random_state_tree_info() @@ -177,7 +177,7 @@ async fn read_sha256_light_system_cpi( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -230,7 +230,7 @@ async fn read_sha256_lowlevel( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -289,7 +289,7 @@ async fn create_compressed_account_poseidon( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = rpc .get_random_state_tree_info() @@ -339,7 +339,7 @@ async fn read_poseidon_light_system_cpi( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -392,7 +392,7 @@ async fn read_poseidon_lowlevel( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs index e19d0742de..861904d862 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs @@ -171,7 +171,7 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = rpc .get_random_state_tree_info() @@ -222,7 +222,7 @@ async fn update_compressed_account( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -276,7 +276,7 @@ async fn close_compressed_account( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -339,7 +339,7 @@ async fn reinit_closed_account( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); @@ -387,7 +387,7 @@ async fn close_compressed_account_permanent( .value; let packed_tree_accounts = rpc_result - .pack_tree_infos(&mut remaining_accounts) + .pack_tree_infos(&mut remaining_accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-native-test/tests/test.rs b/sdk-tests/sdk-native-test/tests/test.rs index 30d792487f..ad2ad0894f 100644 --- a/sdk-tests/sdk-native-test/tests/test.rs +++ b/sdk-tests/sdk-native-test/tests/test.rs @@ -103,7 +103,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -146,7 +147,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs index 0ae7f5c029..9558847fb1 100644 --- a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs @@ -101,7 +101,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -144,7 +145,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs index 59a0562c63..510c98b2b5 100644 --- a/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v2-test/tests/test.rs @@ -111,7 +111,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -154,7 +155,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs index 8e2b595285..b695674b5f 100644 --- a/sdk-tests/sdk-token-test/tests/ctoken_pda.rs +++ b/sdk-tests/sdk-token-test/tests/ctoken_pda.rs @@ -156,7 +156,7 @@ pub async fn create_mint( let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); // packed_accounts.insert_or_get(tree_info.get_output_pubkey()?); - rpc_result.pack_tree_infos(&mut packed_accounts); + let _packed_tree_infos = rpc_result.pack_tree_infos(&mut packed_accounts)?; // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs index 5f096af560..96a4a4021b 100644 --- a/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs +++ b/sdk-tests/sdk-token-test/tests/decompress_full_cpi.rs @@ -213,7 +213,7 @@ async fn test_decompress_full_cpi() { .unwrap() .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) @@ -370,7 +370,7 @@ async fn test_decompress_full_cpi_with_context() { .value; // Add tree accounts first, then custom system accounts (no CPI context since params is None) - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts).unwrap(); let config = DecompressFullAccounts::new(None); remaining_accounts .add_custom_system_accounts(config) diff --git a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs index 91e0f2db9e..86c4785c86 100644 --- a/sdk-tests/sdk-token-test/tests/pda_ctoken.rs +++ b/sdk-tests/sdk-token-test/tests/pda_ctoken.rs @@ -214,7 +214,7 @@ pub async fn create_mint( let mut packed_accounts = PackedAccounts::default(); let config = SystemAccountMetaConfig::new_with_cpi_context(ID, tree_info.cpi_context.unwrap()); packed_accounts.add_system_accounts_v2(config).unwrap(); - rpc_result.pack_tree_infos(&mut packed_accounts); + let _packed_tree_infos = rpc_result.pack_tree_infos(&mut packed_accounts)?; // Create PDA parameters let pda_amount = 100u64; diff --git a/sdk-tests/sdk-token-test/tests/test.rs b/sdk-tests/sdk-token-test/tests/test.rs index 3c6941881d..ff2dab3236 100644 --- a/sdk-tests/sdk-token-test/tests/test.rs +++ b/sdk-tests/sdk-token-test/tests/test.rs @@ -367,7 +367,7 @@ async fn transfer_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = packed_tree_info .state_trees .as_ref() @@ -433,7 +433,7 @@ async fn decompress_compressed_tokens( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = packed_tree_info .state_trees .as_ref() diff --git a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs index 9e70170056..e0de22648e 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_invocations.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_invocations.rs @@ -389,7 +389,7 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -498,7 +498,7 @@ async fn test_four_invokes_instruction( // We need to pack the tree after the cpi context. remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = packed_tree_info .state_trees .as_ref() diff --git a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs index d7ef38a08c..252715de31 100644 --- a/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs +++ b/sdk-tests/sdk-token-test/tests/test_4_transfer2.rs @@ -339,7 +339,7 @@ async fn create_compressed_escrow_pda( .await? .value; - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let new_address_params = packed_tree_info.address_trees[0] .into_new_address_params_assigned_packed(address_seed, Some(0)); @@ -438,7 +438,7 @@ async fn test_four_transfer2_instruction( // We need to pack the tree after the cpi context. remaining_accounts.insert_or_get(rpc_result.accounts[0].tree_info.tree); - let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_tree_info = rpc_result.pack_tree_infos(&mut remaining_accounts)?; let output_tree_index = packed_tree_info .state_trees .as_ref() diff --git a/sdk-tests/sdk-token-test/tests/test_deposit.rs b/sdk-tests/sdk-token-test/tests/test_deposit.rs index 9ebcbd8549..c01512cd54 100644 --- a/sdk-tests/sdk-token-test/tests/test_deposit.rs +++ b/sdk-tests/sdk-token-test/tests/test_deposit.rs @@ -206,7 +206,7 @@ async fn create_deposit_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts)?; println!("packed_accounts {:?}", packed_accounts.state_trees); // Create token meta from compressed account @@ -318,7 +318,7 @@ async fn update_deposit_compressed_account( // Get validity proof for the compressed token account and new address println!("rpc_result {:?}", rpc_result); - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts)?; println!("packed_accounts {:?}", packed_accounts.state_trees); // TODO: investigate why packed_tree_infos seem to be out of order // Create token meta from compressed account diff --git a/sdk-tests/sdk-v1-native-test/tests/test.rs b/sdk-tests/sdk-v1-native-test/tests/test.rs index a93beab599..2e10e61e14 100644 --- a/sdk-tests/sdk-v1-native-test/tests/test.rs +++ b/sdk-tests/sdk-v1-native-test/tests/test.rs @@ -94,7 +94,8 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_tree_infos = rpc_result.pack_tree_infos(&mut accounts)?; + let packed_address_tree_info = packed_tree_infos.address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -137,7 +138,7 @@ pub async fn update_pda( .value; let packed_accounts = rpc_result - .pack_tree_infos(&mut accounts) + .pack_tree_infos(&mut accounts)? .state_trees .unwrap(); diff --git a/sdk-tests/single-account-loader-test/tests/test.rs b/sdk-tests/single-account-loader-test/tests/test.rs index 49e51e544c..e1d54dec5a 100644 --- a/sdk-tests/single-account-loader-test/tests/test.rs +++ b/sdk-tests/single-account-loader-test/tests/test.rs @@ -44,7 +44,8 @@ async fn test_create_zero_copy_record() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await @@ -149,7 +150,8 @@ async fn test_zero_copy_record_full_lifecycle() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/single-ata-test/tests/test.rs b/sdk-tests/single-ata-test/tests/test.rs index 0196b17a23..eda2c79fa0 100644 --- a/sdk-tests/single-ata-test/tests/test.rs +++ b/sdk-tests/single-ata-test/tests/test.rs @@ -103,7 +103,8 @@ async fn test_create_single_ata() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/single-mint-test/tests/test.rs b/sdk-tests/single-mint-test/tests/test.rs index ede75d0504..5d641b8e4a 100644 --- a/sdk-tests/single-mint-test/tests/test.rs +++ b/sdk-tests/single-mint-test/tests/test.rs @@ -41,7 +41,8 @@ async fn test_create_single_mint() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/single-pda-test/tests/test.rs b/sdk-tests/single-pda-test/tests/test.rs index 79dc17d0e5..564dcb225d 100644 --- a/sdk-tests/single-pda-test/tests/test.rs +++ b/sdk-tests/single-pda-test/tests/test.rs @@ -39,7 +39,8 @@ async fn test_create_single_pda() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sdk-tests/single-token-test/tests/test.rs b/sdk-tests/single-token-test/tests/test.rs index 485b4bba7f..a1bf2337d2 100644 --- a/sdk-tests/single-token-test/tests/test.rs +++ b/sdk-tests/single-token-test/tests/test.rs @@ -103,7 +103,8 @@ async fn test_create_single_token_vault() { rent_sponsor, payer.pubkey(), ) - .build(); + .build() + .unwrap(); rpc.create_and_send_transaction(&[init_config_ix], &payer.pubkey(), &[&payer]) .await diff --git a/sparse-merkle-tree/src/indexed_changelog.rs b/sparse-merkle-tree/src/indexed_changelog.rs index bbd30e1ee6..7e6a26cff7 100644 --- a/sparse-merkle-tree/src/indexed_changelog.rs +++ b/sparse-merkle-tree/src/indexed_changelog.rs @@ -29,7 +29,7 @@ pub fn patch_indexed_changelogs( low_element: &mut IndexedElement, new_element: &mut IndexedElement, low_element_next_value: &mut BigUint, - low_leaf_proof: &mut Vec<[u8; 32]>, + low_leaf_proof: &mut [[u8; 32]; HEIGHT], ) -> Result<(), SparseMerkleTreeError> { // Tests are in program-tests/merkle-tree/tests/indexed_changelog.rs let next_indexed_changelog_indices: Vec = (*indexed_changelogs) @@ -69,7 +69,7 @@ pub fn patch_indexed_changelogs( // Patch the next value. *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value); // Patch the proof. - *low_leaf_proof = changelog_entry.proof.to_vec(); + *low_leaf_proof = changelog_entry.proof; } // If we found a new low element. @@ -82,7 +82,7 @@ pub fn patch_indexed_changelogs( next_index: new_low_element_changelog_entry.element.next_index, }; - *low_leaf_proof = new_low_element_changelog_entry.proof.to_vec(); + *low_leaf_proof = new_low_element_changelog_entry.proof; new_element.next_index = low_element.next_index; if new_low_element_changelog_index == indexed_changelogs.len() - 1 { return Ok(()); diff --git a/sparse-merkle-tree/tests/indexed_changelog.rs b/sparse-merkle-tree/tests/indexed_changelog.rs index 7d37142b46..59efda6fde 100644 --- a/sparse-merkle-tree/tests/indexed_changelog.rs +++ b/sparse-merkle-tree/tests/indexed_changelog.rs @@ -92,7 +92,8 @@ fn test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -114,7 +115,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -124,7 +125,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); println!("patched -------------------"); @@ -206,7 +207,8 @@ fn debug_test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -228,7 +230,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -238,7 +240,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); man_indexed_array.elements[low_element.index()] = low_element.clone();