Skip to content

Commit

Permalink
Merge pull request #18464 from mmaslankaprv/mpx-client-parsing
Browse files Browse the repository at this point in the history
Made client id parsing vcluster aware
  • Loading branch information
mmaslankaprv authored May 16, 2024
2 parents 3ec4ac3 + 187dcbf commit 076ddb8
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 37 deletions.
118 changes: 103 additions & 15 deletions src/v/kafka/server/connection_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,94 @@ class invalid_virtual_connection_id : public std::exception {
explicit invalid_virtual_connection_id(ss::sstring msg)
: _msg(std::move(msg)) {}

const char* what() const noexcept final { return _msg.c_str(); }

private:
ss::sstring _msg;
};

bytes parse_virtual_connection_id(const kafka::request_header& header) {
// Tuple containing virtual connection id and client id. It is returned after
// parsing parts of virtual connection id.
struct virtual_connection_client_id {
virtual_connection_id v_connection_id;
std::optional<std::string_view> client_id;
};

const std::regex hex_characters_regexp{R"REGEX(^[a-f0-9A-F]{8}$)REGEX"};

vcluster_connection_id
parse_vcluster_connection_id(const std::string& hex_str) {
std::smatch matches;
auto match = std::regex_match(
hex_str.cbegin(), hex_str.cend(), matches, hex_characters_regexp);
if (!match) {
throw invalid_virtual_connection_id(fmt::format(
"virtual cluster connection id '{}' is not a hexadecimal integer",
hex_str));
}

vcluster_connection_id cid;

std::stringstream sstream(hex_str);
sstream >> std::hex >> cid;
return cid;
}

/**
* Virtual connection id is encoded as with the following structure:
*
* [vcluster_id][connection_id][actual client id]
*
* vcluster_id - is a string encoded XID representing virtual cluster (20
* characters)
* connection_id - is a hex encoded 32 bit integer representing virtual
* connection id (8 characters)
*
* client_id - standard protocol defined client id
*/
virtual_connection_client_id
parse_virtual_connection_id(const kafka::request_header& header) {
static constexpr size_t connection_id_str_size
= sizeof(vcluster_connection_id::type) * 2;
static constexpr size_t v_connection_id_size = xid::str_size
+ connection_id_str_size;
if (header.client_id_buffer.empty()) {
throw invalid_virtual_connection_id(
"virtual connection client id can not be empty");
}

// TODO: should we use vcluster_id here ?
return bytes{
reinterpret_cast<const uint8_t*>(header.client_id_buffer.begin()),
header.client_id_buffer.size()};
if (header.client_id->size() < v_connection_id_size) {
throw invalid_virtual_connection_id(fmt::format(
"virtual connection client id size must contain at least {} "
"characters. Current size: {}",
v_connection_id_size,
header.client_id_buffer.size()));
}
try {
virtual_connection_id connection_id{
.virtual_cluster_id = xid::from_string(
std::string_view(header.client_id->begin(), xid::str_size)),
.connection_id = parse_vcluster_connection_id(std::string(
std::next(header.client_id_buffer.begin(), xid::str_size),
connection_id_str_size))};

return virtual_connection_client_id{
.v_connection_id = connection_id,
// a reminder of client id buffer is used as a standard protocol
// client_id.
.client_id
= header.client_id_buffer.size() == v_connection_id_size
? std::nullopt
: std::make_optional<std::string_view>(
std::next(
header.client_id_buffer.begin(), v_connection_id_size),
header.client_id_buffer.size() - v_connection_id_size),
};
} catch (const invalid_xid& e) {
throw invalid_virtual_connection_id(e.what());
}
}

} // namespace

