Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 12 additions & 57 deletions forester-utils/src/account_zero_copy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt, marker::PhantomData, mem, pin::Pin};
use std::{fmt, mem};

use light_client::rpc::Rpc;
use light_concurrent_merkle_tree::{
Expand All @@ -8,7 +8,7 @@ use light_hash_set::HashSet;
use light_hasher::Hasher;
use light_indexed_merkle_tree::{copy::IndexedMerkleTreeCopy, errors::IndexedMerkleTreeError};
use num_traits::{CheckedAdd, CheckedSub, ToBytes, Unsigned};
use solana_sdk::{account::Account, pubkey::Pubkey};
use solana_sdk::pubkey::Pubkey;
use thiserror::Error;

#[derive(Error, Debug)]
Expand All @@ -19,52 +19,15 @@ pub enum AccountZeroCopyError {
AccountNotFound(Pubkey),
}

#[derive(Debug, Clone)]
pub struct AccountZeroCopy<'a, T> {
pub account: Pin<Box<Account>>,
deserialized: *const T,
_phantom_data: PhantomData<&'a T>,
}

impl<'a, T> AccountZeroCopy<'a, T> {
pub async fn new<R: Rpc>(
rpc: &mut R,
address: Pubkey,
) -> Result<AccountZeroCopy<'a, T>, AccountZeroCopyError> {
let account = rpc
.get_account(address)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(address))?;
let account = Box::pin(account);
let deserialized = account.data[8..].as_ptr() as *const T;

Ok(Self {
account,
deserialized,
_phantom_data: PhantomData,
})
}

// Safe method to access `deserialized` ensuring the lifetime is respected
pub fn deserialized(&self) -> &'a T {
unsafe { &*self.deserialized }
}
fn copy_hash_set_from_account_bytes(bytes: &[u8]) -> Result<HashSet, light_hash_set::HashSetError> {
let mut owned = bytes.to_vec();
// SAFETY: We pass an owned, mutable copy of bytes that are expected to contain a serialized
// hash set account image.
unsafe { HashSet::from_bytes_copy(&mut owned) }
}

/// Fetches the given account, then copies and serializes it as a `HashSet`.
///
/// # Safety
///
/// This is highly unsafe. Ensuring that:
///
/// * The correct account is used.
/// * The account has enough space to be treated as a HashSet with specified
/// parameters.
/// * The account data is aligned.
///
/// Is the caller's responsibility.
pub async unsafe fn get_hash_set<T, R: Rpc>(
pub async fn get_hash_set<T, R: Rpc>(
rpc: &mut R,
pubkey: Pubkey,
) -> Result<HashSet, AccountZeroCopyError> {
Expand All @@ -73,9 +36,7 @@ pub async unsafe fn get_hash_set<T, R: Rpc>(
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(pubkey))?;
let mut data = account.data.clone();

HashSet::from_bytes_copy(&mut data[8 + mem::size_of::<T>()..])
copy_hash_set_from_account_bytes(&account.data[8 + mem::size_of::<T>()..])
.map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e)))
}

Expand Down Expand Up @@ -171,17 +132,11 @@ where
IndexedMerkleTreeCopy::from_bytes_copy(&data[offset..])
}

