Skip to content

Commit

Permalink
prevent nested DelayedOrigin
Browse files Browse the repository at this point in the history
  • Loading branch information
xlc committed Nov 21, 2022
1 parent 740f0b8 commit a93caac
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
86 changes: 85 additions & 1 deletion authority/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use frame_support::{
};
use frame_system::{pallet_prelude::*, EnsureRoot, EnsureSigned};
use scale_info::TypeInfo;
use sp_core::defer;
use sp_runtime::{
traits::{CheckedSub, Dispatchable, Hash, Saturating},
ArithmeticError, DispatchError, DispatchResult, Either, RuntimeDebug,
Expand All @@ -45,14 +46,97 @@ mod weights;
pub use weights::WeightInfo;

/// A delayed origin. Can only be dispatched via `dispatch_as` with a delay.
#[derive(PartialEq, Eq, Clone, RuntimeDebug, Encode, Decode, TypeInfo, MaxEncodedLen)]
#[derive(PartialEq, Eq, Clone, RuntimeDebug, Encode, TypeInfo)]
pub struct DelayedOrigin<BlockNumber, PalletsOrigin> {
/// Number of blocks that this call have been delayed.
pub delay: BlockNumber,
/// The initial origin.
pub origin: Box<PalletsOrigin>,
}

#[cfg(feature = "std")]
mod helper {
use std::cell::RefCell;

thread_local! {
static NESTED_DECODE: RefCell<bool> = RefCell::new(false);
static NESTED_MAX_ENCODED_LEN: RefCell<bool> = RefCell::new(false);
}

pub fn set_nested_decode(val: bool) {
NESTED_DECODE.with(|v| *v.borrow_mut() = val);
}

pub fn nested_decode() -> bool {
NESTED_DECODE.with(|v| *v.borrow())
}

pub fn set_nested_max_encoded_len(val: bool) {
NESTED_MAX_ENCODED_LEN.with(|v| *v.borrow_mut() = val);
}

pub fn nested_max_encoded_len() -> bool {
NESTED_MAX_ENCODED_LEN.with(|v| *v.borrow())
}
}

#[cfg(not(feature = "std"))]
mod helper {
static mut NESTED_DECODE: bool = false;
static mut NESTED_MAX_ENCODED_LEN: bool = false;

pub fn set_nested_decode(val: bool) {
unsafe {
NESTED_DECODE = val;
}
}

pub fn nested_decode() -> bool {
unsafe { NESTED_DECODE }
}

pub fn set_nested_max_encoded_len(val: bool) {
unsafe {
NESTED_MAX_ENCODED_LEN = val;
}
}

pub fn nested_max_encoded_len() -> bool {
unsafe { NESTED_MAX_ENCODED_LEN }
}
}

impl<BlockNumber: Decode, PalletsOrigin: Decode> Decode for DelayedOrigin<BlockNumber, PalletsOrigin> {
fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
if helper::nested_decode() {
return Err("Nested DelayedOrigin::decode is not allowed".into());
}

helper::set_nested_decode(true);
defer!(helper::set_nested_decode(false));

Ok(DelayedOrigin {
delay: Decode::decode(input)?,
origin: Decode::decode(input)?,
})
}
}

impl<BlockNumber: MaxEncodedLen, PalletsOrigin: MaxEncodedLen> MaxEncodedLen
for DelayedOrigin<BlockNumber, PalletsOrigin>
{
fn max_encoded_len() -> usize {
if helper::nested_max_encoded_len() {
return 0;
}

helper::set_nested_max_encoded_len(true);
defer!(helper::set_nested_max_encoded_len(false));

BlockNumber::max_encoded_len() + PalletsOrigin::max_encoded_len()
}
}

/// Ensure the origin have a minimum amount of delay.
pub struct EnsureDelayed<Delay, Inner, BlockNumber, PalletsOrigin>(
sp_std::marker::PhantomData<(Delay, Inner, BlockNumber, PalletsOrigin)>,
Expand Down
21 changes: 21 additions & 0 deletions authority/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#![cfg(test)]

use super::*;
use codec::MaxEncodedLen;
use frame_support::{
assert_noop, assert_ok,
dispatch::DispatchErrorWithPostInfo,
Expand Down Expand Up @@ -691,3 +692,23 @@ fn trigger_old_call_should_be_free_and_operational() {
);
});
}

#[test]
fn origin_max_encoded_len_works() {
assert_eq!(DelayedOrigin::<u32, OriginCaller>::max_encoded_len(), 22);
assert_eq!(OriginCaller::max_encoded_len(), 27);
}

#[test]
fn nested_delayed_origin_is_invalid() {
let orgin = DelayedOrigin::<u64, OriginCaller> {
delay: 1u64,
origin: Box::new(OriginCaller::Authority(DelayedOrigin::<u64, OriginCaller> {
delay: 1u64,
origin: Box::new(OriginCaller::system(frame_system::RawOrigin::Root)),
})),
};

let encoded = orgin.encode();
assert!(OriginCaller::decode(&mut &encoded[..]).is_err());
}

0 comments on commit a93caac

Please sign in to comment.