Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 286 additions & 4 deletions pallets/shield/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@

extern crate alloc;

use alloc::vec;
use chacha20poly1305::{
KeyInit, XChaCha20Poly1305, XNonce,
aead::{Aead, Payload},
};
use frame_support::{pallet_prelude::*, traits::IsSubType};
use frame_system::{ensure_none, ensure_signed, pallet_prelude::*};
use frame_support::{
dispatch::{GetDispatchInfo, PostDispatchInfo},
pallet_prelude::*,
traits::{ConstU64, IsSubType},
};
use frame_system::{ensure_none, ensure_root, ensure_signed, pallet_prelude::*};
use ml_kem::{
Ciphertext, EncodedSizeUser, MlKem768, MlKem768Params,
kem::{Decapsulate, DecapsulationKey},
};
use sp_io::hashing::twox_128;
use sp_runtime::traits::{Applyable, Block as BlockT, Checkable, Hash};
use sp_runtime::traits::{Dispatchable, Saturating};
use stp_shield::{
INHERENT_IDENTIFIER, InherentType, LOG_TARGET, MLKEM768_ENC_KEY_LEN, ShieldEncKey,
ShieldedTransaction,
};

use alloc::vec;
use subtensor_macros::freeze_struct;

pub use pallet::*;

Expand All @@ -45,6 +50,19 @@ type ApplyableCallOf<T> = <T as Applyable>::Call;

const MAX_EXTRINSIC_DEPTH: u32 = 8;

/// Trait for decrypting stored extrinsics before dispatch.
pub trait ExtrinsicDecryptor<RuntimeCall> {
/// Decrypt the stored bytes and return the decoded RuntimeCall.
fn decrypt(data: &[u8]) -> Result<RuntimeCall, DispatchError>;
}

/// Default implementation that always returns an error.
impl<RuntimeCall> ExtrinsicDecryptor<RuntimeCall> for () {
fn decrypt(_data: &[u8]) -> Result<RuntimeCall, DispatchError> {
Err(DispatchError::Other("ExtrinsicDecryptor not implemented"))
}
}

