From 376ffaf53323b0b678f9d507b3b3ce81aa13f1b4 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Tue, 10 Nov 2020 08:13:51 +0900 Subject: [PATCH] Force stop iteration after local response is sent (#88) Signed-off-by: mathetake --- include/proxy-wasm/context.h | 13 +++++ src/context.cc | 106 +++++++++++++++++++---------------- src/exports.cc | 1 + 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index ddad88e8..03657a1b 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -167,6 +167,11 @@ class ContextBase : public RootInterface, // Called before deleting the context. virtual void destroy(); + // Called to raise the flag which indicates that the context should stop iteration regardless of + // returned filter status from Proxy-Wasm extensions. For example, we ignore + // FilterHeadersStatus::Continue after a local reponse is sent by the host. + void stopIteration() { stop_iteration_ = true; }; + /** * Calls into the VM. * These are implemented by the proxy-independent host code. They are virtual to support some @@ -385,6 +390,14 @@ class ContextBase : public RootInterface, std::shared_ptr plugin_; bool in_vm_context_created_ = false; bool destroyed_ = false; + bool stop_iteration_ = false; + +private: + // helper functions + FilterHeadersStatus convertVmCallResultToFilterHeadersStatus(uint64_t result); + FilterDataStatus convertVmCallResultToFilterDataStatus(uint64_t result); + FilterTrailersStatus convertVmCallResultToFilterTrailersStatus(uint64_t result); + FilterMetadataStatus convertVmCallResultToFilterMetadataStatus(uint64_t result); }; class DeferAfterCallActions { diff --git a/src/context.cc b/src/context.cc index c1929e14..6ab4b5e0 100644 --- a/src/context.cc +++ b/src/context.cc @@ -477,93 +477,71 @@ FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_ CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue, FilterHeadersStatus::StopIteration); DeferAfterCallActions actions(this); - auto result = wasm_->on_request_headers_abi_01_ - ? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_ - : wasm_ - ->on_request_headers_abi_02_(this, id_, headers, - static_cast(end_of_stream)) - .u64_; - if (result > static_cast(FilterHeadersStatus::StopAllIterationAndWatermark)) - return FilterHeadersStatus::StopAllIterationAndWatermark; - return static_cast(result); + return convertVmCallResultToFilterHeadersStatus( + wasm_->on_request_headers_abi_01_ + ? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_ + : wasm_ + ->on_request_headers_abi_02_(this, id_, headers, + static_cast(end_of_stream)) + .u64_); } FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) { CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); DeferAfterCallActions actions(this); - auto result = - wasm_->on_request_body_(this, id_, data_length, static_cast(end_of_stream)).u64_; - if (result > static_cast(FilterDataStatus::StopIterationNoBuffer)) - return FilterDataStatus::StopIterationNoBuffer; - return static_cast(result); + return convertVmCallResultToFilterDataStatus( + wasm_->on_request_body_(this, id_, data_length, static_cast(end_of_stream)).u64_); } FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) { CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); DeferAfterCallActions actions(this); - if (static_cast(wasm_->on_request_trailers_(this, id_, trailers).u64_) == - FilterTrailersStatus::Continue) { - return FilterTrailersStatus::Continue; - } - return FilterTrailersStatus::StopIteration; + return convertVmCallResultToFilterTrailersStatus( + wasm_->on_request_trailers_(this, id_, trailers).u64_); } FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) { CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); DeferAfterCallActions actions(this); - if (static_cast(wasm_->on_request_metadata_(this, id_, elements).u64_) == - FilterMetadataStatus::Continue) { - return FilterMetadataStatus::Continue; - } - return FilterMetadataStatus::Continue; // This is currently the only return code. + return convertVmCallResultToFilterMetadataStatus( + wasm_->on_request_metadata_(this, id_, elements).u64_); } FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) { CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_, FilterHeadersStatus::Continue, FilterHeadersStatus::StopIteration); DeferAfterCallActions actions(this); - auto result = wasm_->on_response_headers_abi_01_ - ? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_ - : wasm_ - ->on_response_headers_abi_02_(this, id_, headers, - static_cast(end_of_stream)) - .u64_; - if (result > static_cast(FilterHeadersStatus::StopAllIterationAndWatermark)) - return FilterHeadersStatus::StopAllIterationAndWatermark; - return static_cast(result); + return convertVmCallResultToFilterHeadersStatus( + wasm_->on_response_headers_abi_01_ + ? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_ + : wasm_ + ->on_response_headers_abi_02_(this, id_, headers, + static_cast(end_of_stream)) + .u64_); } FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) { CHECK_HTTP(on_response_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); DeferAfterCallActions actions(this); - auto result = - wasm_->on_response_body_(this, id_, body_length, static_cast(end_of_stream)).u64_; - if (result > static_cast(FilterDataStatus::StopIterationNoBuffer)) - return FilterDataStatus::StopIterationNoBuffer; - return static_cast(result); + return convertVmCallResultToFilterDataStatus( + wasm_->on_response_body_(this, id_, body_length, static_cast(end_of_stream)).u64_); } FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) { CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); DeferAfterCallActions actions(this); - if (static_cast(wasm_->on_response_trailers_(this, id_, trailers).u64_) == - FilterTrailersStatus::Continue) { - return FilterTrailersStatus::Continue; - } - return FilterTrailersStatus::StopIteration; + return convertVmCallResultToFilterTrailersStatus( + wasm_->on_response_trailers_(this, id_, trailers).u64_); } FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) { CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); DeferAfterCallActions actions(this); - if (static_cast(wasm_->on_response_metadata_(this, id_, elements).u64_) == - FilterMetadataStatus::Continue) { - return FilterMetadataStatus::Continue; - } - return FilterMetadataStatus::Continue; // This is currently the only return code. + return convertVmCallResultToFilterMetadataStatus( + wasm_->on_response_metadata_(this, id_, elements).u64_); } void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size, @@ -643,6 +621,38 @@ WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period, return WasmResult::Ok; } +FilterHeadersStatus ContextBase::convertVmCallResultToFilterHeadersStatus(uint64_t result) { + if (stop_iteration_ || + result > static_cast(FilterHeadersStatus::StopAllIterationAndWatermark)) { + stop_iteration_ = false; + return FilterHeadersStatus::StopAllIterationAndWatermark; + } + return static_cast(result); +} + +FilterDataStatus ContextBase::convertVmCallResultToFilterDataStatus(uint64_t result) { + if (stop_iteration_ || result > static_cast(FilterDataStatus::StopIterationNoBuffer)) { + stop_iteration_ = false; + return FilterDataStatus::StopIterationNoBuffer; + } + return static_cast(result); +} + +FilterTrailersStatus ContextBase::convertVmCallResultToFilterTrailersStatus(uint64_t result) { + if (stop_iteration_ || result > static_cast(FilterTrailersStatus::StopIteration)) { + stop_iteration_ = false; + return FilterTrailersStatus::StopIteration; + } + return static_cast(result); +} + +FilterMetadataStatus ContextBase::convertVmCallResultToFilterMetadataStatus(uint64_t result) { + if (static_cast(result) == FilterMetadataStatus::Continue) { + return FilterMetadataStatus::Continue; + } + return FilterMetadataStatus::Continue; // This is currently the only return code. +} + ContextBase::~ContextBase() { // Do not remove vm or root contexts which have the same lifetime as wasm_. if (parent_context_id_) { diff --git a/src/exports.cc b/src/exports.cc index cda12514..4c31a75e 100644 --- a/src/exports.cc +++ b/src/exports.cc @@ -186,6 +186,7 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co auto additional_headers = toPairs(additional_response_header_pairs.value()); context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code, details.value()); + context->stopIteration(); return WasmResult::Ok; }