connection_context::connection_context(
std::optional<
std::reference_wrapper<boost::intrusive::list<connection_context>>> hook,
Expand Down Expand Up @@ -637,18 +707,27 @@ connection_context::dispatch_method_once(request_header hdr, size_t size) {
* Not virtualized connection, simply forward to protocol state for request
* processing.
*/
if (!_is_virtualized_connection) {
if (
!_is_virtualized_connection
|| rctx.header().client_id == multi_proxy_initial_client_id) {
co_return co_await _protocol_state.process_request(
shared_from_this(), std::move(rctx), sres);
}
auto client_connection_id = parse_virtual_connection_id(rctx.header());
rctx.override_client_id(client_connection_id.client_id);
vlog(
klog.trace,
"request from virtual connection {}, client id: {}",
client_connection_id.v_connection_id,
client_connection_id.client_id);

auto v_connection_id = parse_virtual_connection_id(rctx.header());
auto it = _virtual_states.lazy_emplace(
v_connection_id, [v_connection_id](const auto& ctr) mutable {
return ctr(
std::move(v_connection_id),
ss::make_lw_shared<virtual_connection_state>());
});
auto it = _virtual_states.find(client_connection_id.v_connection_id);
if (it == _virtual_states.end()) {
auto p = _virtual_states.emplace(
client_connection_id.v_connection_id,
ss::make_lw_shared<virtual_connection_state>());
it = p.first;
}

co_await it->second->process_request(
shared_from_this(), std::move(rctx), sres);
Expand Down Expand Up @@ -865,4 +944,13 @@ ss::future<> connection_context::client_protocol_state::maybe_process_responses(
});
}

std::ostream& operator<<(std::ostream& o, const virtual_connection_id& id) {
fmt::print(
o,
"{{virtual_cluster_id: {}, connection_id: {}}}",
id.virtual_cluster_id,
id.connection_id);
return o;
}

} // namespace kafka
26 changes: 25 additions & 1 deletion src/v/kafka/server/connection_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#pragma once
#include "base/seastarx.h"
#include "config/property.h"
#include "container/chunked_hash_map.h"
#include "kafka/server/fwd.h"
#include "kafka/server/handlers/handler_probe.h"
#include "kafka/server/logger.h"
Expand Down Expand Up @@ -120,7 +121,28 @@ struct session_resources {
std::unique_ptr<request_tracker> tracker;
request_data request_data;
};
using vcluster_connection_id
= named_type<uint32_t, struct vcluster_connection_id_tag>;
/**
* Struct representing virtual connection identifier. Each virtual cluster may
* have multiple connections identified with connection_id.
*/
struct virtual_connection_id {
xid virtual_cluster_id;
vcluster_connection_id connection_id;

template<typename H>
friend H AbslHashValue(H h, const virtual_connection_id& id) {
return H::combine(
std::move(h), id.virtual_cluster_id, id.connection_id);
}
friend bool
operator==(const virtual_connection_id&, const virtual_connection_id&)
= default;

friend std::ostream&
operator<<(std::ostream& o, const virtual_connection_id& id);
};
class connection_context final
: public ss::enable_lw_shared_from_this<connection_context>
, public boost::intrusive::list_base_hook<> {
Expand Down Expand Up @@ -416,7 +438,9 @@ class connection_context final
* A map keeping virtual connection states, during default operation the map
* is empty
*/
absl::node_hash_map<bytes, ss::lw_shared_ptr<virtual_connection_state>>
chunked_hash_map<
virtual_connection_id,
ss::lw_shared_ptr<virtual_connection_state>>
_virtual_states;

ss::gate _gate;
Expand Down
6 changes: 6 additions & 0 deletions src/v/kafka/server/request_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ class request_context {

const request_header& header() const { return _header; }

// override the client id. This method is used when handling virtual
// connections and an actual client id is part of the client id buffer.
void override_client_id(std::optional<std::string_view> new_client_id) {
_header.client_id = new_client_id;
}

ss::lw_shared_ptr<connection_context> connection() { return _conn; }

ssx::sharded_abort_source& abort_source() { return _conn->abort_source(); }
Expand Down
4 changes: 2 additions & 2 deletions src/v/utils/xid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ constexpr decoder_t build_decoder_table() {
static constexpr decoder_t decoder = build_decoder_table();
} // namespace

invalid_xid::invalid_xid(const ss::sstring& current_string)
invalid_xid::invalid_xid(std::string_view current_string)
: _msg(ssx::sformat("String '{}' is not a valid xid", current_string)) {}

xid xid::from_string(const ss::sstring& str) {
xid xid::from_string(std::string_view str) {
if (str.size() != str_size) {
throw invalid_xid(str);
}
Expand Down
4 changes: 2 additions & 2 deletions src/v/utils/xid.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
*/
class invalid_xid final : public std::exception {
public:
explicit invalid_xid(const ss::sstring&);
explicit invalid_xid(std::string_view);
const char* what() const noexcept final { return _msg.c_str(); }

private:
Expand Down Expand Up @@ -72,7 +72,7 @@ class xid {
*
* @return an xid decoded from the string provided
*/
static xid from_string(const ss::sstring&);
static xid from_string(std::string_view);

friend bool operator==(const xid&, const xid&) = default;

Expand Down
88 changes: 79 additions & 9 deletions tests/rptest/tests/connection_virtualizing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# by the Apache License, Version 2.0

from dataclasses import dataclass
import random
from types import MethodType
from rptest.services.cluster import cluster
from rptest.clients.types import TopicSpec
from kafka.protocol.fetch import FetchRequest
from ducktape.mark import matrix

from rptest.tests.redpanda_test import RedpandaTest
from rptest.util import wait_until
Expand All @@ -22,6 +24,8 @@
from kafka import KafkaClient, KafkaConsumer
from kafka.protocol.produce import ProduceRequest

from rptest.utils.xid_utils import random_xid_string


@dataclass
class PartitionInfo:
Expand Down Expand Up @@ -157,6 +161,10 @@ def no_validation_process_response(self, read_buffer):
return (recv_correlation_id, response)


def create_client_id(vcluster_id: str, connection_id: int, client_id: str):
return f"{vcluster_id}{connection_id:08x}{client_id}"


class TestVirtualConnections(RedpandaTest):
def __init__(self, test_context):
super(TestVirtualConnections, self).__init__(
Expand Down Expand Up @@ -187,22 +195,31 @@ def _fetch_and_produce(self, client: MpxMockClient, topic: str,
return (fetch_fut, produce_fut)

@cluster(num_nodes=3)
def test_no_head_of_line_blocking(self):
@matrix(different_clusters=[True, False],
different_connections=[True, False])
def test_no_head_of_line_blocking(self, different_clusters,
different_connections):

# create topic with single partition
spec = TopicSpec(partition_count=1, replication_factor=3)
self.client().create_topic(spec)

mpx_client = MpxMockClient(self.redpanda)
mpx_client.start()
v_cluster_1 = random_xid_string()
v_cluster_2 = random_xid_string()

fetch_client = create_client_id(v_cluster_1, 0, "client-fetch")
produce_client = create_client_id(
v_cluster_1 if not different_clusters else v_cluster_2,
0 if not different_connections else 1, "client-produce")
# validate that fetch request is blocking produce request first as mpx extensions are disabled
(fetch_fut, produce_fut) = self._fetch_and_produce(
client=mpx_client,
topic=spec.name,
partition=0,
fetch_client_id="v-cluster-1",
produce_client_id="v-cluster-2")
fetch_client_id=fetch_client,
produce_client_id=produce_client)

mpx_client.poll(produce_fut)
assert produce_fut.is_done and produce_fut.succeeded
Expand Down Expand Up @@ -231,8 +248,8 @@ def test_no_head_of_line_blocking(self):
client=mpx_client,
topic=spec.name,
partition=0,
fetch_client_id="v-cluster-10",
produce_client_id="v-cluster-20")
fetch_client_id=fetch_client,
produce_client_id=produce_client)

for connection in mpx_client.client._conns.values():
if len(connection._protocol.in_flight_requests) == 2:
Expand All @@ -241,9 +258,11 @@ def test_no_head_of_line_blocking(self):
no_validation_process_response, connection._protocol)

# wait for fetch as it will be released after produce finishes
should_interleave_requests = different_clusters or different_connections

def _produce_is_ready():
mpx_client.poll(fetch_fut)
mpx_client.poll(
fetch_fut if should_interleave_requests else produce_fut)
return produce_fut.is_done

wait_until(
Expand All @@ -260,7 +279,58 @@ def _produce_is_ready():

f_resp = fetch_fut.value

#assert produce_fut.is_done and produce_fut.succeeded, "produce future should be ready when fetch resolved"
assert f_resp.topics[0][1][0][
6] != b'', "Fetch should be unblocked by produce from another virtual connection"
if should_interleave_requests:
assert f_resp.topics[0][1][0][
6] != b'', "Fetch should be unblocked by produce from another virtual connection"
else:
assert f_resp.topics[0][1][0][
6] == b'', "Fetch should be executed before the produce finishes"
mpx_client.close()

@cluster(num_nodes=3)
def test_handling_invalid_ids(self):
self.redpanda.set_cluster_config({"enable_mpx_extensions": True})
# create topic with single partition
spec = TopicSpec(partition_count=1, replication_factor=3)
topic = spec.name
self.client().create_topic(spec)

def produce_with_client(client_id: str):
mpx_client = MpxMockClient(self.redpanda)
mpx_client.start()
partition_info = mpx_client.get_partition_info(topic, 0)
mpx_client.set_client_id(client_id)
produce_fut = mpx_client.send(
node_id=partition_info.leader_id,
request=mpx_client.create_produce_request(topic=topic,
partition=0))

mpx_client.poll(produce_fut)
assert produce_fut.is_done and produce_fut.succeeded
pi = mpx_client.get_partition_info(topic, 0)
mpx_client.close()
return pi

v_cluster = random_xid_string()
valid_client_id = create_client_id(v_cluster, 0, "client-fetch")
p_info = produce_with_client(valid_client_id)

assert p_info.end_offset > 0, "Produce request should be successful"
starting_end_offset = p_info.end_offset
invalid_xid_id = create_client_id("zzzzzzzzzzzzzzzzzzzz", 0,
"client-fetch")

p_info = produce_with_client(invalid_xid_id)

assert starting_end_offset == p_info.end_offset, "Produce request with invalid client id should fail"

invalid_connection_id = f"{v_cluster}00blob00client"

p_info = produce_with_client(invalid_connection_id)

assert starting_end_offset == p_info.end_offset, "Produce request with invalid client id should fail"

valid_client_id_empty = create_client_id(v_cluster, 0, "")
p_info = produce_with_client(valid_client_id_empty)

assert starting_end_offset < p_info.end_offset, "Produce request with valid client id should succeed "
Loading

0 comments on commit 076ddb8

Please sign in to comment.