/// Parse HashSet from raw queue account data bytes
///
/// # Safety
/// Same safety requirements as `get_hash_set`.
pub unsafe fn parse_hash_set_from_bytes<T>(
data: &[u8],
) -> Result<HashSet, light_hash_set::HashSetError> {
/// Parse HashSet from raw queue account data bytes.
pub fn parse_hash_set_from_bytes<T>(data: &[u8]) -> Result<HashSet, light_hash_set::HashSetError> {
let offset = 8 + mem::size_of::<T>();
if data.len() <= offset {
return Err(light_hash_set::HashSetError::BufferSize(offset, data.len()));
}
let mut data_copy = data[offset..].to_vec();
HashSet::from_bytes_copy(&mut data_copy)
copy_hash_set_from_account_bytes(&data[offset..])
}
176 changes: 109 additions & 67 deletions forester-utils/src/address_merkle_tree_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use account_compression::{
AddressMerkleTreeAccount, AddressMerkleTreeConfig, AddressQueueConfig, NullifierQueueConfig,
QueueAccount, StateMerkleTreeAccount, StateMerkleTreeConfig,
};
use anchor_lang::Discriminator;
use anchor_lang::{AccountDeserialize, Discriminator};
use light_account_checks::discriminator::Discriminator as LightDiscriminator;
use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount;
use light_client::{
Expand All @@ -14,40 +14,66 @@ use num_traits::Zero;
use solana_sdk::pubkey::Pubkey;

use crate::account_zero_copy::{
get_concurrent_merkle_tree, get_hash_set, get_indexed_merkle_tree, AccountZeroCopy,
AccountZeroCopyError,
get_concurrent_merkle_tree, get_indexed_merkle_tree, parse_concurrent_merkle_tree_from_bytes,
parse_hash_set_from_bytes, parse_indexed_merkle_tree_from_bytes, AccountZeroCopyError,
};

fn deserialize_account<T: AccountDeserialize>(
data: &[u8],
pubkey: Pubkey,
) -> Result<T, AccountZeroCopyError> {
if data.len() < 8 {
return Err(AccountZeroCopyError::RpcError(format!(
"Account {} data too short: {}",
pubkey,
data.len()
)));
}

T::try_deserialize(&mut &data[..]).map_err(|e| {
AccountZeroCopyError::RpcError(format!("Failed to deserialize account {}: {}", pubkey, e))
})
}

pub async fn get_address_bundle_config<R: Rpc>(
rpc: &mut R,
address_bundle: AddressMerkleTreeAccounts,
) -> Result<(AddressMerkleTreeConfig, AddressQueueConfig), AccountZeroCopyError> {
// Get queue metadata - don't hold AccountZeroCopy across await points
let address_queue_meta_data = {
let account = AccountZeroCopy::<QueueAccount>::new(rpc, address_bundle.queue).await?;
account.deserialized().metadata
};
let address_queue =
unsafe { get_hash_set::<QueueAccount, R>(rpc, address_bundle.queue).await? };
let address_queue_account = rpc
.get_account(address_bundle.queue)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(address_bundle.queue))?;
let address_queue_meta_data =
deserialize_account::<QueueAccount>(&address_queue_account.data, address_bundle.queue)?
.metadata;
let address_queue = parse_hash_set_from_bytes::<QueueAccount>(&address_queue_account.data)
.map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e)))?;
let queue_config = AddressQueueConfig {
network_fee: Some(address_queue_meta_data.rollover_metadata.network_fee),
// rollover_threshold: address_queue_meta_data.rollover_threshold,
capacity: address_queue.get_capacity() as u16,
sequence_threshold: address_queue.sequence_threshold as u64,
};
// Get tree metadata - don't hold AccountZeroCopy across await points
let address_tree_meta_data = {
let account =
AccountZeroCopy::<AddressMerkleTreeAccount>::new(rpc, address_bundle.merkle_tree)
.await?;
account.deserialized().metadata
};
let address_tree =
get_indexed_merkle_tree::<AddressMerkleTreeAccount, R, Poseidon, usize, 26, 16>(
rpc,
let address_tree_account = rpc
.get_account(address_bundle.merkle_tree)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(
address_bundle.merkle_tree,
))?;
let address_tree_meta_data = deserialize_account::<AddressMerkleTreeAccount>(
&address_tree_account.data,
address_bundle.merkle_tree,
)?
.metadata;
let address_tree =
parse_indexed_merkle_tree_from_bytes::<AddressMerkleTreeAccount, Poseidon, usize, 26, 16>(
&address_tree_account.data,
)
.await?;
.map_err(|e| {
AccountZeroCopyError::RpcError(format!("IndexedMerkleTree parse error: {:?}", e))
})?;
let address_merkle_tree_config = AddressMerkleTreeConfig {
height: address_tree.height as u32,
changelog_size: address_tree.merkle_tree.changelog.capacity() as u64,
Expand All @@ -73,64 +99,77 @@ pub async fn get_state_bundle_config<R: Rpc>(
rpc: &mut R,
state_tree_bundle: StateMerkleTreeAccounts,
) -> Result<(StateMerkleTreeConfig, NullifierQueueConfig), AccountZeroCopyError> {
// Get queue metadata - don't hold AccountZeroCopy across await points
let address_queue_meta_data = {
let account =
AccountZeroCopy::<QueueAccount>::new(rpc, state_tree_bundle.nullifier_queue).await?;
account.deserialized().metadata
};
let address_queue =
unsafe { get_hash_set::<QueueAccount, R>(rpc, state_tree_bundle.nullifier_queue).await? };
let queue_account = rpc
.get_account(state_tree_bundle.nullifier_queue)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(
state_tree_bundle.nullifier_queue,
))?;
let nullifier_queue_metadata = deserialize_account::<QueueAccount>(
&queue_account.data,
state_tree_bundle.nullifier_queue,
)?
.metadata;
let nullifier_queue = parse_hash_set_from_bytes::<QueueAccount>(&queue_account.data)
.map_err(|e| AccountZeroCopyError::RpcError(format!("HashSet parse error: {:?}", e)))?;
let queue_config = NullifierQueueConfig {
network_fee: Some(address_queue_meta_data.rollover_metadata.network_fee),
capacity: address_queue.get_capacity() as u16,
sequence_threshold: address_queue.sequence_threshold as u64,
};
// Get tree metadata - don't hold AccountZeroCopy across await points
let address_tree_meta_data = {
let account =
AccountZeroCopy::<StateMerkleTreeAccount>::new(rpc, state_tree_bundle.merkle_tree)
.await?;
account.deserialized().metadata
network_fee: Some(nullifier_queue_metadata.rollover_metadata.network_fee),
capacity: nullifier_queue.get_capacity() as u16,
sequence_threshold: nullifier_queue.sequence_threshold as u64,
};
let address_tree = get_concurrent_merkle_tree::<StateMerkleTreeAccount, R, Poseidon, 26>(
rpc,
let state_tree_account = rpc
.get_account(state_tree_bundle.merkle_tree)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(
state_tree_bundle.merkle_tree,
))?;
let state_tree_metadata = deserialize_account::<StateMerkleTreeAccount>(
&state_tree_account.data,
state_tree_bundle.merkle_tree,
)
.await?;
let address_merkle_tree_config = StateMerkleTreeConfig {
height: address_tree.height as u32,
changelog_size: address_tree.changelog.capacity() as u64,
roots_size: address_tree.roots.capacity() as u64,
canopy_depth: address_tree.canopy_depth as u64,
rollover_threshold: if address_tree_meta_data
)?
.metadata;
let state_tree =
parse_concurrent_merkle_tree_from_bytes::<StateMerkleTreeAccount, Poseidon, 26>(
&state_tree_account.data,
)
.map_err(|e| {
AccountZeroCopyError::RpcError(format!("ConcurrentMerkleTree parse error: {:?}", e))
})?;
let state_merkle_tree_config = StateMerkleTreeConfig {
height: state_tree.height as u32,
changelog_size: state_tree.changelog.capacity() as u64,
roots_size: state_tree.roots.capacity() as u64,
canopy_depth: state_tree.canopy_depth as u64,
rollover_threshold: if state_tree_metadata
.rollover_metadata
.rollover_threshold
.is_zero()
{
None
} else {
Some(address_tree_meta_data.rollover_metadata.rollover_threshold)
Some(state_tree_metadata.rollover_metadata.rollover_threshold)
},
network_fee: Some(address_tree_meta_data.rollover_metadata.network_fee),
network_fee: Some(state_tree_metadata.rollover_metadata.network_fee),
close_threshold: None,
};
Ok((address_merkle_tree_config, queue_config))
Ok((state_merkle_tree_config, queue_config))
}

