diff --git a/include/cinatra/coro_http_client.hpp b/include/cinatra/coro_http_client.hpp index 62df486c..590c9c7f 100644 --- a/include/cinatra/coro_http_client.hpp +++ b/include/cinatra/coro_http_client.hpp @@ -436,7 +436,7 @@ class coro_http_client : public std::enable_shared_from_this { } else { #endif - std::string encode_header = ws.encode_frame(source, op, true); + auto encode_header = ws.encode_frame(source, op, true); std::vector buffers{ asio::buffer(encode_header.data(), encode_header.size()), asio::buffer(source.data(), source.size())}; @@ -459,7 +459,7 @@ class coro_http_client : public std::enable_shared_from_this { if (cinatra::gzip_codec::deflate( {result.buf.data(), result.buf.size()}, dest_buf)) { std::span msg(dest_buf.data(), dest_buf.size()); - std::string header = ws.encode_frame(msg, op, result.eof, true); + auto header = ws.encode_frame(msg, op, result.eof, true); std::vector buffers{asio::buffer(header), asio::buffer(dest_buf)}; auto [ec, sz] = co_await async_write(buffers); @@ -478,7 +478,7 @@ class coro_http_client : public std::enable_shared_from_this { else { #endif std::span msg(result.buf.data(), result.buf.size()); - std::string encode_header = ws.encode_frame(msg, op, result.eof); + auto encode_header = ws.encode_frame(msg, op, result.eof); std::vector buffers{ asio::buffer(encode_header.data(), encode_header.size()), asio::buffer(msg.data(), msg.size())}; @@ -1953,7 +1953,7 @@ class coro_http_client : public std::enable_shared_from_this { auto close_str = ws.format_close_payload(close_code::normal, reason.data(), reason.size()); auto span = std::span(close_str); - std::string encode_header = ws.encode_frame(span, opcode::close, true); + auto encode_header = ws.encode_frame(span, opcode::close, true); std::vector buffers{asio::buffer(encode_header), asio::buffer(reason)}; diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index 482b3b9f..7159a493 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -559,7 +559,7 @@ class coro_http_connection async_simple::coro::Lazy write_websocket( std::string_view msg, opcode op = opcode::text) { std::vector buffers; - std::string header; + std::string_view header; #ifdef CINATRA_ENABLE_GZIP std::string dest_buf; if (is_client_ws_compressed_ && msg.size() > 0) { @@ -568,13 +568,13 @@ class coro_http_connection co_return std::make_error_code(std::errc::protocol_error); } - header = ws_.format_header(dest_buf.length(), op, true); + header = ws_.encode_ws_header(dest_buf.length(), op, true, true, false); buffers.push_back(asio::buffer(header)); buffers.push_back(asio::buffer(dest_buf)); } else { #endif - header = ws_.format_header(msg.length(), op); + header = ws_.encode_ws_header(msg.length(), op, true, false, false); buffers.push_back(asio::buffer(header)); buffers.push_back(asio::buffer(msg)); #ifdef CINATRA_ENABLE_GZIP @@ -666,7 +666,6 @@ class coro_http_connection std::string close_msg = ws_.format_close_payload( close_code::normal, close_frame.message, close_frame.length); - auto header = ws_.format_header(close_msg.length(), opcode::close); co_await write_websocket(close_msg, opcode::close); close(); diff --git a/include/cinatra/websocket.hpp b/include/cinatra/websocket.hpp index 02dfac4b..82178960 100644 --- a/include/cinatra/websocket.hpp +++ b/include/cinatra/websocket.hpp @@ -83,7 +83,7 @@ class websocket { } if (msg_masked) { - std::memcpy(mask_, inp + pos, 4); + std::memcpy(mask_key_, inp + pos, 4); } return left_header_len_ == 0 ? ws_header_status::complete @@ -95,9 +95,9 @@ class websocket { ws_frame_type parse_payload(std::span buf) { // unmask data: - if (*(uint32_t *)mask_ != 0) { + if (*(uint32_t *)mask_key_ != 0) { for (size_t i = 0; i < payload_length_; i++) { - buf[i] = buf[i] ^ mask_[i % 4]; + buf[i] = buf[i] ^ mask_key_[i % 4]; } } @@ -121,16 +121,9 @@ class websocket { return ws_frame_type::WS_BINARY_FRAME; } - std::string format_header(size_t length, opcode code, - bool is_compressed = false) { - size_t header_length = encode_header(length, code, is_compressed); - return {msg_header_, header_length}; - } - - std::string encode_frame(std::span &data, opcode op, bool eof, - bool need_compression = false) { - std::string header; - /// Base header. + std::string_view encode_ws_header(size_t size, opcode op, bool eof, + bool need_compression = false, + bool is_client = true) { frame_header hdr{}; hdr.fin = eof; hdr.rsv1 = 0; @@ -140,56 +133,55 @@ class websocket { hdr.rsv2 = 0; hdr.rsv3 = 0; hdr.opcode = static_cast(op); - hdr.mask = 1; - - if (data.empty()) { - int mask = 0; - header.resize(sizeof(frame_header) + sizeof(mask)); - std::memcpy(header.data(), &hdr, sizeof(hdr)); - std::memcpy(header.data() + sizeof(hdr), &mask, sizeof(mask)); - return header; - } + hdr.mask = is_client; + + hdr.len = size < 126 ? size : (size < 65536 ? 126 : 127); - hdr.len = - data.size() < 126 ? data.size() : (data.size() < 65536 ? 126 : 127); - - uint8_t buffer[sizeof(frame_header)]; - std::memcpy(buffer, (uint8_t *)&hdr, sizeof(hdr)); - std::string str_hdr_len = - std::string((const char *)buffer, sizeof(frame_header)); - header.append(str_hdr_len); - - /// The payload length may be larger than 126 bytes. - std::string str_payload_len; - if (data.size() >= 126) { - if (data.size() >= 65536) { - uint64_t len = data.size(); - str_payload_len.resize(sizeof(uint64_t)); - *((uint64_t *)&str_payload_len[0]) = htobe64(len); + std::memcpy(msg_header_, (char *)&hdr, sizeof(hdr)); + + size_t len_bytes = 0; + if (size >= 126) { + if (size >= 65536) { + len_bytes = 8; + *((uint64_t *)(msg_header_ + 2)) = htobe64(size); } else { - uint16_t len = data.size(); - str_payload_len.resize(sizeof(uint16_t)); - *((uint16_t *)&str_payload_len[0]) = htons(static_cast(len)); + len_bytes = 2; + *((uint16_t *)(msg_header_ + 2)) = htons(static_cast(size)); } - header.append(str_payload_len); } - /// The mask is a 32-bit value. - uint8_t mask[4] = {}; - header[1] |= 0x80; - uint32_t random = (uint32_t)rand(); - memcpy(mask, &random, 4); + size_t header_len = 6; + + if (is_client) { + if (size > 0) { + // generate mask key. + uint32_t random = (uint32_t)rand(); + memcpy(mask_key_, &random, 4); + } + + std::memcpy(msg_header_ + 2 + len_bytes, mask_key_, 4); + } + else { + header_len = 2; + } - size_t size = header.size(); - header.resize(size + 4); - std::memcpy(header.data() + size, mask, 4); + return {msg_header_, header_len + len_bytes}; + } + void encode_ws_payload(std::span &data) { for (int i = 0; i < data.size(); ++i) { - data[i] ^= mask[i % 4]; + data[i] ^= mask_key_[i % 4]; } + } + + std::string_view encode_frame(std::span &data, opcode op, bool eof, + bool need_compression = false) { + std::string_view ws_header = + encode_ws_header(data.size(), op, eof, need_compression); + encode_ws_payload(data); - return header; + return ws_header; } close_frame parse_close_payload(char *src, size_t length) { @@ -264,7 +256,7 @@ class websocket { size_t payload_length_ = 0; size_t left_header_len_ = 0; - uint8_t mask_[4] = {}; + uint8_t mask_key_[4] = {}; unsigned char msg_opcode_ = 0; unsigned char msg_fin_ = 0;