Skip to content

Commit

Permalink
lang: Allow CPI return values (#1598)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomlinton authored Mar 24, 2022
1 parent 0f7675c commit 1cb7429
Show file tree
Hide file tree
Showing 23 changed files with 471 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ jobs:
path: tests/custom-coder
- cmd: cd tests/validator-clone && yarn --frozen-lockfile && anchor test --skip-lint
path: tests/validator-clone
- cmd: cd tests/cpi-returns && anchor test --skip-lint
path: tests/cpi-returns
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/setup/
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The minor version will be incremented upon a breaking change and the patch versi

### Features

* lang: Add return values to CPI client. ([#1598](https://github.com/project-serum/anchor/pull/1598)).
* avm: New `avm update` command to update the Anchor CLI to the latest version ([#1670](https://github.com/project-serum/anchor/pull/1670)).

### Fixes
Expand Down
31 changes: 28 additions & 3 deletions lang/syn/src/codegen/program/cpi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::codegen::program::common::{generate_ix_variant, sighash, SIGHASH_GLOB
use crate::Program;
use crate::StateIx;
use heck::SnakeCase;
use quote::quote;
use quote::{quote, ToTokens};

pub fn generate(program: &Program) -> proc_macro2::TokenStream {
// Generate cpi methods for the state struct.
Expand Down Expand Up @@ -70,11 +70,20 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
let sighash_arr = sighash(SIGHASH_GLOBAL_NAMESPACE, name);
let sighash_tts: proc_macro2::TokenStream =
format!("{:?}", sighash_arr).parse().unwrap();
let ret_type = &ix.returns.ty.to_token_stream();
let (method_ret, maybe_return) = match ret_type.to_string().as_str() {
"()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
_ => (
quote! { anchor_lang::Result<crate::cpi::Return::<#ret_type>> },
quote! { Ok(crate::cpi::Return::<#ret_type> { phantom: crate::cpi::PhantomData }) }
)
};

quote! {
pub fn #method_name<'a, 'b, 'c, 'info>(
ctx: anchor_lang::context::CpiContext<'a, 'b, 'c, 'info, #accounts_ident<'info>>,
#(#args),*
) -> anchor_lang::Result<()> {
) -> #method_ret {
let ix = {
let ix = instruction::#ix_variant;
let mut ix_data = AnchorSerialize::try_to_vec(&ix)
Expand All @@ -93,7 +102,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
&ix,
&acc_infos,
ctx.signer_seeds,
).map_err(Into::into)
).map_or_else(
|e| Err(Into::into(e)),
// Maybe handle Solana return data.
|_| { #maybe_return }
)
}
}
};
Expand All @@ -108,13 +121,25 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#[cfg(feature = "cpi")]
pub mod cpi {
use super::*;
use std::marker::PhantomData;

pub mod state {
use super::*;

#(#state_cpi_methods)*
}

pub struct Return<T> {
phantom: std::marker::PhantomData<T>
}

impl<T: AnchorDeserialize> Return<T> {
pub fn get(&self) -> T {
let (_key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
T::try_from_slice(&data).unwrap()
}
}

#(#global_cpi_methods)*

#accounts
Expand Down
14 changes: 12 additions & 2 deletions lang/syn/src/codegen/program/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::codegen::program::common::*;
use crate::{Program, State};
use heck::CamelCase;
use quote::quote;
use quote::{quote, ToTokens};

// Generate non-inlined wrappers for each instruction handler, since Solana's
// BPF max stack size can't handle reasonable sized dispatch trees without doing
Expand Down Expand Up @@ -694,6 +694,13 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
let anchor = &ix.anchor_ident;
let variant_arm = generate_ix_variant(ix.raw_method.sig.ident.to_string(), &ix.args);
let ix_name_log = format!("Instruction: {}", ix_name);
let ret_type = &ix.returns.ty.to_token_stream();
let maybe_set_return_data = match ret_type.to_string().as_str() {
"()" => quote! {},
_ => quote! {
anchor_lang::solana_program::program::set_return_data(&result.try_to_vec().unwrap());
},
};
quote! {
#[inline(never)]
pub fn #ix_method_name(
Expand Down Expand Up @@ -722,7 +729,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
)?;

// Invoke user defined handler.
#program_name::#ix_method_name(
let result = #program_name::#ix_method_name(
anchor_lang::context::Context::new(
program_id,
&mut accounts,
Expand All @@ -732,6 +739,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#(#ix_arg_names),*
)?;

// Maybe set Solana return data.
#maybe_set_return_data

// Exit routine.
accounts.exit(program_id)
}
Expand Down
8 changes: 8 additions & 0 deletions lang/syn/src/idl/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub fn parse(
name,
accounts,
args,
returns: None,
}
})
.collect::<Vec<_>>()
Expand Down Expand Up @@ -105,6 +106,7 @@ pub fn parse(
name,
accounts,
args,
returns: None,
}
};

Expand Down Expand Up @@ -164,10 +166,16 @@ pub fn parse(
// todo: don't unwrap
let accounts_strct = accs.get(&ix.anchor_ident.to_string()).unwrap();
let accounts = idl_accounts(&ctx, accounts_strct, &accs, seeds_feature);
let ret_type_str = ix.returns.ty.to_token_stream().to_string();
let returns = match ret_type_str.as_str() {
"()" => None,
_ => Some(ret_type_str.parse().unwrap()),
};
IdlInstruction {
name: ix.ident.to_string().to_mixed_case(),
accounts,
args,
returns,
}
})
.collect::<Vec<_>>();
Expand Down
1 change: 1 addition & 0 deletions lang/syn/src/idl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub struct IdlInstruction {
pub name: String,
pub accounts: Vec<IdlAccountItem>,
pub args: Vec<IdlField>,
pub returns: Option<IdlType>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
Expand Down
8 changes: 7 additions & 1 deletion lang/syn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{
Expr, Generics, Ident, ImplItemMethod, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStruct, LitInt,
LitStr, PatType, Token, TypePath,
LitStr, PatType, Token, Type, TypePath,
};

pub mod codegen;
Expand Down Expand Up @@ -85,6 +85,7 @@ pub struct Ix {
pub raw_method: ItemFn,
pub ident: Ident,
pub args: Vec<IxArg>,
pub returns: IxReturn,
// The ident for the struct deriving Accounts.
pub anchor_ident: Ident,
}
Expand All @@ -95,6 +96,11 @@ pub struct IxArg {
pub raw_arg: PatType,
}

#[derive(Debug)]
pub struct IxReturn {
pub ty: Type,
}

#[derive(Debug)]
pub struct FallbackFn {
raw_method: ItemFn,
Expand Down
35 changes: 34 additions & 1 deletion lang/syn/src/parser/program/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::parser::program::ctx_accounts_ident;
use crate::{FallbackFn, Ix, IxArg};
use crate::{FallbackFn, Ix, IxArg, IxReturn};
use syn::parse::{Error as ParseError, Result as ParseResult};
use syn::spanned::Spanned;

Expand All @@ -23,12 +23,14 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
})
.map(|method: &syn::ItemFn| {
let (ctx, args) = parse_args(method)?;
let returns = parse_return(method)?;
let anchor_ident = ctx_accounts_ident(&ctx.raw_arg)?;
Ok(Ix {
raw_method: method.clone(),
ident: method.sig.ident.clone(),
args,
anchor_ident,
returns,
})
})
.collect::<ParseResult<Vec<Ix>>>()?;
Expand Down Expand Up @@ -91,3 +93,34 @@ pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {

Ok((ctx, args))
}

pub fn parse_return(method: &syn::ItemFn) -> ParseResult<IxReturn> {
match method.sig.output {
syn::ReturnType::Type(_, ref ty) => {
let ty = match ty.as_ref() {
syn::Type::Path(ty) => ty,
_ => return Err(ParseError::new(ty.span(), "expected a return type")),
};
// Assume unit return by default
let default_generic_arg = syn::GenericArgument::Type(syn::parse_str("()").unwrap());
let generic_args = match &ty.path.segments.last().unwrap().arguments {
syn::PathArguments::AngleBracketed(params) => params.args.iter().last().unwrap(),
_ => &default_generic_arg,
};
let ty = match generic_args {
syn::GenericArgument::Type(ty) => ty.clone(),
_ => {
return Err(ParseError::new(
ty.span(),
"expected generic return type to be a type",
))
}
};
Ok(IxReturn { ty })
}
_ => Err(ParseError::new(
method.sig.output.span(),
"expected a return type",
)),
}
}
6 changes: 6 additions & 0 deletions tests/cpi-returns/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

.anchor
.DS_Store
target
**/*.rs.bk
node_modules
16 changes: 16 additions & 0 deletions tests/cpi-returns/Anchor.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[features]
seeds = false

[programs.localnet]
callee = "Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS"
caller = "HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L"

[registry]
url = "https://anchor.projectserum.com"

[provider]
cluster = "localnet"
wallet = "~/.config/solana/id.json"

[scripts]
test = "yarn run ts-mocha -p ./tsconfig.json -t 1000000 tests/**/*.ts"
4 changes: 4 additions & 0 deletions tests/cpi-returns/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[workspace]
members = [
"programs/*"
]
12 changes: 12 additions & 0 deletions tests/cpi-returns/migrations/deploy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Migrations are an early feature. Currently, they're nothing more than this
// single deploy script that's invoked from the CLI, injecting a provider
// configured from the workspace's Anchor.toml.

const anchor = require("@project-serum/anchor");

module.exports = async function (provider) {
// Configure client to use the provider.
anchor.setProvider(provider);

// Add your deploy script here.
};
19 changes: 19 additions & 0 deletions tests/cpi-returns/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"name": "cpi-returns",
"version": "0.23.0",
"license": "(MIT OR Apache-2.0)",
"homepage": "https://github.com/project-serum/anchor#readme",
"bugs": {
"url": "https://github.com/project-serum/anchor/issues"
},
"repository": {
"type": "git",
"url": "https://github.com/project-serum/anchor.git"
},
"engines": {
"node": ">=11"
},
"scripts": {
"test": "anchor run test-with-build"
}
}
19 changes: 19 additions & 0 deletions tests/cpi-returns/programs/callee/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "callee"
version = "0.1.0"
description = "Created with Anchor"
edition = "2018"

[lib]
crate-type = ["cdylib", "lib"]
name = "callee"

[features]
no-entrypoint = []
no-idl = []
no-log-ix-name = []
cpi = ["no-entrypoint"]
default = []

[dependencies]
anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] }
2 changes: 2 additions & 0 deletions tests/cpi-returns/programs/callee/Xargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[target.bpfel-unknown-unknown.dependencies.std]
features = []
50 changes: 50 additions & 0 deletions tests/cpi-returns/programs/callee/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use anchor_lang::prelude::*;

declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");

#[program]
pub mod callee {
use super::*;

#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct StructReturn {
pub value: u64,
}

pub fn initialize(_ctx: Context<Initialize>) -> Result<()> {
Ok(())
}

pub fn return_u64(_ctx: Context<CpiReturn>) -> Result<u64> {
Ok(10)
}

pub fn return_struct(_ctx: Context<CpiReturn>) -> Result<StructReturn> {
let s = StructReturn { value: 11 };
Ok(s)
}

pub fn return_vec(_ctx: Context<CpiReturn>) -> Result<Vec<u8>> {
Ok(vec![12, 13, 14, 100])
}
}

#[derive(Accounts)]
pub struct Initialize<'info> {
#[account(init, payer = user, space = 8 + 8)]
pub account: Account<'info, CpiReturnAccount>,
#[account(mut)]
pub user: Signer<'info>,
pub system_program: Program<'info, System>,
}

#[derive(Accounts)]
pub struct CpiReturn<'info> {
#[account(mut)]
pub account: Account<'info, CpiReturnAccount>,
}

#[account]
pub struct CpiReturnAccount {
pub value: u64,
}
Loading

0 comments on commit 1cb7429

Please sign in to comment.