pub async fn address_tree_ready_for_rollover<R: Rpc>(
rpc: &mut R,
merkle_tree: Pubkey,
) -> Result<bool, AccountZeroCopyError> {
// Get account data - don't hold AccountZeroCopy across await points
let (address_tree_meta_data, account_data_len, account_lamports) = {
let account = AccountZeroCopy::<AddressMerkleTreeAccount>::new(rpc, merkle_tree).await?;
(
account.deserialized().metadata,
account.account.data.len(),
account.account.lamports,
)
};
let account = rpc
.get_account(merkle_tree)
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?
.ok_or(AccountZeroCopyError::AccountNotFound(merkle_tree))?;
let address_tree_meta_data =
deserialize_account::<AddressMerkleTreeAccount>(&account.data, merkle_tree)?.metadata;
let account_data_len = account.data.len();
let account_lamports = account.lamports;
let rent_exemption = rpc
.get_minimum_balance_for_rent_exemption(account_data_len)
.await
Expand Down Expand Up @@ -166,15 +205,18 @@ pub async fn state_tree_ready_for_rollover<R: Rpc>(
.get_minimum_balance_for_rent_exemption(account.data.len())
.await
.map_err(|e| AccountZeroCopyError::RpcError(e.to_string()))?;
if account.data.len() < 8 {
return Err(AccountZeroCopyError::RpcError(format!(
"Account {} data too short: {}",
merkle_tree,
account.data.len()
)));
}
let discriminator = &account.data[0..8];
let (next_index, tree_meta_data, height) = match discriminator {
d if d == StateMerkleTreeAccount::DISCRIMINATOR => {
// Get tree metadata - don't hold AccountZeroCopy across await points
let tree_meta_data = {
let account =
AccountZeroCopy::<StateMerkleTreeAccount>::new(rpc, merkle_tree).await?;
account.deserialized().metadata
};
let tree_meta_data =
deserialize_account::<StateMerkleTreeAccount>(&account.data, merkle_tree)?.metadata;
let tree = get_concurrent_merkle_tree::<StateMerkleTreeAccount, R, Poseidon, 26>(
rpc,
merkle_tree,
Expand Down
Loading
Loading