Skip to content

Commit

Permalink
Make callout responses accessible from stream context. (envoyproxy#429)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Sikora <piotrsikora@google.com>
  • Loading branch information
PiotrSikora committed Feb 26, 2020
1 parent baf361b commit 8dbdb4d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 25 deletions.
35 changes: 21 additions & 14 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,20 +659,26 @@ const Http::HeaderMap* Context::getConstMap(HeaderMapType type) {
}
return response_trailers_;
case HeaderMapType::GrpcCreateInitialMetadata:
return grpc_create_initial_metadata_;
return rootContext()->grpc_create_initial_metadata_;
case HeaderMapType::GrpcReceiveInitialMetadata:
return grpc_receive_initial_metadata_.get();
return rootContext()->grpc_receive_initial_metadata_.get();
case HeaderMapType::GrpcReceiveTrailingMetadata:
return grpc_receive_trailing_metadata_.get();
case HeaderMapType::HttpCallResponseHeaders:
if (http_call_response_)
return &(*http_call_response_)->headers();
return rootContext()->grpc_receive_trailing_metadata_.get();
case HeaderMapType::HttpCallResponseHeaders: {
Envoy::Http::MessagePtr* response = rootContext()->http_call_response_;
if (response) {
return &(*response)->headers();
}
return nullptr;
case HeaderMapType::HttpCallResponseTrailers:
if (http_call_response_)
return (*http_call_response_)->trailers();
}
case HeaderMapType::HttpCallResponseTrailers: {
Envoy::Http::MessagePtr* response = rootContext()->http_call_response_;
if (response) {
return (*response)->trailers();
}
return nullptr;
}
}
return nullptr;
}

Expand Down Expand Up @@ -780,13 +786,14 @@ Buffer::Instance* Context::getBuffer(BufferType type) {
return network_downstream_data_buffer_;
case BufferType::NetworkUpstreamData:
return network_upstream_data_buffer_;
case BufferType::HttpCallResponseBody:
if (http_call_response_) {
return (*http_call_response_)->body().get();
case BufferType::HttpCallResponseBody: {
Envoy::Http::MessagePtr* response = rootContext()->http_call_response_;
if (response) {
return (*response)->body().get();
}
break;
} break;
case BufferType::GrpcReceiveBuffer:
return grpc_receive_buffer_.get();
return rootContext()->grpc_receive_buffer_.get();
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Context : public Logger::Loggable<Logger::Id::wasm>,
uint32_t id() const { return id_; }
bool isVmContext() { return id_ == 0; }
bool isRootContext() { return root_context_id_ == 0; }
Context* root_context() { return root_context_; }
Context* rootContext() { return isRootContext() ? this : root_context_; }

absl::string_view root_id() const { return plugin_ ? plugin_->root_id_ : root_id_; }
absl::string_view log_prefix() const { return plugin_ ? plugin_->log_prefix_ : root_log_prefix_; }
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/common/wasm/exports.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ Word get_buffer_status(void* raw_context, Word type, Word length_ptr, Word flags
Word http_call(void* raw_context, Word uri_ptr, Word uri_size, Word header_pairs_ptr,
Word header_pairs_size, Word body_ptr, Word body_size, Word trailer_pairs_ptr,
Word trailer_pairs_size, Word timeout_milliseconds, Word token_ptr) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
auto uri = context->wasmVm()->getMemory(uri_ptr.u64_, uri_size.u64_);
auto body = context->wasmVm()->getMemory(body_ptr.u64_, body_size.u64_);
auto header_pairs = context->wasmVm()->getMemory(header_pairs_ptr.u64_, header_pairs_size.u64_);
Expand Down Expand Up @@ -567,7 +567,7 @@ Word get_metric(void* raw_context, Word metric_id, Word result_uint64_ptr) {
Word grpc_call(void* raw_context, Word service_ptr, Word service_size, Word service_name_ptr,
Word service_name_size, Word method_name_ptr, Word method_name_size,
Word request_ptr, Word request_size, Word timeout_milliseconds, Word token_ptr) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_);
auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_);
auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_);
Expand Down Expand Up @@ -595,7 +595,7 @@ Word grpc_call(void* raw_context, Word service_ptr, Word service_size, Word serv
Word grpc_stream(void* raw_context, Word service_ptr, Word service_size, Word service_name_ptr,
Word service_name_size, Word method_name_ptr, Word method_name_size,
Word token_ptr) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_);
auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_);
auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_);
Expand All @@ -619,18 +619,18 @@ Word grpc_stream(void* raw_context, Word service_ptr, Word service_size, Word se
}

