Skip to content

Commit

Permalink
Force stop iteration after local response is sent (#88)
Browse files Browse the repository at this point in the history
Signed-off-by: mathetake <takeshi@tetrate.io>
  • Loading branch information
mathetake authored Nov 9, 2020
1 parent 64313a6 commit 376ffaf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 48 deletions.
13 changes: 13 additions & 0 deletions include/proxy-wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ class ContextBase : public RootInterface,
// Called before deleting the context.
virtual void destroy();

// Called to raise the flag which indicates that the context should stop iteration regardless of
// returned filter status from Proxy-Wasm extensions. For example, we ignore
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
void stopIteration() { stop_iteration_ = true; };

/**
* Calls into the VM.
* These are implemented by the proxy-independent host code. They are virtual to support some
Expand Down Expand Up @@ -385,6 +390,14 @@ class ContextBase : public RootInterface,
std::shared_ptr<PluginBase> plugin_;
bool in_vm_context_created_ = false;
bool destroyed_ = false;
bool stop_iteration_ = false;

private:
// helper functions
FilterHeadersStatus convertVmCallResultToFilterHeadersStatus(uint64_t result);
FilterDataStatus convertVmCallResultToFilterDataStatus(uint64_t result);
FilterTrailersStatus convertVmCallResultToFilterTrailersStatus(uint64_t result);
FilterMetadataStatus convertVmCallResultToFilterMetadataStatus(uint64_t result);
};

class DeferAfterCallActions {
Expand Down
106 changes: 58 additions & 48 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,93 +477,71 @@ FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_
CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue,
FilterHeadersStatus::StopIteration);
DeferAfterCallActions actions(this);
auto result = wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_;
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
return FilterHeadersStatus::StopAllIterationAndWatermark;
return static_cast<FilterHeadersStatus>(result);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
}

FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) {
CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
DeferAfterCallActions actions(this);
auto result =
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
return FilterDataStatus::StopIterationNoBuffer;
return static_cast<FilterDataStatus>(result);
return convertVmCallResultToFilterDataStatus(
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_);
}

FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
DeferAfterCallActions actions(this);
if (static_cast<FilterTrailersStatus>(wasm_->on_request_trailers_(this, id_, trailers).u64_) ==
FilterTrailersStatus::Continue) {
return FilterTrailersStatus::Continue;
}
return FilterTrailersStatus::StopIteration;
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_request_trailers_(this, id_, trailers).u64_);
}

FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
DeferAfterCallActions actions(this);
if (static_cast<FilterMetadataStatus>(wasm_->on_request_metadata_(this, id_, elements).u64_) ==
FilterMetadataStatus::Continue) {
return FilterMetadataStatus::Continue;
}
return FilterMetadataStatus::Continue; // This is currently the only return code.
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_request_metadata_(this, id_, elements).u64_);
}

FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) {
CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_,
FilterHeadersStatus::Continue, FilterHeadersStatus::StopIteration);
DeferAfterCallActions actions(this);
auto result = wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_;
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
return FilterHeadersStatus::StopAllIterationAndWatermark;
return static_cast<FilterHeadersStatus>(result);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
}

FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) {
CHECK_HTTP(on_response_body_, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
DeferAfterCallActions actions(this);
auto result =
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
return FilterDataStatus::StopIterationNoBuffer;
return static_cast<FilterDataStatus>(result);
return convertVmCallResultToFilterDataStatus(
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_);
}

FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
DeferAfterCallActions actions(this);
if (static_cast<FilterTrailersStatus>(wasm_->on_response_trailers_(this, id_, trailers).u64_) ==
FilterTrailersStatus::Continue) {
return FilterTrailersStatus::Continue;
}
return FilterTrailersStatus::StopIteration;
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_response_trailers_(this, id_, trailers).u64_);
}

FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) {
CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
DeferAfterCallActions actions(this);
if (static_cast<FilterMetadataStatus>(wasm_->on_response_metadata_(this, id_, elements).u64_) ==
FilterMetadataStatus::Continue) {
return FilterMetadataStatus::Continue;
}
return FilterMetadataStatus::Continue; // This is currently the only return code.
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_response_metadata_(this, id_, elements).u64_);
}

void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,
Expand Down Expand Up @@ -643,6 +621,38 @@ WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period,
return WasmResult::Ok;
}

FilterHeadersStatus ContextBase::convertVmCallResultToFilterHeadersStatus(uint64_t result) {
if (stop_iteration_ ||
result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
stop_iteration_ = false;
return FilterHeadersStatus::StopAllIterationAndWatermark;
}
return static_cast<FilterHeadersStatus>(result);
}

FilterDataStatus ContextBase::convertVmCallResultToFilterDataStatus(uint64_t result) {
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer)) {
stop_iteration_ = false;
return FilterDataStatus::StopIterationNoBuffer;
}
return static_cast<FilterDataStatus>(result);
}

FilterTrailersStatus ContextBase::convertVmCallResultToFilterTrailersStatus(uint64_t result) {
if (stop_iteration_ || result > static_cast<uint64_t>(FilterTrailersStatus::StopIteration)) {
stop_iteration_ = false;
return FilterTrailersStatus::StopIteration;
}
return static_cast<FilterTrailersStatus>(result);
}

FilterMetadataStatus ContextBase::convertVmCallResultToFilterMetadataStatus(uint64_t result) {
if (static_cast<FilterMetadataStatus>(result) == FilterMetadataStatus::Continue) {
return FilterMetadataStatus::Continue;
}
return FilterMetadataStatus::Continue; // This is currently the only return code.
}

ContextBase::~ContextBase() {
// Do not remove vm or root contexts which have the same lifetime as wasm_.
if (parent_context_id_) {
Expand Down
1 change: 1 addition & 0 deletions src/exports.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co
auto additional_headers = toPairs(additional_response_header_pairs.value());
context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code,
details.value());
context->stopIteration();
return WasmResult::Ok;
}

Expand Down

0 comments on commit 376ffaf

Please sign in to comment.