diff --git a/example/benchmark.cpp b/example/benchmark.cpp index 05324af3..53637d9b 100644 --- a/example/benchmark.cpp +++ b/example/benchmark.cpp @@ -7,6 +7,7 @@ int main() { coro_http_server server(std::thread::hardware_concurrency(), 8090); server.set_http_handler( "/plaintext", [](coro_http_request& req, coro_http_response& resp) { + resp.get_conn()->set_multi_buf(false); resp.set_content_type(); resp.set_status_and_content(status_type::ok, "Hello, world!"); }); diff --git a/include/cinatra/coro_http_connection.hpp b/include/cinatra/coro_http_connection.hpp index 92aeaeb2..f0590379 100644 --- a/include/cinatra/coro_http_connection.hpp +++ b/include/cinatra/coro_http_connection.hpp @@ -338,6 +338,7 @@ class coro_http_connection buffers_.clear(); body_.clear(); resp_str_.clear(); + multi_buf_ = true; if (need_shrink_every_time_) { body_.shrink_to_fit(); } @@ -345,11 +346,21 @@ class coro_http_connection } async_simple::coro::Lazy reply(bool need_to_bufffer = true) { - // avoid duplicate reply - if (need_to_bufffer) { - response_.to_buffers(buffers_); + std::error_code ec; + size_t size; + if (multi_buf_) { + if (need_to_bufffer) { + response_.to_buffers(buffers_); + } + std::tie(ec, size) = co_await async_write(buffers_); } - auto [ec, _] = co_await async_write(buffers_); + else { + if (need_to_bufffer) { + response_.build_resp_str(resp_str_); + } + std::tie(ec, size) = co_await async_write(asio::buffer(resp_str_)); + } + if (ec) { CINATRA_LOG_ERROR << "async_write error: " << ec.message(); close(); @@ -394,6 +405,8 @@ class coro_http_connection return ss.str(); } + void set_multi_buf(bool r) { multi_buf_ = r; } + async_simple::coro::Lazy write_data(std::string_view message) { std::vector buffers; buffers.push_back(asio::buffer(message)); @@ -761,13 +774,10 @@ class coro_http_connection private: bool check_keep_alive() { - bool keep_alive = true; - auto val = request_.get_header_value("connection"); - if (!val.empty() && iequal0(val, "close")) { - keep_alive = false; + if (parser_.has_close()) { + return false; } - - return keep_alive; + return true; } void build_ws_handshake_head() { @@ -823,5 +833,6 @@ class coro_http_connection bool use_ssl_ = false; #endif bool need_shrink_every_time_ = false; + bool multi_buf_ = true; }; } // namespace cinatra diff --git a/include/cinatra/coro_http_request.hpp b/include/cinatra/coro_http_request.hpp index fa45b9cb..f5788d31 100644 --- a/include/cinatra/coro_http_request.hpp +++ b/include/cinatra/coro_http_request.hpp @@ -188,17 +188,13 @@ class coro_http_request { coro_http_connection *get_conn() { return conn_; } bool is_upgrade() { - auto h = get_header_value("Connection"); - if (h.empty()) + if (!parser_.has_upgrade()) return false; auto u = get_header_value("Upgrade"); if (u.empty()) return false; - if (h != UPGRADE) - return false; - if (u != WEBSOCKET) return false; diff --git a/include/cinatra/coro_http_response.hpp b/include/cinatra/coro_http_response.hpp index e3c73e8a..0d1d61ac 100644 --- a/include/cinatra/coro_http_response.hpp +++ b/include/cinatra/coro_http_response.hpp @@ -94,6 +94,7 @@ class coro_http_response { status_type status() { return status_; } std::string_view content() { return content_; } + size_t content_size() { return content_.size(); } void add_header(auto k, auto v) { resp_headers_.emplace_back(resp_header{std::move(k), std::move(v)}); diff --git a/include/cinatra/http_parser.hpp b/include/cinatra/http_parser.hpp index 6c2fa898..15e1e2f5 100644 --- a/include/cinatra/http_parser.hpp +++ b/include/cinatra/http_parser.hpp @@ -8,6 +8,7 @@ #include #include "cinatra_log_wrapper.hpp" +#include "define.h" #include "picohttpparser.h" #include "url_encode_decode.hpp" @@ -64,9 +65,12 @@ class http_parser { size_t method_len; const char *url; size_t url_len; + + bool has_query{}; header_len_ = detail::phr_parse_request( data, size, &method, &method_len, &url, &url_len, &minor_version, - headers_.data(), &num_headers_, last_len); + headers_.data(), &num_headers_, last_len, has_connection_, has_close_, + has_upgrade_, has_query); if (header_len_ < 0) [[unlikely]] { CINATRA_LOG_WARNING << "parse http head failed"; @@ -76,21 +80,28 @@ class http_parser { << ", you can define macro " "CINATRA_MAX_HTTP_HEADER_FIELD_SIZE to expand it."; } + return header_len_; } method_ = {method, method_len}; url_ = {url, url_len}; - auto content_len = this->get_header_value("content-length"sv); - if (content_len.empty()) { + auto methd_type = method_type(method_); + if (methd_type == http_method::GET || methd_type == http_method::HEAD) { body_len_ = 0; } else { - body_len_ = atoi(content_len.data()); + auto content_len = this->get_header_value("content-length"sv); + if (content_len.empty()) { + body_len_ = 0; + } + else { + body_len_ = atoi(content_len.data()); + } } - size_t pos = url_.find('?'); - if (pos != std::string_view::npos) { + if (has_query) { + size_t pos = url_.find('?'); parse_query(url_.substr(pos + 1, url_len - pos - 1)); url_ = {url, pos}; } @@ -98,6 +109,12 @@ class http_parser { return header_len_; } + bool has_connection() { return has_connection_; } + + bool has_close() { return has_close_; } + + bool has_upgrade() { return has_upgrade_; } + std::string_view get_header_value(std::string_view key) const { for (size_t i = 0; i < num_headers_; i++) { if (iequal0(headers_[i].name, key)) @@ -247,6 +264,9 @@ class http_parser { size_t num_headers_ = 0; int header_len_ = 0; int body_len_ = 0; + bool has_connection_{}; + bool has_close_{}; + bool has_upgrade_{}; std::array headers_; std::string_view method_; std::string_view url_; diff --git a/include/cinatra/picohttpparser.h b/include/cinatra/picohttpparser.h index 031480cf..310044b3 100644 --- a/include/cinatra/picohttpparser.h +++ b/include/cinatra/picohttpparser.h @@ -808,7 +808,9 @@ static const char *parse_headers(const char *buf, const char *buf_end, static const char *parse_headers(const char *buf, const char *buf_end, http_header *headers, size_t *num_headers, - size_t max_headers, int *ret) { + size_t max_headers, int *ret, + bool &has_connection, bool &has_close, + bool &has_upgrade) { for (;; ++*num_headers) { const char *name; size_t name_len; @@ -877,6 +879,21 @@ static const char *parse_headers(const char *buf, const char *buf_end, NULL) { return NULL; } + if (name_len == 10) { + if (memcmp(name + 1, "onnection", name_len - 1) == 0) { + // has connection + has_connection = true; + char ch = *value; + if (ch == 'U') { + // has upgrade + has_upgrade = true; + } + else if (ch == 'c' || ch == 'C') { + // has_close + has_close = true; + } + } + } headers[*num_headers] = {std::string_view{name, name_len}, std::string_view{value, value_len}}; } @@ -885,12 +902,40 @@ static const char *parse_headers(const char *buf, const char *buf_end, #endif -static const char *parse_request(const char *buf, const char *buf_end, - const char **method, size_t *method_len, - const char **path, size_t *path_len, - int *minor_version, http_header *headers, - size_t *num_headers, size_t max_headers, - int *ret) { +#define ADVANCE_PATH(tok, toklen, has_query) \ + do { \ + const char *tok_start = buf; \ + static const char ALIGNED(16) ranges2[] = "\000\040\177\177"; \ + int found2; \ + buf = findchar_fast(buf, buf_end, ranges2, sizeof(ranges2) - 1, &found2); \ + if (!found2) { \ + CHECK_EOF(); \ + } \ + while (1) { \ + if (*buf == ' ') { \ + break; \ + } \ + else if (unlikely(!IS_PRINTABLE_ASCII(*buf))) { \ + if ((unsigned char)*buf < '\040' || *buf == '\177') { \ + *ret = -1; \ + return NULL; \ + } \ + } \ + else if (unlikely(*buf == '?')) { \ + has_query = true; \ + } \ + ++buf; \ + CHECK_EOF(); \ + } \ + tok = tok_start; \ + toklen = buf - tok_start; \ + } while (0) + +static const char *parse_request( + const char *buf, const char *buf_end, const char **method, + size_t *method_len, const char **path, size_t *path_len, int *minor_version, + http_header *headers, size_t *num_headers, size_t max_headers, int *ret, + bool &has_connection, bool &has_close, bool &has_upgrade, bool &has_query) { /* skip first empty line (some clients add CRLF after POST content) */ CHECK_EOF(); if (*buf == '\015') { @@ -904,7 +949,7 @@ static const char *parse_request(const char *buf, const char *buf_end, /* parse request line */ ADVANCE_TOKEN(*method, *method_len); ++buf; - ADVANCE_TOKEN(*path, *path_len); + ADVANCE_PATH(*path, *path_len, has_query); ++buf; if ((buf = parse_http_version(buf, buf_end, minor_version, ret)) == NULL) { return NULL; @@ -921,14 +966,17 @@ static const char *parse_request(const char *buf, const char *buf_end, return NULL; } - return parse_headers(buf, buf_end, headers, num_headers, max_headers, ret); + return parse_headers(buf, buf_end, headers, num_headers, max_headers, ret, + has_connection, has_close, has_upgrade); } inline int phr_parse_request(const char *buf_start, size_t len, const char **method, size_t *method_len, const char **path, size_t *path_len, int *minor_version, http_header *headers, - size_t *num_headers, size_t last_len) { + size_t *num_headers, size_t last_len, + bool &has_connection, bool &has_close, + bool &has_upgrade, bool &has_query) { const char *buf = buf_start, *buf_end = buf_start + len; size_t max_headers = *num_headers; int r; @@ -948,7 +996,8 @@ inline int phr_parse_request(const char *buf_start, size_t len, if ((buf = parse_request(buf + last_len, buf_end, method, method_len, path, path_len, minor_version, headers, num_headers, - max_headers, &r)) == NULL) { + max_headers, &r, has_connection, has_close, + has_upgrade, has_query)) == NULL) { return r; } @@ -987,7 +1036,10 @@ inline const char *parse_response(const char *buf, const char *buf_end, return NULL; } - return parse_headers(buf, buf_end, headers, num_headers, max_headers, ret); + bool has_connection, has_close, has_upgrade; + + return parse_headers(buf, buf_end, headers, num_headers, max_headers, ret, + has_connection, has_close, has_upgrade); } inline int phr_parse_response(const char *buf_start, size_t len, @@ -1033,8 +1085,9 @@ inline int phr_parse_headers(const char *buf_start, size_t len, return r; } - if ((buf = parse_headers(buf, buf_end, headers, num_headers, max_headers, - &r)) == NULL) { + bool has_connection, has_close, has_upgrade; + if ((buf = parse_headers(buf, buf_end, headers, num_headers, max_headers, &r, + has_connection, has_close, has_upgrade)) == NULL) { return r; } diff --git a/tests/test_http_parse.cpp b/tests/test_http_parse.cpp index 36fa6597..acee8e49 100644 --- a/tests/test_http_parse.cpp +++ b/tests/test_http_parse.cpp @@ -36,11 +36,13 @@ TEST_CASE("http parser test") { cinatra::http_header headers[64]; size_t num_headers; int i, ret; + bool has_connection, has_close, has_upgrade, has_query; num_headers = sizeof(headers) / sizeof(headers[0]); ret = cinatra::detail::phr_parse_request( REQ, sizeof(REQ) - 1, &method, &method_len, &path, &path_len, - &minor_version, headers, &num_headers, 0); + &minor_version, headers, &num_headers, 0, has_connection, has_close, + has_upgrade, has_query); CHECK(ret == 703); CHECK(strncmp(method, "GET", method_len) == 0); CHECK(minor_version == 1);