Skip to content

Commit

Permalink
review: refactor create_tx
Browse files Browse the repository at this point in the history
  • Loading branch information
ValuedMammal committed Sep 1, 2024
1 parent 9bd6c47 commit 9341041
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 33 deletions.
38 changes: 9 additions & 29 deletions crates/wallet/src/wallet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1449,24 +1449,18 @@ impl Wallet {

// get drain script
let mut drain_index = Option::<(KeychainKind, u32)>::None;
let mut must_reveal = false;
let change_keychain = self.map_keychain(KeychainKind::Internal);
let drain_script = match params.drain_to {
Some(ref drain_recipient) => drain_recipient.clone(),
None => {
let addr = self
.list_unused_addresses(change_keychain)
.next()
.unwrap_or_else(|| {
let next_index = self
.derivation_index(change_keychain)
.map(|i| i + 1)
.unwrap_or(0);
must_reveal = true;
self.peek_address(change_keychain, next_index)
});
drain_index = Some((change_keychain, addr.index));
addr.script_pubkey()
let change_keychain = self.map_keychain(KeychainKind::Internal);
let ((index, spk), index_changeset) = self
.indexed_graph
.index
.next_unused_spk(change_keychain)
.expect("keychain must exist");
self.stage.merge(index_changeset.into());
drain_index = Some((change_keychain, index));
spk
}
};

Expand Down Expand Up @@ -1569,20 +1563,6 @@ impl Wallet {
if let Excess::Change { .. } = excess {
if let Some((keychain, index)) = drain_index {
self.mark_used(keychain, index);
if must_reveal {
let (_, index_changeset) = self
.indexed_graph
.index
.reveal_to_target(keychain, index)
.expect("keychain must exist");
if let Some(last_reveal) = index_changeset.last_revealed.iter().next() {
debug_assert_eq!(
last_reveal,
(&self.public_descriptor(keychain).descriptor_id(), &index),
);
}
self.stage.merge(index_changeset.into());
}
}
}

Expand Down
11 changes: 7 additions & 4 deletions crates/wallet/tests/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1394,10 +1394,7 @@ fn test_create_tx_global_xpubs_with_origin() {

#[test]
fn test_create_tx_increment_change_index() {
// Check that a change address is marked used after `create_tx`.
// Failing to create a tx should not mark a change address used.
use coin_selection::Error;

let (desc, change_desc) = get_test_tr_single_sig_xprv_with_change_desc();
let (mut wallet, _) = get_funded_wallet_with_change(desc, change_desc);
let addr = wallet.next_unused_address(KeychainKind::External);
Expand All @@ -1416,6 +1413,7 @@ fn test_create_tx_increment_change_index() {
assert_eq!(
wallet.next_unused_address(KeychainKind::Internal),
internal_addr0,
"create tx fail should not mark change address used",
);

// send all (no change) should not increment index
Expand All @@ -1427,6 +1425,7 @@ fn test_create_tx_increment_change_index() {
assert_eq!(
wallet.next_unused_address(KeychainKind::Internal),
internal_addr0,
"no change output should not mark change address used",
);

// create tx with change should increment index
Expand All @@ -1435,7 +1434,11 @@ fn test_create_tx_increment_change_index() {
let psbt = builder.finish().unwrap();
assert_eq!(psbt.unsigned_tx.output.len(), 2);
let internal_addr1 = wallet.next_unused_address(KeychainKind::Internal);
assert_eq!(internal_addr1.index, 1);
assert_eq!(
internal_addr1.index, 1,
"internally derived drain output should mark change address used"
);
assert!(!wallet.mark_used(KeychainKind::Internal, 0))
}

#[test]
Expand Down

0 comments on commit 9341041

Please sign in to comment.