Skip to content

Commit

Permalink
Discard packets statically known to fail (solana-labs#370)
Browse files Browse the repository at this point in the history
* Discard packets statically known to fail

* add test
  • Loading branch information
apfitzge authored Mar 21, 2024
1 parent 8b66a67 commit 5f16932
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
52 changes: 51 additions & 1 deletion core/src/banking_stage/immutable_deserialized_packet.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use {
solana_cost_model::block_cost_limits::BUILT_IN_INSTRUCTION_COSTS,
solana_perf::packet::Packet,
solana_runtime::compute_budget_details::{ComputeBudgetDetails, GetComputeBudgetDetails},
solana_sdk::{
feature_set,
hash::Hash,
message::Message,
sanitize::SanitizeError,
saturating_add_assign,
short_vec::decode_shortu16_len,
signature::Signature,
transaction::{
Expand Down Expand Up @@ -98,6 +100,22 @@ impl ImmutableDeserializedPacket {
self.compute_budget_details.clone()
}

/// Returns true if the transaction's compute unit limit is at least as
/// large as the sum of the static builtins' costs.
/// This is a simple sanity check so the leader can discard transactions
/// which are statically known to exceed the compute budget, and will
/// result in no useful state-change.
pub fn compute_unit_limit_above_static_builtins(&self) -> bool {
let mut static_builtin_cost_sum: u64 = 0;
for (program_id, _) in self.transaction.get_message().program_instructions_iter() {
if let Some(ix_cost) = BUILT_IN_INSTRUCTION_COSTS.get(program_id) {
saturating_add_assign!(static_builtin_cost_sum, *ix_cost);
}
}

self.compute_unit_limit() >= static_builtin_cost_sum
}

// This function deserializes packets into transactions, computes the blake3 hash of transaction
// messages, and verifies secp256k1 instructions.
pub fn build_sanitized_transaction(
Expand Down Expand Up @@ -150,7 +168,10 @@ fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> {
mod tests {
use {
super::*,
solana_sdk::{signature::Keypair, system_transaction},
solana_sdk::{
compute_budget, instruction::Instruction, pubkey::Pubkey, signature::Keypair,
signer::Signer, system_instruction, system_transaction, transaction::Transaction,
},
};

#[test]
Expand All @@ -166,4 +187,33 @@ mod tests {

assert!(deserialized_packet.is_ok());
}

#[test]
fn compute_unit_limit_above_static_builtins() {
// Cases:
// 1. compute_unit_limit under static builtins
// 2. compute_unit_limit equal to static builtins
// 3. compute_unit_limit above static builtins
for (cu_limit, expectation) in [(250, false), (300, true), (350, true)] {
let keypair = Keypair::new();
let bpf_program_id = Pubkey::new_unique();
let ixs = vec![
system_instruction::transfer(&keypair.pubkey(), &Pubkey::new_unique(), 1),
compute_budget::ComputeBudgetInstruction::set_compute_unit_limit(cu_limit),
Instruction::new_with_bytes(bpf_program_id, &[], vec![]), // non-builtin - not counted in filter
];
let tx = Transaction::new_signed_with_payer(
&ixs,
Some(&keypair.pubkey()),
&[&keypair],
Hash::new_unique(),
);
let packet = Packet::from_data(None, tx).unwrap();
let deserialized_packet = ImmutableDeserializedPacket::new(packet).unwrap();
assert_eq!(
deserialized_packet.compute_unit_limit_above_static_builtins(),
expectation
);
}
}
}
13 changes: 11 additions & 2 deletions core/src/banking_stage/packet_deserializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl PacketDeserializer {
&self,
recv_timeout: Duration,
capacity: usize,
packet_filter: impl Fn(&ImmutableDeserializedPacket) -> bool,
) -> Result<ReceivePacketResults, RecvTimeoutError> {
let (packet_count, packet_batches) = self.receive_until(recv_timeout, capacity)?;

Expand All @@ -62,6 +63,7 @@ impl PacketDeserializer {
packet_count,
&packet_batches,
round_compute_unit_price_enabled,
&packet_filter,
))
}

Expand All @@ -71,6 +73,7 @@ impl PacketDeserializer {
packet_count: usize,
banking_batches: &[BankingPacketBatch],
round_compute_unit_price_enabled: bool,
packet_filter: &impl Fn(&ImmutableDeserializedPacket) -> bool,
) -> ReceivePacketResults {
let mut passed_sigverify_count: usize = 0;
let mut failed_sigverify_count: usize = 0;
Expand All @@ -88,6 +91,7 @@ impl PacketDeserializer {
packet_batch,
&packet_indexes,
round_compute_unit_price_enabled,
packet_filter,
));
}

Expand Down Expand Up @@ -158,13 +162,16 @@ impl PacketDeserializer {
packet_batch: &'a PacketBatch,
packet_indexes: &'a [usize],
round_compute_unit_price_enabled: bool,
packet_filter: &'a (impl Fn(&ImmutableDeserializedPacket) -> bool + 'a),
) -> impl Iterator<Item = ImmutableDeserializedPacket> + 'a {
packet_indexes.iter().filter_map(move |packet_index| {
let mut packet_clone = packet_batch[*packet_index].clone();
packet_clone
.meta_mut()
.set_round_compute_unit_price(round_compute_unit_price_enabled);
ImmutableDeserializedPacket::new(packet_clone).ok()
ImmutableDeserializedPacket::new(packet_clone)
.ok()
.filter(packet_filter)
})
}
}
Expand All @@ -186,7 +193,7 @@ mod tests {

#[test]
fn test_deserialize_and_collect_packets_empty() {
let results = PacketDeserializer::deserialize_and_collect_packets(0, &[], false);
let results = PacketDeserializer::deserialize_and_collect_packets(0, &[], false, &|_| true);
assert_eq!(results.deserialized_packets.len(), 0);
assert!(results.new_tracer_stats_option.is_none());
assert_eq!(results.passed_sigverify_count, 0);
Expand All @@ -204,6 +211,7 @@ mod tests {
packet_count,
&[BankingPacketBatch::new((packet_batches, None))],
false,
&|_| true,
);
assert_eq!(results.deserialized_packets.len(), 2);
assert!(results.new_tracer_stats_option.is_none());
Expand All @@ -223,6 +231,7 @@ mod tests {
packet_count,
&[BankingPacketBatch::new((packet_batches, None))],
false,
&|_| true,
);
assert_eq!(results.deserialized_packets.len(), 1);
assert!(results.new_tracer_stats_option.is_none());
Expand Down
1 change: 1 addition & 0 deletions core/src/banking_stage/packet_receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl PacketReceiver {
.receive_packets(
recv_timeout,
unprocessed_transaction_storage.max_receive_size(),
|packet| packet.compute_unit_limit_above_static_builtins(),
)
// Consumes results if Ok, otherwise we keep the Err
.map(|receive_packet_results| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl SchedulerController {

let (received_packet_results, receive_time_us) = measure_us!(self
.packet_receiver
.receive_packets(recv_timeout, remaining_queue_capacity));
.receive_packets(recv_timeout, remaining_queue_capacity, |_| true));

self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.receive_time_us, receive_time_us);
Expand Down

0 comments on commit 5f16932

Please sign in to comment.