Skip to content

Commit

Permalink
feat: add an optional way to handle contract errors with Result (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
itegulov authored Mar 10, 2022
1 parent 2a1091d commit e3c303f
Show file tree
Hide file tree
Showing 13 changed files with 258 additions and 41 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
## [unreleased]
### Features
- Added `Debug` and `PartialEq` implementations for `PromiseError`. [PR 728](https://github.com/near/near-sdk-rs/pull/728).

- Added convenience function `env::block_timestamp_ms` to return ms since 1970. [PR 736](https://github.com/near/near-sdk-rs/pull/728)
- Added an optional way to handle contract errors with `Result`. [PR 745](https://github.com/near/near-sdk-rs/pull/745).

## `4.0.0-pre.7` [02-02-2022]

Expand Down
36 changes: 2 additions & 34 deletions near-sdk-macros/src/core_impl/code_generator/attr_sig_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use proc_macro2::TokenStream as TokenStream2;
use crate::core_impl::info_extractor::{
ArgInfo, AttrSigInfo, BindgenArgType, InputStructType, SerializerType,
};
use crate::core_impl::utils;
use quote::quote;
use syn::{GenericArgument, Path, PathArguments, Type};

impl AttrSigInfo {
/// Create struct representing input arguments.
Expand Down Expand Up @@ -192,7 +192,7 @@ impl AttrSigInfo {
}
}
BindgenArgType::CallbackResultArg => {
let ok_type = if let Some(ok_type) = extract_ok_type(ty) {
let ok_type = if let Some(ok_type) = utils::extract_ok_type(ty) {
ok_type
} else {
return syn::Error::new_spanned(ty, "Function parameters marked with \
Expand Down Expand Up @@ -262,38 +262,6 @@ impl AttrSigInfo {
}
}

/// Checks whether the given path is literally "Result".
/// Note that it won't match a fully qualified name `core::result::Result` or a type alias like
/// `type StringResult = Result<String, String>`.
fn path_is_result(path: &Path) -> bool {
path.leading_colon.is_none()
&& path.segments.len() == 1
&& path.segments.iter().next().unwrap().ident == "Result"
}

/// Extracts the Ok type from a `Result` type.
///
/// For example, given `Result<String, u8>` type it will return `String` type.
fn extract_ok_type(ty: &Type) -> Option<&Type> {
match ty {
Type::Path(type_path) if type_path.qself.is_none() && path_is_result(&type_path.path) => {
// Get the first segment of the path (there should be only one, in fact: "Result"):
let type_params = &type_path.path.segments.first()?.arguments;
// We are interested in the first angle-bracketed param responsible for Ok type ("<String, _>"):
let generic_arg = match type_params {
PathArguments::AngleBracketed(params) => Some(params.args.first()?),
_ => None,
}?;
// This argument must be a type:
match generic_arg {
GenericArgument::Type(ty) => Some(ty),
_ => None,
}
}
_ => None,
}
}

pub fn deserialize_data(ty: &SerializerType) -> TokenStream2 {
match ty {
SerializerType::JSON => quote! {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::core_impl::info_extractor::{
AttrSigInfo, ImplItemMethodInfo, InputStructType, MethodType, SerializerType,
};
use crate::core_impl::utils;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{ReturnType, Signature};
Expand Down Expand Up @@ -53,6 +54,7 @@ impl ImplItemMethodInfo {
method_type,
is_payable,
is_private,
is_returns_result,
..
} = attr_signature_info;
let deposit_check = if *is_payable || matches!(method_type, &MethodType::View) {
Expand Down Expand Up @@ -136,6 +138,30 @@ impl ImplItemMethodInfo {
#method_invocation;
#contract_ser
},
ReturnType::Type(_, return_type)
if utils::type_is_result(return_type) && *is_returns_result =>
{
let value_ser = match result_serializer {
SerializerType::JSON => quote! {
let result = near_sdk::serde_json::to_vec(&result).expect("Failed to serialize the return value using JSON.");
},
SerializerType::Borsh => quote! {
let result = near_sdk::borsh::BorshSerialize::try_to_vec(&result).expect("Failed to serialize the return value using Borsh.");
},
};
quote! {
#contract_deser
let result = #method_invocation;
match result {
Ok(result) => {
#value_ser
near_sdk::env::value_return(&result);
#contract_ser
}
Err(err) => near_sdk::FunctionError::panic(&err)
}
}
}
ReturnType::Type(_, _) => {
let value_ser = match result_serializer {
SerializerType::JSON => quote! {
Expand All @@ -146,11 +172,11 @@ impl ImplItemMethodInfo {
},
};
quote! {
#contract_deser
let result = #method_invocation;
#value_ser
near_sdk::env::value_return(&result);
#contract_ser
#contract_deser
let result = #method_invocation;
#value_ser
near_sdk::env::value_return(&result);
#contract_ser
}
}
}
Expand Down
59 changes: 59 additions & 0 deletions near-sdk-macros/src/core_impl/code_generator/item_impl_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,63 @@ mod tests {
);
assert_eq!(expected.to_string(), actual.to_string());
}

#[test]
fn return_result_json() {
let impl_type: Type = syn::parse_str("Hello").unwrap();
let mut method: ImplItemMethod = parse_quote! {
#[return_result]
pub fn method(&self) -> Result<u64, &'static str> { }
};
let method_info = ImplItemMethodInfo::new(&mut method, impl_type).unwrap();
let actual = method_info.method_wrapper();
let expected = quote!(
#[cfg(target_arch = "wasm32")]
#[no_mangle]
pub extern "C" fn method() {
near_sdk::env::setup_panic_hook();
let contract: Hello = near_sdk::env::state_read().unwrap_or_default();
let result = contract.method();
match result {
Ok(result) => {
let result =
near_sdk::serde_json::to_vec(&result).expect("Failed to serialize the return value using JSON.");
near_sdk::env::value_return(&result);
}
Err(err) => near_sdk::FunctionError::panic(&err)
}
}
);
assert_eq!(expected.to_string(), actual.to_string());
}

#[test]
fn return_result_borsh() {
let impl_type: Type = syn::parse_str("Hello").unwrap();
let mut method: ImplItemMethod = parse_quote! {
#[return_result]
#[result_serializer(borsh)]
pub fn method(&self) -> Result<u64, &'static str> { }
};
let method_info = ImplItemMethodInfo::new(&mut method, impl_type).unwrap();
let actual = method_info.method_wrapper();
let expected = quote!(
#[cfg(target_arch = "wasm32")]
#[no_mangle]
pub extern "C" fn method() {
near_sdk::env::setup_panic_hook();
let contract: Hello = near_sdk::env::state_read().unwrap_or_default();
let result = contract.method();
match result {
Ok(result) => {
let result =
near_sdk::borsh::BorshSerialize::try_to_vec(&result).expect("Failed to serialize the return value using Borsh.");
near_sdk::env::value_return(&result);
}
Err(err) => near_sdk::FunctionError::panic(&err)
}
}
);
assert_eq!(expected.to_string(), actual.to_string());
}
}
7 changes: 7 additions & 0 deletions near-sdk-macros/src/core_impl/info_extractor/attr_sig_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub struct AttrSigInfo {
pub is_payable: bool,
/// Whether method can accept calls from self (current account)
pub is_private: bool,
/// Whether method returns Result type where only Ok type is serialized
pub is_returns_result: bool,
/// The serializer that we use for `env::input()`.
pub input_serializer: SerializerType,
/// The serializer that we use for the return type.
Expand Down Expand Up @@ -61,6 +63,7 @@ impl AttrSigInfo {
let mut method_type = MethodType::Regular;
let mut is_payable = false;
let mut is_private = false;
let mut is_returns_result = false;
// By the default we serialize the result with JSON.
let mut result_serializer = SerializerType::JSON;

Expand All @@ -87,6 +90,9 @@ impl AttrSigInfo {
let serializer: SerializerAttr = syn::parse2(attr.tokens.clone())?;
result_serializer = serializer.serializer_type;
}
"return_result" => {
is_returns_result = true;
}
_ => {
non_bindgen_attrs.push((*attr).clone());
}
Expand Down Expand Up @@ -136,6 +142,7 @@ impl AttrSigInfo {
method_type,
is_payable,
is_private,
is_returns_result,
result_serializer,
receiver,
returns,
Expand Down
1 change: 1 addition & 0 deletions near-sdk-macros/src/core_impl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod code_generator;
mod info_extractor;
mod metadata;
mod utils;
pub use code_generator::*;
pub use info_extractor::*;
pub use metadata::metadata_visitor::MetadataVisitor;
41 changes: 41 additions & 0 deletions near-sdk-macros/src/core_impl/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use syn::{GenericArgument, Path, PathArguments, Type};

/// Checks whether the given path is literally "Result".
/// Note that it won't match a fully qualified name `core::result::Result` or a type alias like
/// `type StringResult = Result<String, String>`.
pub(crate) fn path_is_result(path: &Path) -> bool {
path.leading_colon.is_none()
&& path.segments.len() == 1
&& path.segments.iter().next().unwrap().ident == "Result"
}

/// Equivalent to `path_is_result` except that it works on `Type` values.
pub(crate) fn type_is_result(ty: &Type) -> bool {
match ty {
Type::Path(type_path) if type_path.qself.is_none() => path_is_result(&type_path.path),
_ => false,
}
}

/// Extracts the Ok type from a `Result` type.
///
/// For example, given `Result<String, u8>` type it will return `String` type.
pub(crate) fn extract_ok_type(ty: &Type) -> Option<&Type> {
match ty {
Type::Path(type_path) if type_path.qself.is_none() && path_is_result(&type_path.path) => {
// Get the first segment of the path (there should be only one, in fact: "Result"):
let type_params = &type_path.path.segments.first()?.arguments;
// We are interested in the first angle-bracketed param responsible for Ok type ("<String, _>"):
let generic_arg = match type_params {
PathArguments::AngleBracketed(params) => Some(params.args.first()?),
_ => None,
}?;
// This argument must be a type:
match generic_arg {
GenericArgument::Type(ty) => Some(ty),
_ => None,
}
}
_ => None,
}
}
27 changes: 27 additions & 0 deletions near-sdk-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,30 @@ pub fn borsh_storage_key(item: TokenStream) -> TokenStream {
impl near_sdk::BorshIntoStorageKey for #name {}
})
}

/// `FunctionError` generates implementation for `near_sdk::FunctionError` trait.
/// It allows contract runtime to panic with the type using its `ToString` implementation
/// as the message.
#[proc_macro_derive(FunctionError)]
pub fn function_error(item: TokenStream) -> TokenStream {
let name = if let Ok(input) = syn::parse::<ItemEnum>(item.clone()) {
input.ident
} else if let Ok(input) = syn::parse::<ItemStruct>(item) {
input.ident
} else {
return TokenStream::from(
syn::Error::new(
Span::call_site(),
"FunctionError can only be used as a derive on enums or structs.",
)
.to_compile_error(),
);
};
TokenStream::from(quote! {
impl near_sdk::FunctionError for #name {
fn panic(&self) -> ! {
near_sdk::env::panic_str(&::std::string::ToString::to_string(&self))
}
}
})
}
1 change: 1 addition & 0 deletions near-sdk/compilation_tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ fn compilation_tests() {
t.pass("compilation_tests/cond_compilation.rs");
t.compile_fail("compilation_tests/payable_view.rs");
t.pass("compilation_tests/borsh_storage_key.rs");
t.pass("compilation_tests/function_error.rs");
}
50 changes: 50 additions & 0 deletions near-sdk/compilation_tests/function_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//! Testing FunctionError macro.
use borsh::{BorshDeserialize, BorshSerialize};
use near_sdk::{near_bindgen, FunctionError};
use std::fmt;

#[derive(FunctionError, BorshSerialize)]
struct ErrorStruct {
message: String,
}

impl fmt::Display for ErrorStruct {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "error ocurred: {}", self.message)
}
}

#[derive(FunctionError, BorshSerialize)]
enum ErrorEnum {
NotFound,
Banned { account_id: String },
}

impl fmt::Display for ErrorEnum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorEnum::NotFound => write!(f, "not found"),
ErrorEnum::Banned { account_id } => write!(f, "account {} is banned", account_id)
}
}
}

#[near_bindgen]
#[derive(BorshDeserialize, BorshSerialize, Default)]
struct Contract {}

#[near_bindgen]
impl Contract {
#[return_result]
pub fn set(&self, value: String) -> Result<String, ErrorStruct> {
Err(ErrorStruct { message: format!("Could not set to {}", value) })
}

#[return_result]
pub fn get(&self) -> Result<String, ErrorEnum> {
Err(ErrorEnum::NotFound)
}
}

fn main() {}
2 changes: 1 addition & 1 deletion near-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate quickcheck;

pub use near_sdk_macros::{
callback, callback_vec, ext_contract, init, metadata, near_bindgen, result_serializer,
serializer, BorshStorageKey, PanicOnDefault,
serializer, BorshStorageKey, FunctionError, PanicOnDefault,
};

#[cfg(feature = "unstable")]
Expand Down
Loading

0 comments on commit e3c303f

Please sign in to comment.