From fd26b88dce65fe657e508183a7e2ae4522def1bd Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Thu, 24 Oct 2024 11:38:22 +1100 Subject: [PATCH] Move decoding into frame --- shotover-proxy/tests/kafka_int_tests/mod.rs | 59 +++-- shotover/src/codec/kafka.rs | 213 ++++++++---------- shotover/src/codec/mod.rs | 6 +- shotover/src/frame/kafka.rs | 39 +++- .../kafka/sink_cluster/connections.rs | 18 +- test-helpers/src/connection/kafka/python.rs | 26 +++ .../src/connection/kafka/python/auth_fail.py | 18 ++ 7 files changed, 231 insertions(+), 148 deletions(-) create mode 100644 test-helpers/src/connection/kafka/python/auth_fail.py diff --git a/shotover-proxy/tests/kafka_int_tests/mod.rs b/shotover-proxy/tests/kafka_int_tests/mod.rs index 3eaabb642..5f8e1cd83 100644 --- a/shotover-proxy/tests/kafka_int_tests/mod.rs +++ b/shotover-proxy/tests/kafka_int_tests/mod.rs @@ -9,6 +9,7 @@ use test_cases::produce_consume_partitions1; use test_cases::produce_consume_partitions3; use test_cases::{assert_topic_creation_is_denied_due_to_acl, setup_basic_user_acls}; use test_helpers::connection::kafka::node::run_node_smoke_test_scram; +use test_helpers::connection::kafka::python::run_python_bad_auth_sasl_scram; use test_helpers::connection::kafka::python::run_python_smoke_test_sasl_scram; use test_helpers::connection::kafka::{KafkaConnectionBuilder, KafkaDriver}; use test_helpers::docker_compose::docker_compose; @@ -755,21 +756,53 @@ async fn cluster_sasl_scram_over_mtls_nodejs_and_python() { let _docker_compose = docker_compose("tests/test-configs/kafka/cluster-sasl-scram-over-mtls/docker-compose.yaml"); - let shotover = shotover_process( - "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", - ) - .start() - .await; - run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; - run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await; + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; + + run_node_smoke_test_scram("127.0.0.1:9192", "super_user", "super_password").await; + run_python_smoke_test_sasl_scram("127.0.0.1:9192", "super_user", "super_password").await; - tokio::time::timeout( - Duration::from_secs(10), - shotover.shutdown_and_then_consume_events(&[]), - ) - .await - .expect("Shotover did not shutdown within 10s"); + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } + + { + let shotover = shotover_process( + "tests/test-configs/kafka/cluster-sasl-scram-over-mtls/topology-single.yaml", + ) + .start() + .await; + + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "incorrect_user", "super_password").await; + run_python_bad_auth_sasl_scram("127.0.0.1:9192", "super_user", "incorrect_password").await; + + tokio::time::timeout( + Duration::from_secs(10), + shotover.shutdown_and_then_consume_events(&[EventMatcher::new() + .with_level(Level::Error) + .with_target("shotover::server") + .with_message(r#"encountered an error when flushing the chain kafka for shutdown + +Caused by: + 1: KafkaSinkCluster transform failed + 2: Failed to receive responses (without sending requests) + 3: Outgoing connection had pending requests, those requests/responses are lost so connection recovery cannot be attempted. + 4: Failed to receive from ControlConnection + 5: The other side of this connection closed the connection"#) + .with_count(Count::Times(2))]), + ) + .await + .expect("Shotover did not shutdown within 10s"); + } } #[rstest] diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index e61a9bb45..4a0bbf47f 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -5,10 +5,8 @@ use crate::frame::{Frame, MessageType}; use crate::message::{Encodable, Message, MessageId, Messages}; use anyhow::{anyhow, Result}; use bytes::BytesMut; -use kafka_protocol::messages::{ - ApiKey, RequestHeader as RequestHeaderProtocol, RequestKind, ResponseHeader, ResponseKind, - SaslAuthenticateRequest, SaslAuthenticateResponse, -}; +use kafka_protocol::messages::{ApiKey, RequestKind, ResponseKind}; +use kafka_protocol::protocol::StrBytes; use metrics::Histogram; use std::sync::mpsc; use std::time::Instant; @@ -66,21 +64,32 @@ impl CodecBuilder for KafkaCodecBuilder { pub struct RequestInfo { header: RequestHeader, id: MessageId, - expect_raw_sasl: Option, + expect_raw_sasl: Option, } #[derive(Debug, Clone, PartialEq, Copy)] -pub enum SaslType { +pub enum SaslMechanismState { Plain, ScramMessage1, ScramMessage2, } +impl SaslMechanismState { + fn from_name(mechanism: &StrBytes) -> Result { + match mechanism.as_str() { + "PLAIN" => Ok(SaslMechanismState::Plain), + "SCRAM-SHA-512" => Ok(SaslMechanismState::ScramMessage1), + "SCRAM-SHA-256" => Ok(SaslMechanismState::ScramMessage1), + mechanism => Err(anyhow!("Unknown sasl mechanism {mechanism}")), + } + } +} + pub struct KafkaDecoder { // Some when Sink (because it receives responses) request_header_rx: Option>, direction: Direction, - expect_raw_sasl: Option, + expect_raw_sasl: Option, } impl KafkaDecoder { @@ -123,11 +132,6 @@ impl Decoder for KafkaDecoder { pretty_hex::pretty_hex(&bytes) ); - struct Meta { - request_header: RequestHeader, - message_id: Option, - } - let request_info = self .request_header_rx .as_ref() @@ -137,89 +141,75 @@ impl Decoder for KafkaDecoder { }) .transpose()?; - let message = if self.expect_raw_sasl.is_some() { - // Convert the unframed raw sasl into a framed sasl - // This allows transforms to correctly parse the message and inspect the sasl request - let kafka_frame = match self.direction { - Direction::Source => KafkaFrame::Request { - header: RequestHeaderProtocol::default() - .with_request_api_key(ApiKey::SaslAuthenticateKey as i16), - body: RequestKind::SaslAuthenticate( - SaslAuthenticateRequest::default().with_auth_bytes(bytes.freeze()), - ), - }, - Direction::Sink => KafkaFrame::Response { + struct Meta { + request_header: RequestHeader, + message_id: Option, + } + + let meta = if let Some(RequestInfo { header, id, .. }) = request_info { + Meta { + request_header: header, + message_id: Some(id), + } + } else if self.expect_raw_sasl.is_some() { + Meta { + request_header: RequestHeader { + api_key: ApiKey::SaslAuthenticateKey, version: 0, - header: ResponseHeader::default(), - body: ResponseKind::SaslAuthenticate( - SaslAuthenticateResponse::default().with_auth_bytes(bytes.freeze()), - // TODO: we need to set with_error_code - ), }, - }; - let codec_state = CodecState::Kafka(KafkaCodecState { - request_header: None, - raw_sasl: self.expect_raw_sasl, - }); - self.expect_raw_sasl = match self.expect_raw_sasl { - Some(SaslType::Plain) => None, - Some(SaslType::ScramMessage1) => Some(SaslType::ScramMessage2), - Some(SaslType::ScramMessage2) => None, - None => None, - }; - Message::from_frame_and_codec_state_at_instant( - Frame::Kafka(kafka_frame), - codec_state, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + } else { + Meta { + request_header: RequestHeader { + api_key: ApiKey::try_from(i16::from_be_bytes( + bytes[4..6].try_into().unwrap(), + )) + .unwrap(), + version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()), + }, + // This code path is only used for requests, so message_id can be None. + message_id: None, + } + }; + let mut message = if let Some(id) = meta.message_id.as_ref() { + let mut message = Message::from_bytes_at_instant( + bytes.freeze(), + CodecState::Kafka(KafkaCodecState { + request_header: Some(meta.request_header), + raw_sasl: self.expect_raw_sasl, + }), Some(received_at), - ) + ); + message.set_request_id(*id); + message } else { - let meta = if let Some(RequestInfo { - header, - id, - expect_raw_sasl, - }) = request_info - { - if let Some(expect_raw_sasl) = expect_raw_sasl { - self.expect_raw_sasl = Some(expect_raw_sasl); - } - Meta { - request_header: header, - message_id: Some(id), - } - } else { - Meta { - request_header: RequestHeader { - api_key: ApiKey::try_from(i16::from_be_bytes( - bytes[4..6].try_into().unwrap(), - )) - .unwrap(), - version: i16::from_be_bytes(bytes[6..8].try_into().unwrap()), - }, - message_id: None, - } - }; - let mut message = if let Some(id) = meta.message_id.as_ref() { - let mut message = Message::from_bytes_at_instant( - bytes.freeze(), - CodecState::Kafka(KafkaCodecState { - request_header: Some(meta.request_header), - raw_sasl: None, - }), - Some(received_at), - ); - message.set_request_id(*id); - message - } else { - Message::from_bytes_at_instant( - bytes.freeze(), - CodecState::Kafka(KafkaCodecState { - request_header: None, - raw_sasl: None, - }), - Some(received_at), - ) - }; + Message::from_bytes_at_instant( + bytes.freeze(), + CodecState::Kafka(KafkaCodecState { + request_header: None, + raw_sasl: self.expect_raw_sasl, + }), + Some(received_at), + ) + }; + // advanced to the next state of expect_raw_sasl + self.expect_raw_sasl = match self.expect_raw_sasl { + Some(SaslMechanismState::Plain) => None, + Some(SaslMechanismState::ScramMessage1) => Some(SaslMechanismState::ScramMessage2), + Some(SaslMechanismState::ScramMessage2) => None, + None => None, + }; + + if let Some(request_info) = request_info { + // set expect_raw_sasl for responses + if let Some(expect_raw_sasl) = request_info.expect_raw_sasl { + self.expect_raw_sasl = Some(expect_raw_sasl); + } + } else { + // set expect_raw_sasl for requests if meta.request_header.api_key == ApiKey::SaslHandshakeKey && meta.request_header.version == 0 { @@ -229,16 +219,10 @@ impl Decoder for KafkaDecoder { .. })) = message.frame() { - self.expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() { - "PLAIN" => SaslType::Plain, - "SCRAM-SHA-512" => SaslType::ScramMessage1, - "SCRAM-SHA-256" => SaslType::ScramMessage1, - mechanism => { - return Err(CodecReadError::Parser(anyhow!( - "Unknown sasl mechanism {mechanism}" - ))) - } - }); + self.expect_raw_sasl = Some( + SaslMechanismState::from_name(&sasl_handshake.mechanism) + .map_err(CodecReadError::Parser)?, + ); // Clear raw bytes of the message to force the encoder to encode from frame. // This is needed because the encoder only has access to the frame if it does not have any raw bytes, @@ -246,8 +230,7 @@ impl Decoder for KafkaDecoder { message.invalidate_cache(); } } - message - }; + } Ok(Some(vec![message])) } else { @@ -288,7 +271,11 @@ impl Encoder for KafkaEncoder { let response_is_dummy = m.response_is_dummy(); let id = m.id(); let received_at = m.received_from_source_or_sink_at; - let codec_state = m.codec_state.as_kafka(); + let message_contains_raw_sasl = if let CodecState::Kafka(codec_state) = m.codec_state { + codec_state.raw_sasl.is_some() + } else { + false + }; let mut expect_raw_sasl = None; let result = match m.into_encodable() { Encodable::Bytes(bytes) => { @@ -296,7 +283,7 @@ impl Encoder for KafkaEncoder { Ok(()) } Encodable::Frame(frame) => { - if codec_state.raw_sasl.is_some() { + if message_contains_raw_sasl { match frame { Frame::Kafka(KafkaFrame::Request { body: RequestKind::SaslAuthenticate(body), @@ -315,23 +302,17 @@ impl Encoder for KafkaEncoder { Ok(()) } else { let frame = frame.into_kafka().unwrap(); - // it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse it in the KafkaDecoder. + // it is garanteed that all v0 SaslHandshakes will be in a parsed state since we parse + invalidate_cache in the KafkaDecoder. if let KafkaFrame::Request { body: RequestKind::SaslHandshake(sasl_handshake), header, } = &frame { if header.request_api_version == 0 { - expect_raw_sasl = Some(match sasl_handshake.mechanism.as_str() { - "PLAIN" => SaslType::Plain, - "SCRAM-SHA-512" => SaslType::ScramMessage1, - "SCRAM-SHA-256" => SaslType::ScramMessage1, - mechanism => { - return Err(CodecWriteError::Encoder(anyhow!( - "Unknown sasl mechanism {mechanism}" - ))) - } - }); + expect_raw_sasl = Some( + SaslMechanismState::from_name(&sasl_handshake.mechanism) + .map_err(CodecWriteError::Encoder)?, + ); } } frame.encode(dst) @@ -343,7 +324,7 @@ impl Encoder for KafkaEncoder { // or if it will generate a dummy response if !dst[start..].is_empty() && !response_is_dummy { if let Some(tx) = self.request_header_tx.as_ref() { - let header = if codec_state.raw_sasl.is_some() { + let header = if message_contains_raw_sasl { RequestHeader { api_key: ApiKey::SaslAuthenticateKey, version: 0, diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index 6a373cfe5..77646c9df 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -7,7 +7,7 @@ use core::fmt; #[cfg(feature = "kafka")] use kafka::RequestHeader; #[cfg(feature = "kafka")] -use kafka::SaslType; +use kafka::SaslMechanismState; use metrics::{histogram, Histogram}; use tokio_util::codec::{Decoder, Encoder}; @@ -87,7 +87,9 @@ impl CodecState { #[derive(Debug, Clone, PartialEq, Copy)] pub struct KafkaCodecState { pub request_header: Option, - pub raw_sasl: Option, + /// When some this message is not in the kafka protocol and is instead a raw SASL message + /// KafkaFrame will parse this as a SaslHandshake to hide the legacy raw message from transform implementations. + pub raw_sasl: Option, } #[derive(Debug)] diff --git a/shotover/src/frame/kafka.rs b/shotover/src/frame/kafka.rs index c08773b1a..8b55ce77a 100644 --- a/shotover/src/frame/kafka.rs +++ b/shotover/src/frame/kafka.rs @@ -2,7 +2,9 @@ use crate::codec::kafka::RequestHeader as CodecRequestHeader; use crate::codec::KafkaCodecState; use anyhow::{anyhow, Context, Result}; use bytes::{BufMut, Bytes, BytesMut}; -use kafka_protocol::messages::{ApiKey, RequestHeader, ResponseHeader}; +use kafka_protocol::messages::{ + ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse, +}; use kafka_protocol::protocol::{Decodable, Encodable}; use std::fmt::{Display, Formatter, Result as FmtResult}; @@ -70,12 +72,35 @@ impl Display for KafkaFrame { impl KafkaFrame { pub fn from_bytes(mut bytes: Bytes, codec_state: KafkaCodecState) -> Result { - // remove length header - let _ = bytes.split_to(4); - - match &codec_state.request_header { - Some(request_header) => KafkaFrame::parse_response(bytes, *request_header), - None => KafkaFrame::parse_request(bytes), + if codec_state.raw_sasl.is_some() { + match &codec_state.request_header { + Some(_) => Ok(KafkaFrame::Response { + version: 0, + header: ResponseHeader::default(), + body: ResponseBody::SaslAuthenticate( + SaslAuthenticateResponse::default().with_auth_bytes(bytes), + // We dont set error_code field when the response contains a scram error, which sounds problematic. + // But in reality, at least for raw sasl mode, if kafka encounters an auth failure, + // it just kills the connection without sending any sasl response at all. + // So we never actually receive a scram response containing an error and + // so there would be no case where the error_code field would need to be set. + ), + }), + None => Ok(KafkaFrame::Request { + header: RequestHeader::default() + .with_request_api_key(ApiKey::SaslAuthenticateKey as i16), + body: RequestBody::SaslAuthenticate( + SaslAuthenticateRequest::default().with_auth_bytes(bytes), + ), + }), + } + } else { + // remove length header + let _ = bytes.split_to(4); + match &codec_state.request_header { + Some(request_header) => KafkaFrame::parse_response(bytes, *request_header), + None => KafkaFrame::parse_request(bytes), + } } } diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index 97d71c4a5..8764230b8 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -238,16 +238,14 @@ impl Connections { } else { KafkaNodeState::Up }; - nodes - .iter() - .find(|x| match destination { - Destination::Id(id) => x.broker_id == id, - Destination::ControlConnection => { - &x.kafka_address == self.control_connection_address.as_ref().unwrap() - } - }) - .unwrap() - .set_state(node_state); + if let Some(node) = nodes.iter().find(|x| match destination { + Destination::Id(id) => x.broker_id == id, + Destination::ControlConnection => { + &x.kafka_address == self.control_connection_address.as_ref().unwrap() + } + }) { + node.set_state(node_state); + } if old_connection .map(|old| old.pending_requests_count()) diff --git a/test-helpers/src/connection/kafka/python.rs b/test-helpers/src/connection/kafka/python.rs index 90c7e62b1..7f500ed31 100644 --- a/test-helpers/src/connection/kafka/python.rs +++ b/test-helpers/src/connection/kafka/python.rs @@ -80,6 +80,32 @@ pub async fn run_python_smoke_test_sasl_scram(address: &str, user: &str, passwor .unwrap(); } +pub async fn run_python_bad_auth_sasl_scram(address: &str, user: &str, password: &str) { + ensure_uv_is_installed().await; + + let project_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/connection/kafka/python"); + let uv_binary = uv_binary_path(); + let config = format!( + r#"{{ + 'bootstrap_servers': ["{address}"], + 'security_protocol': "SASL_PLAINTEXT", + 'sasl_mechanism': "SCRAM-SHA-256", + 'sasl_plain_username': "{user}", + 'sasl_plain_password': "{password}", +}}"# + ); + tokio::time::timeout( + Duration::from_secs(60), + run_command_async( + &project_dir, + uv_binary.to_str().unwrap(), + &["run", "auth_fail.py", &config], + ), + ) + .await + .unwrap(); +} + /// Install a specific version of UV to: /// * avoid developers having to manually install an external tool /// * avoid issues due to a different version being installed diff --git a/test-helpers/src/connection/kafka/python/auth_fail.py b/test-helpers/src/connection/kafka/python/auth_fail.py new file mode 100644 index 000000000..f81229e18 --- /dev/null +++ b/test-helpers/src/connection/kafka/python/auth_fail.py @@ -0,0 +1,18 @@ +from kafka import KafkaConsumer +from kafka import KafkaProducer +from kafka.errors import KafkaError +import sys + +def main(): + config = eval(sys.argv[1]) + print("Running kafka-python script with config:") + print(config) + + try: + KafkaProducer(**config) + raise Exception("KafkaProducer was succesfully created but expected to fail due to using incorrect username/password") + except KafkaError: + print("kafka-python auth_fail script passed all test cases") + +if __name__ == "__main__": + main()