diff --git a/CHANGELOG.md b/CHANGELOG.md index 51a0b5e3af..9f9d4f5554 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The minor version will be incremented upon a breaking change and the patch versi - cli, idl: Pass `cargo` args to IDL generation when building program or IDL ([#3059](https://github.com/coral-xyz/anchor/pull/3059)). - cli: Add checks for incorrect usage of `idl-build` feature ([#3061](https://github.com/coral-xyz/anchor/pull/3061)). - lang: Export `Discriminator` trait from `prelude` ([#3075](https://github.com/coral-xyz/anchor/pull/3075)). +- lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)). ### Fixes diff --git a/lang/attribute/program/src/declare_program/mods/utils.rs b/lang/attribute/program/src/declare_program/mods/utils.rs index 65c6c95fe3..9ba325c29a 100644 --- a/lang/attribute/program/src/declare_program/mods/utils.rs +++ b/lang/attribute/program/src/declare_program/mods/utils.rs @@ -4,16 +4,72 @@ use quote::{format_ident, quote}; use super::common::gen_discriminator; pub fn gen_utils_mod(idl: &Idl) -> proc_macro2::TokenStream { + let account = gen_account(idl); let event = gen_event(idl); quote! { /// Program utilities. pub mod utils { + use super::*; + + #account #event } } } +fn gen_account(idl: &Idl) -> proc_macro2::TokenStream { + let variants = idl + .accounts + .iter() + .map(|acc| format_ident!("{}", acc.name)) + .map(|name| quote! { #name(#name) }); + let match_arms = idl.accounts.iter().map(|acc| { + let disc = gen_discriminator(&acc.discriminator); + let name = format_ident!("{}", acc.name); + let account = quote! { + #name::try_from_slice(&value[8..]) + .map(Self::#name) + .map_err(Into::into) + }; + quote! { #disc => #account } + }); + + quote! { + /// An enum that includes all accounts of the declared program as a tuple variant. + /// + /// See [`Self::try_from_bytes`] to create an instance from bytes. + pub enum Account { + #(#variants,)* + } + + impl Account { + /// Try to create an account based on the given bytes. + /// + /// This method returns an error if the discriminator of the given bytes don't match + /// with any of the existing accounts, or if the deserialization fails. + pub fn try_from_bytes(bytes: &[u8]) -> Result { + Self::try_from(bytes) + } + } + + impl TryFrom<&[u8]> for Account { + type Error = anchor_lang::error::Error; + + fn try_from(value: &[u8]) -> Result { + if value.len() < 8 { + return Err(ProgramError::InvalidArgument.into()); + } + + match &value[..8] { + #(#match_arms,)* + _ => Err(ProgramError::InvalidArgument.into()), + } + } + } + } +} + fn gen_event(idl: &Idl) -> proc_macro2::TokenStream { let variants = idl .events @@ -32,8 +88,6 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream { }); quote! { - use super::*; - /// An enum that includes all events of the declared program as a tuple variant. /// /// See [`Self::try_from_bytes`] to create an instance from bytes. diff --git a/tests/declare-program/programs/declare-program/src/lib.rs b/tests/declare-program/programs/declare-program/src/lib.rs index 01df3b22a8..e3dd6cd18a 100644 --- a/tests/declare-program/programs/declare-program/src/lib.rs +++ b/tests/declare-program/programs/declare-program/src/lib.rs @@ -52,6 +52,30 @@ pub mod declare_program { Ok(()) } + pub fn account_utils(_ctx: Context) -> Result<()> { + use external::utils::Account; + + // Empty + if Account::try_from_bytes(&[]).is_ok() { + return Err(ProgramError::Custom(0).into()); + } + + const DISC: &[u8] = &external::accounts::MyAccount::DISCRIMINATOR; + + // Correct discriminator but invalid data + if Account::try_from_bytes(DISC).is_ok() { + return Err(ProgramError::Custom(1).into()); + }; + + // Correct discriminator and valid data + match Account::try_from_bytes(&[DISC, &[1, 0, 0, 0]].concat()) { + Ok(Account::MyAccount(my_account)) => require_eq!(my_account.field, 1), + Err(e) => return Err(e.into()), + } + + Ok(()) + } + pub fn event_utils(_ctx: Context) -> Result<()> { use external::utils::Event; diff --git a/tests/declare-program/tests/declare-program.ts b/tests/declare-program/tests/declare-program.ts index de73ab25ed..0f3f684416 100644 --- a/tests/declare-program/tests/declare-program.ts +++ b/tests/declare-program/tests/declare-program.ts @@ -47,6 +47,10 @@ describe("declare-program", () => { assert.strictEqual(myAccount.field, value); }); + it("Can use account utils", async () => { + await program.methods.accountUtils().rpc(); + }); + it("Can use event utils", async () => { await program.methods.eventUtils().rpc(); });