Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
qicosmos committed May 2, 2024
1 parent 2f3347b commit c114a09
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 61 deletions.
8 changes: 4 additions & 4 deletions include/cinatra/coro_http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
}
else {
#endif
std::string encode_header = ws.encode_frame(source, op, true);
auto encode_header = ws.encode_frame(source, op, true);
std::vector<asio::const_buffer> buffers{
asio::buffer(encode_header.data(), encode_header.size()),
asio::buffer(source.data(), source.size())};
Expand All @@ -459,7 +459,7 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
if (cinatra::gzip_codec::deflate(
{result.buf.data(), result.buf.size()}, dest_buf)) {
std::span<char> msg(dest_buf.data(), dest_buf.size());
std::string header = ws.encode_frame(msg, op, result.eof, true);
auto header = ws.encode_frame(msg, op, result.eof, true);
std::vector<asio::const_buffer> buffers{asio::buffer(header),
asio::buffer(dest_buf)};
auto [ec, sz] = co_await async_write(buffers);
Expand All @@ -478,7 +478,7 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
else {
#endif
std::span<char> msg(result.buf.data(), result.buf.size());
std::string encode_header = ws.encode_frame(msg, op, result.eof);
auto encode_header = ws.encode_frame(msg, op, result.eof);
std::vector<asio::const_buffer> buffers{
asio::buffer(encode_header.data(), encode_header.size()),
asio::buffer(msg.data(), msg.size())};
Expand Down Expand Up @@ -1953,7 +1953,7 @@ class coro_http_client : public std::enable_shared_from_this<coro_http_client> {
auto close_str = ws.format_close_payload(close_code::normal,
reason.data(), reason.size());
auto span = std::span<char>(close_str);
std::string encode_header = ws.encode_frame(span, opcode::close, true);
auto encode_header = ws.encode_frame(span, opcode::close, true);
std::vector<asio::const_buffer> buffers{asio::buffer(encode_header),
asio::buffer(reason)};

Expand Down
7 changes: 3 additions & 4 deletions include/cinatra/coro_http_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ class coro_http_connection
async_simple::coro::Lazy<std::error_code> write_websocket(
std::string_view msg, opcode op = opcode::text) {
std::vector<asio::const_buffer> buffers;
std::string header;
std::string_view header;
#ifdef CINATRA_ENABLE_GZIP
std::string dest_buf;
if (is_client_ws_compressed_ && msg.size() > 0) {
Expand All @@ -568,13 +568,13 @@ class coro_http_connection
co_return std::make_error_code(std::errc::protocol_error);
}

header = ws_.format_header(dest_buf.length(), op, true);
header = ws_.encode_ws_header(dest_buf.length(), op, true, true, false);
buffers.push_back(asio::buffer(header));
buffers.push_back(asio::buffer(dest_buf));
}
else {
#endif
header = ws_.format_header(msg.length(), op);
header = ws_.encode_ws_header(msg.length(), op, true, false, false);
buffers.push_back(asio::buffer(header));
buffers.push_back(asio::buffer(msg));
#ifdef CINATRA_ENABLE_GZIP
Expand Down Expand Up @@ -666,7 +666,6 @@ class coro_http_connection

std::string close_msg = ws_.format_close_payload(
close_code::normal, close_frame.message, close_frame.length);
auto header = ws_.format_header(close_msg.length(), opcode::close);

co_await write_websocket(close_msg, opcode::close);
close();
Expand Down
98 changes: 45 additions & 53 deletions include/cinatra/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class websocket {
}

if (msg_masked) {
std::memcpy(mask_, inp + pos, 4);
std::memcpy(mask_key_, inp + pos, 4);
}

return left_header_len_ == 0 ? ws_header_status::complete
Expand All @@ -95,9 +95,9 @@ class websocket {

ws_frame_type parse_payload(std::span<char> buf) {
// unmask data:
if (*(uint32_t *)mask_ != 0) {
if (*(uint32_t *)mask_key_ != 0) {
for (size_t i = 0; i < payload_length_; i++) {
buf[i] = buf[i] ^ mask_[i % 4];
buf[i] = buf[i] ^ mask_key_[i % 4];
}
}

Expand All @@ -121,16 +121,9 @@ class websocket {
return ws_frame_type::WS_BINARY_FRAME;
}

std::string format_header(size_t length, opcode code,
bool is_compressed = false) {
size_t header_length = encode_header(length, code, is_compressed);
return {msg_header_, header_length};
}

std::string encode_frame(std::span<char> &data, opcode op, bool eof,
bool need_compression = false) {
std::string header;
/// Base header.
std::string_view encode_ws_header(size_t size, opcode op, bool eof,
bool need_compression = false,
bool is_client = true) {
frame_header hdr{};
hdr.fin = eof;
hdr.rsv1 = 0;
Expand All @@ -140,56 +133,55 @@ class websocket {
hdr.rsv2 = 0;
hdr.rsv3 = 0;
hdr.opcode = static_cast<uint8_t>(op);
hdr.mask = 1;

if (data.empty()) {
int mask = 0;
header.resize(sizeof(frame_header) + sizeof(mask));
std::memcpy(header.data(), &hdr, sizeof(hdr));
std::memcpy(header.data() + sizeof(hdr), &mask, sizeof(mask));
return header;
}
hdr.mask = is_client;

hdr.len = size < 126 ? size : (size < 65536 ? 126 : 127);

hdr.len =
data.size() < 126 ? data.size() : (data.size() < 65536 ? 126 : 127);

uint8_t buffer[sizeof(frame_header)];
std::memcpy(buffer, (uint8_t *)&hdr, sizeof(hdr));
std::string str_hdr_len =
std::string((const char *)buffer, sizeof(frame_header));
header.append(str_hdr_len);

/// The payload length may be larger than 126 bytes.
std::string str_payload_len;
if (data.size() >= 126) {
if (data.size() >= 65536) {
uint64_t len = data.size();
str_payload_len.resize(sizeof(uint64_t));
*((uint64_t *)&str_payload_len[0]) = htobe64(len);
std::memcpy(msg_header_, (char *)&hdr, sizeof(hdr));

size_t len_bytes = 0;
if (size >= 126) {
if (size >= 65536) {
len_bytes = 8;
*((uint64_t *)(msg_header_ + 2)) = htobe64(size);
}
else {
uint16_t len = data.size();
str_payload_len.resize(sizeof(uint16_t));
*((uint16_t *)&str_payload_len[0]) = htons(static_cast<uint16_t>(len));
len_bytes = 2;
*((uint16_t *)(msg_header_ + 2)) = htons(static_cast<uint16_t>(size));
}
header.append(str_payload_len);
}

/// The mask is a 32-bit value.
uint8_t mask[4] = {};
header[1] |= 0x80;
uint32_t random = (uint32_t)rand();
memcpy(mask, &random, 4);
size_t header_len = 6;

if (is_client) {
if (size > 0) {
// generate mask key.
uint32_t random = (uint32_t)rand();
memcpy(mask_key_, &random, 4);
}

std::memcpy(msg_header_ + 2 + len_bytes, mask_key_, 4);
}
else {
header_len = 2;
}

size_t size = header.size();
header.resize(size + 4);
std::memcpy(header.data() + size, mask, 4);
return {msg_header_, header_len + len_bytes};
}

void encode_ws_payload(std::span<char> &data) {
for (int i = 0; i < data.size(); ++i) {
data[i] ^= mask[i % 4];
data[i] ^= mask_key_[i % 4];
}
}

std::string_view encode_frame(std::span<char> &data, opcode op, bool eof,
bool need_compression = false) {
std::string_view ws_header =
encode_ws_header(data.size(), op, eof, need_compression);
encode_ws_payload(data);

return header;
return ws_header;
}

close_frame parse_close_payload(char *src, size_t length) {
Expand Down Expand Up @@ -264,7 +256,7 @@ class websocket {
size_t payload_length_ = 0;

size_t left_header_len_ = 0;
uint8_t mask_[4] = {};
uint8_t mask_key_[4] = {};
unsigned char msg_opcode_ = 0;
unsigned char msg_fin_ = 0;

Expand Down

0 comments on commit c114a09

Please sign in to comment.