Word grpc_cancel(void* raw_context, Word token) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
return wasmResultToWord(context->grpcCancel(token.u64_));
}

Word grpc_close(void* raw_context, Word token) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
return wasmResultToWord(context->grpcClose(token.u64_));
}

Word grpc_send(void* raw_context, Word token, Word message_ptr, Word message_size,
Word end_stream) {
auto context = WASM_CONTEXT(raw_context)->root_context();
auto context = WASM_CONTEXT(raw_context)->rootContext();
auto message = context->wasmVm()->getMemory(message_ptr.u64_, message_size.u64_);
if (!message) {
return wasmResultToWord(WasmResult::InvalidMemoryAccess);
Expand Down
6 changes: 3 additions & 3 deletions source/extensions/common/wasm/foreign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class CreateExpressionFactory : public ExpressionFactory, public ForeignFunction
return WasmResult::BadArgument;
}

auto& expr_context = getOrCreateContext(current_context_->root_context());
auto& expr_context = getOrCreateContext(current_context_->rootContext());
auto token = expr_context.createToken();
auto& handler = expr_context.getExpression(token);

Expand Down Expand Up @@ -154,7 +154,7 @@ class EvaluateExpressionFactory : public ExpressionFactory, public ForeignFuncti
WasmForeignFunction create() const override {
WasmForeignFunction f = [](Wasm&, absl::string_view argument,
std::function<void*(size_t size)> alloc_result) -> WasmResult {
auto& expr_context = getOrCreateContext(current_context_->root_context());
auto& expr_context = getOrCreateContext(current_context_->rootContext());
if (argument.size() != sizeof(uint32_t)) {
return WasmResult::BadArgument;
}
Expand Down Expand Up @@ -190,7 +190,7 @@ class DeleteExpressionFactory : public ExpressionFactory, public ForeignFunction
WasmForeignFunction create() const override {
WasmForeignFunction f = [](Wasm&, absl::string_view argument,
std::function<void*(size_t size)>) -> WasmResult {
auto& expr_context = getOrCreateContext(current_context_->root_context());
auto& expr_context = getOrCreateContext(current_context_->rootContext());
if (argument.size() != sizeof(uint32_t)) {
return WasmResult::BadArgument;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ FilterHeadersStatus ExampleContext::onRequestHeaders(uint32_t) {
auto context_id = id();
auto callback = [context_id](uint32_t, size_t body_size, uint32_t) {
auto response_headers = getHeaderMapPairs(HeaderMapType::HttpCallResponseHeaders);
// Switch context after getting headers, but before getting body to exercise both code paths.
getContext(context_id)->setEffectiveContext();
auto body = getBufferBytes(BufferType::HttpCallResponseBody, 0, body_size);
auto response_trailers = getHeaderMapPairs(HeaderMapType::HttpCallResponseTrailers);
getContext(context_id)->setEffectiveContext();
for (auto& p : response_headers->pairs()) {
logInfo(std::string(p.first) + std::string(" -> ") + std::string(p.second));
}
Expand Down
Binary file modified test/extensions/filters/http/wasm/test_data/async_call_cpp.wasm
Binary file not shown.

0 comments on commit 8dbdb4d

Please sign in to comment.