diff --git a/node/src/benchmarking.rs b/node/src/benchmarking.rs index d0c0ac9a40..15023c1245 100644 --- a/node/src/benchmarking.rs +++ b/node/src/benchmarking.rs @@ -141,6 +141,7 @@ pub fn create_benchmark_extrinsic( pallet_shield::CheckShieldedTxValidity::::new(), pallet_subtensor::SubtensorTransactionExtension::::new(), pallet_drand::drand_priority::DrandPriority::::new(), + pallet_commitments::CommitmentsTransactionExtension::::new(), ), frame_metadata_hash_extension::CheckMetadataHash::::new(true), ); @@ -158,7 +159,7 @@ pub fn create_benchmark_extrinsic( (), (), ), - ((), (), (), (), ()), + ((), (), (), (), (), ()), None, ), ); diff --git a/pallets/commitments/src/lib.rs b/pallets/commitments/src/lib.rs index 98e5961708..ee78910044 100644 --- a/pallets/commitments/src/lib.rs +++ b/pallets/commitments/src/lib.rs @@ -12,17 +12,27 @@ pub mod weights; use ark_serialize::CanonicalDeserialize; use codec::Encode; -use frame_support::IterableStorageDoubleMap; use frame_support::{ - BoundedVec, - traits::{Currency, Get}, + BoundedVec, IterableStorageDoubleMap, + dispatch::{DispatchErrorWithPostInfo, DispatchExtension, DispatchInfo, PostDispatchInfo}, + pallet_prelude::{ + Decode, DecodeWithMemTracking, PhantomData, ValidTransaction, ValidateResult, + }, + traits::{Currency, Get, IsSubType, OriginTrait}, }; use frame_system::pallet_prelude::BlockNumberFor; pub use pallet::*; use scale_info::prelude::collections::BTreeSet; -use sp_runtime::SaturatedConversion; -use sp_runtime::{Saturating, Weight, traits::Zero}; +use sp_runtime::{ + SaturatedConversion, Saturating, Weight, + traits::Zero, + traits::{ + AsSystemOriginSigner, DispatchInfoOf, Dispatchable, Implication, TransactionExtension, + }, + transaction_validity::{InvalidTransaction, TransactionSource, TransactionValidityError}, +}; use sp_std::{boxed::Box, vec::Vec}; +use subtensor_macros::freeze_struct; use subtensor_runtime_common::NetUid; use tle::{ curves::drand::TinyBLS381, @@ -576,6 +586,117 @@ impl Pallet { } } +type CallOf = ::RuntimeCall; +type OriginOf = as Dispatchable>::RuntimeOrigin; + +#[derive(Default, Encode, Decode, DecodeWithMemTracking, Clone, Eq, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +#[freeze_struct("7f03f99666ee2c4f")] +pub struct CommitmentsTransactionExtension(PhantomData); + +impl sp_std::fmt::Debug for CommitmentsTransactionExtension { + fn fmt(&self, f: &mut sp_std::fmt::Formatter) -> sp_std::fmt::Result { + write!(f, "CommitmentsTransactionExtension") + } +} + +impl CommitmentsTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + CallOf: Dispatchable + IsSubType>, +{ + pub fn new() -> Self { + Self(Default::default()) + } +} + +impl TransactionExtension> for CommitmentsTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + CallOf: Dispatchable + IsSubType>, + OriginOf: AsSystemOriginSigner + Clone, +{ + const IDENTIFIER: &'static str = "CommitmentsTransactionExtension"; + + type Implicit = (); + type Val = (); + type Pre = (); + + fn weight(&self, _call: &CallOf) -> Weight { + Weight::from_parts(0, 0) + } + + fn validate( + &self, + origin: OriginOf, + call: &CallOf, + _info: &DispatchInfoOf>, + _len: usize, + _self_implicit: Self::Implicit, + _inherited_implication: &impl Implication, + _source: TransactionSource, + ) -> ValidateResult> { + let Some(who) = origin.as_system_origin_signer() else { + return Ok((ValidTransaction::default(), (), origin)); + }; + + match call.is_sub_type() { + Some(pallet::Call::set_commitment { netuid, .. }) => { + if !T::CanCommit::can_commit(*netuid, who) { + return Err(InvalidTransaction::BadSigner.into()); + } + + Ok((ValidTransaction::default(), (), origin)) + } + _ => Ok((ValidTransaction::default(), (), origin)), + } + } + + fn prepare( + self, + _val: Self::Val, + _origin: &OriginOf, + _call: &CallOf, + _info: &DispatchInfoOf>, + _len: usize, + ) -> Result { + Ok(()) + } +} + +pub struct CommitmentsDispatchExtension(PhantomData); + +impl DispatchExtension> for CommitmentsDispatchExtension +where + T: Config, + CallOf: + Dispatchable + IsSubType>, + OriginOf: OriginTrait, +{ + type Pre = (); + + fn weight(_call: &CallOf) -> Weight { + T::DbWeight::get().reads(1) + } + + fn pre_dispatch( + origin: &OriginOf, + call: &CallOf, + ) -> Result { + let Some(who) = origin.as_signer() else { + return Ok(()); + }; + + if let Some(pallet::Call::set_commitment { netuid, .. }) = call.is_sub_type() { + if !T::CanCommit::can_commit(*netuid, who) { + return Err(Error::::AccountNotAllowedCommit.into()); + } + } + + Ok(()) + } +} + pub trait GetCommitments { fn get_commitments(netuid: NetUid) -> Vec<(AccountId, Vec)>; } diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 76645e790a..da449a2164 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -22,10 +22,13 @@ use codec::{Compact, Decode, Encode}; use ethereum::AuthorizationList; use frame_support::{ PalletId, - dispatch::DispatchResult, + dispatch::{ + DispatchErrorWithPostInfo, DispatchExtension, DispatchInfo, DispatchResult, + PostDispatchInfo, + }, genesis_builder_helper::{build_state, get_preset}, pallet_prelude::Get, - traits::{Contains, InsideBoth, LinearStoragePrice, fungible::HoldConsideration}, + traits::{Contains, InsideBoth, LinearStoragePrice, OriginTrait, fungible::HoldConsideration}, }; use frame_system::{EnsureRoot, EnsureRootWithSuccess, EnsureSigned}; use pallet_commitments::{CanCommit, OnMetadataCommitment}; @@ -204,6 +207,7 @@ impl frame_system::offchain::CreateSignedTransaction pallet_shield::CheckShieldedTxValidity::::new(), pallet_subtensor::SubtensorTransactionExtension::::new(), pallet_drand::drand_priority::DrandPriority::::new(), + pallet_commitments::CommitmentsTransactionExtension::::new(), ), frame_metadata_hash_extension::CheckMetadataHash::::new(true), ); @@ -381,7 +385,7 @@ impl frame_system::Config for Runtime { type PostInherents = (); type PostTransactions = (); type ExtensionsWeightInfo = frame_system::SubstrateExtensionsWeight; - type DispatchExtension = pallet_subtensor::CheckColdkeySwap; + type DispatchExtension = RuntimeDispatchExtension; } impl pallet_insecure_randomness_collective_flip::Config for Runtime {} @@ -1681,6 +1685,7 @@ pub type CustomTxExtension = ( pallet_shield::CheckShieldedTxValidity, pallet_subtensor::SubtensorTransactionExtension, pallet_drand::drand_priority::DrandPriority, + pallet_commitments::CommitmentsTransactionExtension, ); pub type TxExtension = ( SystemTxExtension, @@ -1719,6 +1724,47 @@ pub type Executive = frame_executive::Executive< Migrations, >; +type RuntimeDispatchableOrigin = ::RuntimeOrigin; +type ColdkeySwapDispatchPre = + as DispatchExtension>::Pre; +type CommitmentsDispatchPre = + as DispatchExtension>::Pre; + +pub struct RuntimeDispatchExtension; + +impl DispatchExtension for RuntimeDispatchExtension +where + RuntimeCall: Dispatchable, + RuntimeDispatchableOrigin: OriginTrait, +{ + type Pre = (ColdkeySwapDispatchPre, CommitmentsDispatchPre); + + fn weight(call: &RuntimeCall) -> Weight { + as DispatchExtension>::weight( + call, + ) + .saturating_add( + as DispatchExtension< + RuntimeCall, + >>::weight(call), + ) + } + + fn pre_dispatch( + origin: &RuntimeDispatchableOrigin, + call: &RuntimeCall, + ) -> Result { + let coldkey_swap_pre = + as DispatchExtension>::pre_dispatch(origin, call)?; + let commitments_pre = + as DispatchExtension< + RuntimeCall, + >>::pre_dispatch(origin, call)?; + + Ok((coldkey_swap_pre, commitments_pre)) + } +} + #[cfg(feature = "runtime-benchmarks")] #[macro_use] extern crate frame_benchmarking;