diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 03657a1b..2313d9cd 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -167,11 +167,6 @@ class ContextBase : public RootInterface, // Called before deleting the context. virtual void destroy(); - // Called to raise the flag which indicates that the context should stop iteration regardless of - // returned filter status from Proxy-Wasm extensions. For example, we ignore - // FilterHeadersStatus::Continue after a local reponse is sent by the host. - void stopIteration() { stop_iteration_ = true; }; - /** * Calls into the VM. * These are implemented by the proxy-independent host code. They are virtual to support some @@ -390,7 +385,6 @@ class ContextBase : public RootInterface, std::shared_ptr plugin_; bool in_vm_context_created_ = false; bool destroyed_ = false; - bool stop_iteration_ = false; private: // helper functions diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index b2c694db..c71c5848 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -122,6 +122,12 @@ class WasmBase : public std::enable_shared_from_this { AbiVersion abiVersion() { return abi_version_; } + // Called to raise the flag which indicates that the context should stop iteration regardless of + // returned filter status from Proxy-Wasm extensions. For example, we ignore + // FilterHeadersStatus::Continue after a local reponse is sent by the host. + void stopNextIteration(bool stop) { stop_iteration_ = stop; }; + bool isNextIterationStopped() { return stop_iteration_; }; + void addAfterVmCallAction(std::function f) { after_vm_call_actions_.push_back(f); } void doAfterVmCallActions() { // NB: this may be deleted by a delayed function unless prevented. @@ -223,6 +229,7 @@ class WasmBase : public std::enable_shared_from_this { std::string code_; std::string vm_configuration_; bool allow_precompiled_ = false; + bool stop_iteration_ = false; FailState failed_ = FailState::Ok; // Wasm VM fatal error. // ABI version. diff --git a/src/context.cc b/src/context.cc index 6ab4b5e0..89daa97a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -225,7 +225,10 @@ SharedData global_shared_data; } // namespace -DeferAfterCallActions::~DeferAfterCallActions() { wasm_->doAfterVmCallActions(); } +DeferAfterCallActions::~DeferAfterCallActions() { + wasm_->stopNextIteration(false); + wasm_->doAfterVmCallActions(); +} WasmResult BufferBase::copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr, uint64_t size_ptr) const { @@ -622,25 +625,24 @@ WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period, } FilterHeadersStatus ContextBase::convertVmCallResultToFilterHeadersStatus(uint64_t result) { - if (stop_iteration_ || + if (wasm()->isNextIterationStopped() || result > static_cast(FilterHeadersStatus::StopAllIterationAndWatermark)) { - stop_iteration_ = false; return FilterHeadersStatus::StopAllIterationAndWatermark; } return static_cast(result); } FilterDataStatus ContextBase::convertVmCallResultToFilterDataStatus(uint64_t result) { - if (stop_iteration_ || result > static_cast(FilterDataStatus::StopIterationNoBuffer)) { - stop_iteration_ = false; + if (wasm()->isNextIterationStopped() || + result > static_cast(FilterDataStatus::StopIterationNoBuffer)) { return FilterDataStatus::StopIterationNoBuffer; } return static_cast(result); } FilterTrailersStatus ContextBase::convertVmCallResultToFilterTrailersStatus(uint64_t result) { - if (stop_iteration_ || result > static_cast(FilterTrailersStatus::StopIteration)) { - stop_iteration_ = false; + if (wasm()->isNextIterationStopped() || + result > static_cast(FilterTrailersStatus::StopIteration)) { return FilterTrailersStatus::StopIteration; } return static_cast(result); diff --git a/src/exports.cc b/src/exports.cc index 4c31a75e..1ffee322 100644 --- a/src/exports.cc +++ b/src/exports.cc @@ -186,7 +186,7 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co auto additional_headers = toPairs(additional_response_header_pairs.value()); context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code, details.value()); - context->stopIteration(); + context->wasm()->stopNextIteration(true); return WasmResult::Ok; }