#[frame_support::pallet]
pub mod pallet {
use super::*;
Expand All @@ -56,6 +74,14 @@ pub mod pallet {

/// A way to find the current and next block author.
type FindAuthors: FindAuthors<Self>;

/// The overarching call type for dispatching stored extrinsics.
type RuntimeCall: Parameter
+ Dispatchable<RuntimeOrigin = Self::RuntimeOrigin, PostInfo = PostDispatchInfo>
+ GetDispatchInfo;

/// Decryptor for stored extrinsics.
type ExtrinsicDecryptor: ExtrinsicDecryptor<<Self as pallet::Config>::RuntimeCall>;
}

#[pallet::pallet]
Expand Down Expand Up @@ -93,11 +119,89 @@ pub mod pallet {
pub type HasMigrationRun<T: Config> =
StorageMap<_, Identity, BoundedVec<u8, MigrationKeyMaxLen>, bool, ValueQuery>;

/// Maximum size of a single encoded call.
pub type MaxCallSize = ConstU32<8192>;

/// Default maximum number of pending extrinsics.
pub type DefaultMaxPendingExtrinsics = ConstU32<100>;

/// Configurable maximum number of pending extrinsics.
/// Defaults to 100 if not explicitly set via `set_max_pending_extrinsics`.
#[pallet::storage]
pub type MaxPendingExtrinsicsLimit<T: Config> =
StorageValue<_, u32, ValueQuery, DefaultMaxPendingExtrinsics>;

/// Default extrinsic lifetime in blocks.
pub const DEFAULT_EXTRINSIC_LIFETIME: u32 = 10;

/// Configurable extrinsic lifetime (max block difference between submission and execution).
/// Defaults to 10 blocks if not explicitly set.
#[pallet::storage]
pub type ExtrinsicLifetime<T: Config> =
StorageValue<_, u32, ValueQuery, ConstU32<DEFAULT_EXTRINSIC_LIFETIME>>;

/// Default maximum weight allowed for on_initialize processing.
pub const DEFAULT_ON_INITIALIZE_WEIGHT: u64 = 500_000_000_000;

/// Absolute maximum weight for on_initialize: half the total block weight (2s of 4s).
pub const MAX_ON_INITIALIZE_WEIGHT: u64 = 2_000_000_000_000;

/// Configurable maximum weight for on_initialize processing.
/// Defaults to 500_000_000_000 ref_time if not explicitly set.
#[pallet::storage]
pub type OnInitializeWeight<T: Config> =
StorageValue<_, u64, ValueQuery, ConstU64<DEFAULT_ON_INITIALIZE_WEIGHT>>;

/// A pending extrinsic stored for later execution.
#[freeze_struct("c5749ec89253be61")]
#[derive(Clone, Encode, Decode, TypeInfo, MaxEncodedLen, PartialEq, Debug)]
#[scale_info(skip_type_params(T))]
pub struct PendingExtrinsic<T: Config> {
/// The account that submitted the extrinsic.
pub who: T::AccountId,
/// The encoded call data.
pub call: BoundedVec<u8, MaxCallSize>,
/// The block number when the extrinsic was submitted.
pub submitted_at: BlockNumberFor<T>,
}

/// Storage map for encrypted extrinsics to be executed in on_initialize.
/// Uses u32 index for O(1) insertion and removal.
#[pallet::storage]
pub type PendingExtrinsics<T: Config> =
StorageMap<_, Identity, u32, PendingExtrinsic<T>, OptionQuery>;
Comment on lines +168 to +172
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is maybe one of the rare case where we could use a CountedStorageMap


/// Next index to use when inserting a pending extrinsic (unique auto-increment).
#[pallet::storage]
pub type NextPendingExtrinsicIndex<T: Config> = StorageValue<_, u32, ValueQuery>;

/// Number of pending extrinsics currently stored (for limit checking).
#[pallet::storage]
pub type PendingExtrinsicCount<T: Config> = StorageValue<_, u32, ValueQuery>;

#[pallet::event]
#[pallet::generate_deposit(pub(super) fn deposit_event)]
pub enum Event<T: Config> {
/// Encrypted wrapper accepted.
EncryptedSubmitted { id: T::Hash, who: T::AccountId },
/// Encrypted extrinsic was stored for later execution.
ExtrinsicStored { index: u32, who: T::AccountId },
/// Extrinsic decode failed during on_initialize.
ExtrinsicDecodeFailed { index: u32 },
/// Extrinsic dispatch failed during on_initialize.
ExtrinsicDispatchFailed { index: u32, error: DispatchError },
/// Extrinsic was successfully dispatched during on_initialize.
ExtrinsicDispatched { index: u32 },
/// Extrinsic expired (exceeded max block lifetime).
ExtrinsicExpired { index: u32 },
/// Extrinsic postponed due to weight limit.
ExtrinsicPostponed { index: u32 },
/// Maximum pending extrinsics limit was updated.
MaxPendingExtrinsicsNumberSet { value: u32 },
/// Maximum on_initialize weight was updated.
OnInitializeWeightSet { value: u64 },
/// Extrinsic lifetime was updated.
ExtrinsicLifetimeSet { value: u32 },
}

#[pallet::error]
Expand All @@ -106,10 +210,18 @@ pub mod pallet {
BadEncKeyLen,
/// Unreachable.
Unreachable,
/// Too many pending extrinsics in storage.
TooManyPendingExtrinsics,
/// Weight exceeds the absolute maximum (half of total block weight).
WeightExceedsAbsoluteMax,
}

#[pallet::hooks]
impl<T: Config> Hooks<BlockNumberFor<T>> for Pallet<T> {
fn on_initialize(_block_number: BlockNumberFor<T>) -> Weight {
Self::process_pending_extrinsics()
}

fn on_runtime_upgrade() -> frame_support::weights::Weight {
let mut weight = frame_support::weights::Weight::from_parts(0, 0);

Expand Down Expand Up @@ -229,6 +341,84 @@ pub mod pallet {
Self::deposit_event(Event::EncryptedSubmitted { id, who });
Ok(())
}

/// Store an encrypted extrinsic for later execution in on_initialize.
#[pallet::call_index(2)]
#[pallet::weight(Weight::from_parts(10_000_000, 0)
.saturating_add(T::DbWeight::get().reads(2_u64))
.saturating_add(T::DbWeight::get().writes(3_u64)))]
pub fn store_encrypted(
origin: OriginFor<T>,
call: BoundedVec<u8, MaxCallSize>,
) -> DispatchResult {
let who = ensure_signed(origin)?;

let count = PendingExtrinsicCount::<T>::get();

ensure!(
count < MaxPendingExtrinsicsLimit::<T>::get(),
Error::<T>::TooManyPendingExtrinsics
);

let index = NextPendingExtrinsicIndex::<T>::get();
let pending = PendingExtrinsic {
who: who.clone(),
call,
submitted_at: frame_system::Pallet::<T>::block_number(),
};
PendingExtrinsics::<T>::insert(index, pending);

NextPendingExtrinsicIndex::<T>::put(index.saturating_add(1));
PendingExtrinsicCount::<T>::put(count.saturating_add(1));

Self::deposit_event(Event::ExtrinsicStored { index, who });
Ok(())
}

/// Set the maximum number of pending extrinsics allowed in the queue.
#[pallet::call_index(3)]
#[pallet::weight(T::DbWeight::get().writes(1_u64))]
pub fn set_max_pending_extrinsics_number(
origin: OriginFor<T>,
value: u32,
) -> DispatchResult {
ensure_root(origin)?;

MaxPendingExtrinsicsLimit::<T>::put(value);

Self::deposit_event(Event::MaxPendingExtrinsicsNumberSet { value });
Ok(())
}

/// Set the maximum weight allowed for on_initialize processing.
/// Rejects values exceeding the absolute limit (half of total block weight).
#[pallet::call_index(4)]
#[pallet::weight(T::DbWeight::get().writes(1_u64))]
pub fn set_on_initialize_weight(origin: OriginFor<T>, value: u64) -> DispatchResult {
ensure_root(origin)?;

ensure!(
value <= MAX_ON_INITIALIZE_WEIGHT,
Error::<T>::WeightExceedsAbsoluteMax
);

OnInitializeWeight::<T>::put(value);

Self::deposit_event(Event::OnInitializeWeightSet { value });
Ok(())
}

/// Set the extrinsic lifetime (max blocks between submission and execution).
#[pallet::call_index(5)]
#[pallet::weight(T::DbWeight::get().writes(1_u64))]
pub fn set_stored_extrinsic_lifetime(origin: OriginFor<T>, value: u32) -> DispatchResult {
ensure_root(origin)?;

ExtrinsicLifetime::<T>::put(value);

Self::deposit_event(Event::ExtrinsicLifetimeSet { value });
Ok(())
}
}

#[pallet::inherent]
Expand All @@ -255,6 +445,98 @@ pub mod pallet {
}

impl<T: Config> Pallet<T> {
/// Process pending encrypted extrinsics up to the weight limit.
/// Returns the total weight consumed.
pub fn process_pending_extrinsics() -> Weight {
let next_index = NextPendingExtrinsicIndex::<T>::get();
let count = PendingExtrinsicCount::<T>::get();

let mut weight = T::DbWeight::get().reads(2);

if count == 0 {
return weight;
}

let start_index = next_index.saturating_sub(count);
let current_block = frame_system::Pallet::<T>::block_number();

// Process extrinsics
for index in start_index..next_index {
let Some(pending) = PendingExtrinsics::<T>::get(index) else {
weight = weight.saturating_add(T::DbWeight::get().reads(1));

continue;
};

// Check if the extrinsic has expired
let age = current_block.saturating_sub(pending.submitted_at);
if age > ExtrinsicLifetime::<T>::get().into() {
remove_pending_extrinsic::<T>(index, &mut weight);

Self::deposit_event(Event::ExtrinsicExpired { index });

continue;
}

let call = match T::ExtrinsicDecryptor::decrypt(&pending.call) {
Ok(call) => call,
Err(_) => {
remove_pending_extrinsic::<T>(index, &mut weight);

Self::deposit_event(Event::ExtrinsicDecodeFailed { index });

continue;
}
};

// Check if dispatching would exceed weight limit
let info = call.get_dispatch_info();
let dispatch_weight = T::DbWeight::get()
.writes(2)
.saturating_add(info.call_weight);

let max_weight = Weight::from_parts(OnInitializeWeight::<T>::get(), 0);

if weight.saturating_add(dispatch_weight).any_gt(max_weight) {
Self::deposit_event(Event::ExtrinsicPostponed { index });
break;
}

// We're going to execute it - remove the item from storage
remove_pending_extrinsic::<T>(index, &mut weight);

// Dispatch the extrinsic
let origin: T::RuntimeOrigin = frame_system::RawOrigin::Signed(pending.who).into();
let result = call.dispatch(origin);

match result {
Ok(post_info) => {
let actual_weight = post_info.actual_weight.unwrap_or(info.call_weight);
weight = weight.saturating_add(actual_weight);

Self::deposit_event(Event::ExtrinsicDispatched { index });
}
Err(e) => {
weight = weight.saturating_add(info.call_weight);

Self::deposit_event(Event::ExtrinsicDispatchFailed {
index,
error: e.error,
});
}
}
}

/// Remove a pending extrinsic from storage and decrement count.
fn remove_pending_extrinsic<T: Config>(index: u32, weight: &mut Weight) {
PendingExtrinsics::<T>::remove(index);
PendingExtrinsicCount::<T>::mutate(|c| *c = c.saturating_sub(1));
*weight = weight.saturating_add(T::DbWeight::get().writes(2));
}

weight
}

pub fn try_decode_shielded_tx<Block: BlockT, Context: Default>(
uxt: ExtrinsicOf<Block>,
) -> Option<ShieldedTransaction>
Expand Down
13 changes: 13 additions & 0 deletions pallets/shield/src/mock.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate as pallet_shield;
use stp_shield::MLKEM768_ENC_KEY_LEN;

use codec::Decode;
use frame_support::pallet_prelude::DispatchError;
use frame_support::traits::{ConstBool, ConstU64};
use frame_support::{BoundedVec, construct_runtime, derive_impl, parameter_types};
use sp_consensus_aura::sr25519::AuthorityId as AuraId;
Expand Down Expand Up @@ -85,9 +87,20 @@ impl pallet_shield::FindAuthors<Test> for MockFindAuthors {
}
}

/// Mock decryptor that just decodes the bytes without decryption.
pub struct MockDecryptor;

impl pallet_shield::ExtrinsicDecryptor<RuntimeCall> for MockDecryptor {
fn decrypt(data: &[u8]) -> Result<RuntimeCall, DispatchError> {
RuntimeCall::decode(&mut &data[..]).map_err(|_| DispatchError::Other("decode failed"))
}
}

impl pallet_shield::Config for Test {
type AuthorityId = AuraId;
type FindAuthors = MockFindAuthors;
type RuntimeCall = RuntimeCall;
type ExtrinsicDecryptor = MockDecryptor;
}

pub fn new_test_ext() -> sp_io::TestExternalities {
Expand Down
Loading
Loading