diff --git a/core/src/banking_stage/immutable_deserialized_packet.rs b/core/src/banking_stage/immutable_deserialized_packet.rs index 26ede7045d3480..8e31f9cd462473 100644 --- a/core/src/banking_stage/immutable_deserialized_packet.rs +++ b/core/src/banking_stage/immutable_deserialized_packet.rs @@ -1,4 +1,5 @@ use { + solana_cost_model::block_cost_limits::BUILT_IN_INSTRUCTION_COSTS, solana_perf::packet::Packet, solana_runtime::compute_budget_details::{ComputeBudgetDetails, GetComputeBudgetDetails}, solana_sdk::{ @@ -6,6 +7,7 @@ use { hash::Hash, message::Message, sanitize::SanitizeError, + saturating_add_assign, short_vec::decode_shortu16_len, signature::Signature, transaction::{ @@ -98,6 +100,22 @@ impl ImmutableDeserializedPacket { self.compute_budget_details.clone() } + /// Returns true if the transaction's compute unit limit is at least as + /// large as the sum of the static builtins' costs. + /// This is a simple sanity check so the leader can discard transactions + /// which are statically known to exceed the compute budget, and will + /// result in no useful state-change. + pub fn compute_unit_limit_above_static_builtins(&self) -> bool { + let mut static_builtin_cost_sum: u64 = 0; + for (program_id, _) in self.transaction.get_message().program_instructions_iter() { + if let Some(ix_cost) = BUILT_IN_INSTRUCTION_COSTS.get(program_id) { + saturating_add_assign!(static_builtin_cost_sum, *ix_cost); + } + } + + self.compute_unit_limit() >= static_builtin_cost_sum + } + // This function deserializes packets into transactions, computes the blake3 hash of transaction // messages, and verifies secp256k1 instructions. pub fn build_sanitized_transaction( @@ -150,7 +168,10 @@ fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> { mod tests { use { super::*, - solana_sdk::{signature::Keypair, system_transaction}, + solana_sdk::{ + compute_budget, instruction::Instruction, pubkey::Pubkey, signature::Keypair, + signer::Signer, system_instruction, system_transaction, transaction::Transaction, + }, }; #[test] @@ -166,4 +187,33 @@ mod tests { assert!(deserialized_packet.is_ok()); } + + #[test] + fn compute_unit_limit_above_static_builtins() { + // Cases: + // 1. compute_unit_limit under static builtins + // 2. compute_unit_limit equal to static builtins + // 3. compute_unit_limit above static builtins + for (cu_limit, expectation) in [(250, false), (300, true), (350, true)] { + let keypair = Keypair::new(); + let bpf_program_id = Pubkey::new_unique(); + let ixs = vec![ + system_instruction::transfer(&keypair.pubkey(), &Pubkey::new_unique(), 1), + compute_budget::ComputeBudgetInstruction::set_compute_unit_limit(cu_limit), + Instruction::new_with_bytes(bpf_program_id, &[], vec![]), // non-builtin - not counted in filter + ]; + let tx = Transaction::new_signed_with_payer( + &ixs, + Some(&keypair.pubkey()), + &[&keypair], + Hash::new_unique(), + ); + let packet = Packet::from_data(None, tx).unwrap(); + let deserialized_packet = ImmutableDeserializedPacket::new(packet).unwrap(); + assert_eq!( + deserialized_packet.compute_unit_limit_above_static_builtins(), + expectation + ); + } + } } diff --git a/core/src/banking_stage/packet_deserializer.rs b/core/src/banking_stage/packet_deserializer.rs index a405b626568482..1d1079eaf97fcd 100644 --- a/core/src/banking_stage/packet_deserializer.rs +++ b/core/src/banking_stage/packet_deserializer.rs @@ -50,6 +50,7 @@ impl PacketDeserializer { &self, recv_timeout: Duration, capacity: usize, + packet_filter: impl Fn(&ImmutableDeserializedPacket) -> bool, ) -> Result { let (packet_count, packet_batches) = self.receive_until(recv_timeout, capacity)?; @@ -62,6 +63,7 @@ impl PacketDeserializer { packet_count, &packet_batches, round_compute_unit_price_enabled, + &packet_filter, )) } @@ -71,6 +73,7 @@ impl PacketDeserializer { packet_count: usize, banking_batches: &[BankingPacketBatch], round_compute_unit_price_enabled: bool, + packet_filter: &impl Fn(&ImmutableDeserializedPacket) -> bool, ) -> ReceivePacketResults { let mut passed_sigverify_count: usize = 0; let mut failed_sigverify_count: usize = 0; @@ -88,6 +91,7 @@ impl PacketDeserializer { packet_batch, &packet_indexes, round_compute_unit_price_enabled, + packet_filter, )); } @@ -158,13 +162,16 @@ impl PacketDeserializer { packet_batch: &'a PacketBatch, packet_indexes: &'a [usize], round_compute_unit_price_enabled: bool, + packet_filter: &'a (impl Fn(&ImmutableDeserializedPacket) -> bool + 'a), ) -> impl Iterator + 'a { packet_indexes.iter().filter_map(move |packet_index| { let mut packet_clone = packet_batch[*packet_index].clone(); packet_clone .meta_mut() .set_round_compute_unit_price(round_compute_unit_price_enabled); - ImmutableDeserializedPacket::new(packet_clone).ok() + ImmutableDeserializedPacket::new(packet_clone) + .ok() + .filter(packet_filter) }) } } @@ -186,7 +193,7 @@ mod tests { #[test] fn test_deserialize_and_collect_packets_empty() { - let results = PacketDeserializer::deserialize_and_collect_packets(0, &[], false); + let results = PacketDeserializer::deserialize_and_collect_packets(0, &[], false, &|_| true); assert_eq!(results.deserialized_packets.len(), 0); assert!(results.new_tracer_stats_option.is_none()); assert_eq!(results.passed_sigverify_count, 0); @@ -204,6 +211,7 @@ mod tests { packet_count, &[BankingPacketBatch::new((packet_batches, None))], false, + &|_| true, ); assert_eq!(results.deserialized_packets.len(), 2); assert!(results.new_tracer_stats_option.is_none()); @@ -223,6 +231,7 @@ mod tests { packet_count, &[BankingPacketBatch::new((packet_batches, None))], false, + &|_| true, ); assert_eq!(results.deserialized_packets.len(), 1); assert!(results.new_tracer_stats_option.is_none()); diff --git a/core/src/banking_stage/packet_receiver.rs b/core/src/banking_stage/packet_receiver.rs index a566ef7cf3e4c1..bbb753967f20ce 100644 --- a/core/src/banking_stage/packet_receiver.rs +++ b/core/src/banking_stage/packet_receiver.rs @@ -49,6 +49,7 @@ impl PacketReceiver { .receive_packets( recv_timeout, unprocessed_transaction_storage.max_receive_size(), + |packet| packet.compute_unit_limit_above_static_builtins(), ) // Consumes results if Ok, otherwise we keep the Err .map(|receive_packet_results| { diff --git a/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs b/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs index 12e8f7bf8bf0bf..0b10f613e64cd6 100644 --- a/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs +++ b/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs @@ -322,7 +322,7 @@ impl SchedulerController { let (received_packet_results, receive_time_us) = measure_us!(self .packet_receiver - .receive_packets(recv_timeout, remaining_queue_capacity)); + .receive_packets(recv_timeout, remaining_queue_capacity, |_| true)); self.timing_metrics.update(|timing_metrics| { saturating_add_assign!(timing_metrics.receive_time_us, receive_time_us);