Skip to content

Commit

Permalink
lang: Add init_with_needed keyword (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
armaniferrante authored Oct 21, 2021
1 parent d41fb4f commit 95bb9b3
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 73 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ incremented for features.
* ts: `Program<T>` can now be typed with an IDL type ([#795](https://github.com/project-serum/anchor/pull/795)).
* lang: Add `mint::freeze_authority` keyword for mint initialization within `#[derive(Accounts)]` ([#835](https://github.com/project-serum/anchor/pull/835)).
* lang: Add `AccountLoader` type for `zero_copy` accounts with support for CPI ([#792](https://github.com/project-serum/anchor/pull/792)).
* lang: Add `#[account(init_if_needed)]` keyword for allowing one to invoke the same instruction even if the account was created already ([#906](https://github.com/project-serum/anchor/pull/906)).
* lang: Add custom errors support for raw constraints ([#905](https://github.com/project-serum/anchor/pull/905)).

### Breaking
Expand Down
112 changes: 64 additions & 48 deletions lang/syn/src/codegen/accounts/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
}
}
};
generate_init(f, seeds_with_nonce, payer, &c.space, &c.kind)
generate_init(f, c.if_needed, seeds_with_nonce, payer, &c.space, &c.kind)
}

fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
Expand Down Expand Up @@ -397,8 +397,10 @@ fn generate_constraint_associated_token(
}
}

// `if_needed` is set if account allocation and initialization is optional.
pub fn generate_init(
f: &Field,
if_needed: bool,
seeds_with_nonce: proc_macro2::TokenStream,
payer: proc_macro2::TokenStream,
space: &Option<Expr>,
Expand All @@ -407,6 +409,11 @@ pub fn generate_init(
let field = &f.ident;
let ty_decl = f.ty_decl();
let from_account_info = f.from_account_info_unchecked(Some(kind));
let if_needed = if if_needed {
quote! {true}
} else {
quote! {false}
};
match kind {
InitKind::Token { owner, mint } => {
let create_account = generate_create_account(
Expand All @@ -417,22 +424,25 @@ pub fn generate_init(
);
quote! {
let #field: #ty_decl = {
// Define payer variable.
#payer

// Create the account with the system program.
#create_account

// Initialize the token account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeAccount {
account: #field.to_account_info(),
mint: #mint.to_account_info(),
authority: #owner.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_account(cpi_ctx)?;
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
// Define payer variable.
#payer

// Create the account with the system program.
#create_account

// Initialize the token account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeAccount {
account: #field.to_account_info(),
mint: #mint.to_account_info(),
authority: #owner.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_account(cpi_ctx)?;
}

let pa: #ty_decl = #from_account_info;
pa
};
Expand All @@ -441,20 +451,22 @@ pub fn generate_init(
InitKind::AssociatedToken { owner, mint } => {
quote! {
let #field: #ty_decl = {
#payer

let cpi_program = associated_token_program.to_account_info();
let cpi_accounts = anchor_spl::associated_token::Create {
payer: payer.to_account_info(),
associated_token: #field.to_account_info(),
authority: #owner.to_account_info(),
mint: #mint.to_account_info(),
system_program: system_program.to_account_info(),
token_program: token_program.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
anchor_spl::associated_token::create(cpi_ctx)?;
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
#payer

let cpi_program = associated_token_program.to_account_info();
let cpi_accounts = anchor_spl::associated_token::Create {
payer: payer.to_account_info(),
associated_token: #field.to_account_info(),
authority: #owner.to_account_info(),
mint: #mint.to_account_info(),
system_program: system_program.to_account_info(),
token_program: token_program.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
anchor_spl::associated_token::create(cpi_ctx)?;
}
let pa: #ty_decl = #from_account_info;
pa
};
Expand All @@ -477,20 +489,22 @@ pub fn generate_init(
};
quote! {
let #field: #ty_decl = {
// Define payer variable.
#payer

// Create the account with the system program.
#create_account

// Initialize the mint account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeMint {
mint: #field.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
// Define payer variable.
#payer

// Create the account with the system program.
#create_account

// Initialize the mint account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeMint {
mint: #field.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
}
let pa: #ty_decl = #from_account_info;
pa
};
Expand Down Expand Up @@ -535,9 +549,11 @@ pub fn generate_init(
generate_create_account(field, quote! {space}, owner, seeds_with_nonce);
quote! {
let #field = {
#space
#payer
#create_account
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
#space
#payer
#create_account
}
let pa: #ty_decl = #from_account_info;
pa
};
Expand Down
8 changes: 7 additions & 1 deletion lang/syn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,12 @@ impl Parse for ConstraintToken {
}

#[derive(Debug, Clone)]
pub struct ConstraintInit {}
pub struct ConstraintInit {
pub if_needed: bool,
}

#[derive(Debug, Clone)]
pub struct ConstraintInitIfNeeded {}

#[derive(Debug, Clone)]
pub struct ConstraintZeroed {}
Expand Down Expand Up @@ -639,6 +644,7 @@ pub enum ConstraintRentExempt {

#[derive(Debug, Clone)]
pub struct ConstraintInitGroup {
pub if_needed: bool,
pub seeds: Option<ConstraintSeedsGroup>,
pub payer: Option<Expr>,
pub space: Option<Expr>,
Expand Down
12 changes: 10 additions & 2 deletions lang/syn/src/parser/accounts/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
let kw = ident.to_string();

let c = match kw.as_str() {
"init" => ConstraintToken::Init(Context::new(ident.span(), ConstraintInit {})),
"init" => ConstraintToken::Init(Context::new(
ident.span(),
ConstraintInit { if_needed: false },
)),
"init_if_needed" => ConstraintToken::Init(Context::new(
ident.span(),
ConstraintInit { if_needed: true },
)),
"zero" => ConstraintToken::Zeroed(Context::new(ident.span(), ConstraintZeroed {})),
"mut" => ConstraintToken::Mut(Context::new(ident.span(), ConstraintMut {})),
"signer" => ConstraintToken::Signer(Context::new(ident.span(), ConstraintSigner {})),
Expand Down Expand Up @@ -518,7 +525,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
_ => None,
};
Ok(ConstraintGroup {
init: init.as_ref().map(|_| Ok(ConstraintInitGroup {
init: init.as_ref().map(|i| Ok(ConstraintInitGroup {
if_needed: i.if_needed,
seeds: seeds.clone(),
payer: into_inner!(payer.clone()).map(|a| a.target),
space: space.clone().map(|s| s.space.clone()),
Expand Down
8 changes: 8 additions & 0 deletions tests/misc/programs/misc/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,11 @@ pub struct TestEmptySeedsConstraint<'info> {
#[account(seeds = [], bump)]
pub pda: AccountInfo<'info>,
}

#[derive(Accounts)]
pub struct TestInitIfNeeded<'info> {
#[account(init_if_needed, payer = payer)]
pub data: Account<'info, DataU16>,
pub payer: Signer<'info>,
pub system_program: Program<'info, System>,
}
5 changes: 5 additions & 0 deletions tests/misc/programs/misc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,9 @@ pub mod misc {
pub fn test_empty_seeds_constraint(ctx: Context<TestEmptySeedsConstraint>) -> ProgramResult {
Ok(())
}

pub fn test_init_if_needed(ctx: Context<TestInitIfNeeded>, data: u16) -> ProgramResult {
ctx.accounts.data.data = data;
Ok(())
}
}
79 changes: 57 additions & 22 deletions tests/misc/tests/misc.js
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ describe("misc", () => {
accounts: {
token: associatedToken,
mint: mint.publicKey,
wallet: program.provider.wallet.publicKey
}
wallet: program.provider.wallet.publicKey,
},
});

await assert.rejects(
Expand All @@ -662,8 +662,8 @@ describe("misc", () => {
accounts: {
token: associatedToken,
mint: mint.publicKey,
wallet: anchor.web3.Keypair.generate().publicKey
}
wallet: anchor.web3.Keypair.generate().publicKey,
},
});
},
(err) => {
Expand Down Expand Up @@ -735,30 +735,31 @@ describe("misc", () => {
]);
// Call for multiple kinds of .all.
const allAccounts = await program.account.dataWithFilter.all();
const allAccountsFilteredByBuffer =
await program.account.dataWithFilter.all(
program.provider.wallet.publicKey.toBuffer()
);
const allAccountsFilteredByProgramFilters1 =
await program.account.dataWithFilter.all([
const allAccountsFilteredByBuffer = await program.account.dataWithFilter.all(
program.provider.wallet.publicKey.toBuffer()
);
const allAccountsFilteredByProgramFilters1 = await program.account.dataWithFilter.all(
[
{
memcmp: {
offset: 8,
bytes: program.provider.wallet.publicKey.toBase58(),
},
},
{ memcmp: { offset: 40, bytes: filterable1.toBase58() } },
]);
const allAccountsFilteredByProgramFilters2 =
await program.account.dataWithFilter.all([
]
);
const allAccountsFilteredByProgramFilters2 = await program.account.dataWithFilter.all(
[
{
memcmp: {
offset: 8,
bytes: program.provider.wallet.publicKey.toBase58(),
},
},
{ memcmp: { offset: 40, bytes: filterable2.toBase58() } },
]);
]
);
// Without filters there should be 4 accounts.
assert.equal(allAccounts.length, 4);
// Filtering by main wallet there should be 3 accounts.
Expand All @@ -772,32 +773,66 @@ describe("misc", () => {
});

it("Can use pdas with empty seeds", async () => {
const [pda, bump] = await PublicKey.findProgramAddress([], program.programId);
const [pda, bump] = await PublicKey.findProgramAddress(
[],
program.programId
);

await program.rpc.testInitWithEmptySeeds({
accounts: {
pda: pda,
authority: program.provider.wallet.publicKey,
systemProgram: anchor.web3.SystemProgram.programId
}
systemProgram: anchor.web3.SystemProgram.programId,
},
});
await program.rpc.testEmptySeedsConstraint({
accounts: {
pda: pda
}
pda: pda,
},
});

const [pda2, bump2] = await PublicKey.findProgramAddress(["non-empty"], program.programId);
const [pda2, bump2] = await PublicKey.findProgramAddress(
["non-empty"],
program.programId
);
await assert.rejects(
program.rpc.testEmptySeedsConstraint({
accounts: {
pda: pda2
}
pda: pda2,
},
}),
(err) => {
assert.equal(err.code, 146);
return true;
}
);
});

const ifNeededAcc = anchor.web3.Keypair.generate();

it("Can init if needed a new account", async () => {
await program.rpc.testInitIfNeeded(1, {
accounts: {
data: ifNeededAcc.publicKey,
systemProgram: anchor.web3.SystemProgram.programId,
payer: program.provider.wallet.publicKey,
},
signers: [ifNeededAcc],
});
const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
assert.ok(account.data, 1);
});

it("Can init if needed a previously created account", async () => {
await program.rpc.testInitIfNeeded(3, {
accounts: {
data: ifNeededAcc.publicKey,
systemProgram: anchor.web3.SystemProgram.programId,
payer: program.provider.wallet.publicKey,
},
signers: [ifNeededAcc],
});
const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
assert.ok(account.data, 3);
});
});

0 comments on commit 95bb9b3

Please sign in to comment.