diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 8e95871129d0..70ef6e01a38c 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -659,20 +659,26 @@ const Http::HeaderMap* Context::getConstMap(HeaderMapType type) { } return response_trailers_; case HeaderMapType::GrpcCreateInitialMetadata: - return grpc_create_initial_metadata_; + return rootContext()->grpc_create_initial_metadata_; case HeaderMapType::GrpcReceiveInitialMetadata: - return grpc_receive_initial_metadata_.get(); + return rootContext()->grpc_receive_initial_metadata_.get(); case HeaderMapType::GrpcReceiveTrailingMetadata: - return grpc_receive_trailing_metadata_.get(); - case HeaderMapType::HttpCallResponseHeaders: - if (http_call_response_) - return &(*http_call_response_)->headers(); + return rootContext()->grpc_receive_trailing_metadata_.get(); + case HeaderMapType::HttpCallResponseHeaders: { + Envoy::Http::MessagePtr* response = rootContext()->http_call_response_; + if (response) { + return &(*response)->headers(); + } return nullptr; - case HeaderMapType::HttpCallResponseTrailers: - if (http_call_response_) - return (*http_call_response_)->trailers(); + } + case HeaderMapType::HttpCallResponseTrailers: { + Envoy::Http::MessagePtr* response = rootContext()->http_call_response_; + if (response) { + return (*response)->trailers(); + } return nullptr; } + } return nullptr; } @@ -780,13 +786,14 @@ Buffer::Instance* Context::getBuffer(BufferType type) { return network_downstream_data_buffer_; case BufferType::NetworkUpstreamData: return network_upstream_data_buffer_; - case BufferType::HttpCallResponseBody: - if (http_call_response_) { - return (*http_call_response_)->body().get(); + case BufferType::HttpCallResponseBody: { + Envoy::Http::MessagePtr* response = rootContext()->http_call_response_; + if (response) { + return (*response)->body().get(); } - break; + } break; case BufferType::GrpcReceiveBuffer: - return grpc_receive_buffer_.get(); + return rootContext()->grpc_receive_buffer_.get(); default: break; } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 07082b9fba36..351db432d67a 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -81,7 +81,7 @@ class Context : public Logger::Loggable, uint32_t id() const { return id_; } bool isVmContext() { return id_ == 0; } bool isRootContext() { return root_context_id_ == 0; } - Context* root_context() { return root_context_; } + Context* rootContext() { return isRootContext() ? this : root_context_; } absl::string_view root_id() const { return plugin_ ? plugin_->root_id_ : root_id_; } absl::string_view log_prefix() const { return plugin_ ? plugin_->log_prefix_ : root_log_prefix_; } diff --git a/source/extensions/common/wasm/exports.cc b/source/extensions/common/wasm/exports.cc index df03669761de..293eb7658342 100644 --- a/source/extensions/common/wasm/exports.cc +++ b/source/extensions/common/wasm/exports.cc @@ -497,7 +497,7 @@ Word get_buffer_status(void* raw_context, Word type, Word length_ptr, Word flags Word http_call(void* raw_context, Word uri_ptr, Word uri_size, Word header_pairs_ptr, Word header_pairs_size, Word body_ptr, Word body_size, Word trailer_pairs_ptr, Word trailer_pairs_size, Word timeout_milliseconds, Word token_ptr) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); auto uri = context->wasmVm()->getMemory(uri_ptr.u64_, uri_size.u64_); auto body = context->wasmVm()->getMemory(body_ptr.u64_, body_size.u64_); auto header_pairs = context->wasmVm()->getMemory(header_pairs_ptr.u64_, header_pairs_size.u64_); @@ -567,7 +567,7 @@ Word get_metric(void* raw_context, Word metric_id, Word result_uint64_ptr) { Word grpc_call(void* raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, Word request_ptr, Word request_size, Word timeout_milliseconds, Word token_ptr) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_); auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_); auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_); @@ -595,7 +595,7 @@ Word grpc_call(void* raw_context, Word service_ptr, Word service_size, Word serv Word grpc_stream(void* raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, Word token_ptr) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_); auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_); auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_); @@ -619,18 +619,18 @@ Word grpc_stream(void* raw_context, Word service_ptr, Word service_size, Word se } Word grpc_cancel(void* raw_context, Word token) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); return wasmResultToWord(context->grpcCancel(token.u64_)); } Word grpc_close(void* raw_context, Word token) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); return wasmResultToWord(context->grpcClose(token.u64_)); } Word grpc_send(void* raw_context, Word token, Word message_ptr, Word message_size, Word end_stream) { - auto context = WASM_CONTEXT(raw_context)->root_context(); + auto context = WASM_CONTEXT(raw_context)->rootContext(); auto message = context->wasmVm()->getMemory(message_ptr.u64_, message_size.u64_); if (!message) { return wasmResultToWord(WasmResult::InvalidMemoryAccess); diff --git a/source/extensions/common/wasm/foreign.cc b/source/extensions/common/wasm/foreign.cc index 3c1f43d2d9d8..1ce54af8ba7d 100644 --- a/source/extensions/common/wasm/foreign.cc +++ b/source/extensions/common/wasm/foreign.cc @@ -124,7 +124,7 @@ class CreateExpressionFactory : public ExpressionFactory, public ForeignFunction return WasmResult::BadArgument; } - auto& expr_context = getOrCreateContext(current_context_->root_context()); + auto& expr_context = getOrCreateContext(current_context_->rootContext()); auto token = expr_context.createToken(); auto& handler = expr_context.getExpression(token); @@ -154,7 +154,7 @@ class EvaluateExpressionFactory : public ExpressionFactory, public ForeignFuncti WasmForeignFunction create() const override { WasmForeignFunction f = [](Wasm&, absl::string_view argument, std::function alloc_result) -> WasmResult { - auto& expr_context = getOrCreateContext(current_context_->root_context()); + auto& expr_context = getOrCreateContext(current_context_->rootContext()); if (argument.size() != sizeof(uint32_t)) { return WasmResult::BadArgument; } @@ -190,7 +190,7 @@ class DeleteExpressionFactory : public ExpressionFactory, public ForeignFunction WasmForeignFunction create() const override { WasmForeignFunction f = [](Wasm&, absl::string_view argument, std::function) -> WasmResult { - auto& expr_context = getOrCreateContext(current_context_->root_context()); + auto& expr_context = getOrCreateContext(current_context_->rootContext()); if (argument.size() != sizeof(uint32_t)) { return WasmResult::BadArgument; } diff --git a/test/extensions/filters/http/wasm/test_data/async_call_cpp.cc b/test/extensions/filters/http/wasm/test_data/async_call_cpp.cc index c61ea7bd5265..e7509a67a8bb 100644 --- a/test/extensions/filters/http/wasm/test_data/async_call_cpp.cc +++ b/test/extensions/filters/http/wasm/test_data/async_call_cpp.cc @@ -16,9 +16,10 @@ FilterHeadersStatus ExampleContext::onRequestHeaders(uint32_t) { auto context_id = id(); auto callback = [context_id](uint32_t, size_t body_size, uint32_t) { auto response_headers = getHeaderMapPairs(HeaderMapType::HttpCallResponseHeaders); + // Switch context after getting headers, but before getting body to exercise both code paths. + getContext(context_id)->setEffectiveContext(); auto body = getBufferBytes(BufferType::HttpCallResponseBody, 0, body_size); auto response_trailers = getHeaderMapPairs(HeaderMapType::HttpCallResponseTrailers); - getContext(context_id)->setEffectiveContext(); for (auto& p : response_headers->pairs()) { logInfo(std::string(p.first) + std::string(" -> ") + std::string(p.second)); } diff --git a/test/extensions/filters/http/wasm/test_data/async_call_cpp.wasm b/test/extensions/filters/http/wasm/test_data/async_call_cpp.wasm index 30426c895af1..1a85dfb592e3 100644 Binary files a/test/extensions/filters/http/wasm/test_data/async_call_cpp.wasm and b/test/extensions/filters/http/wasm/test_data/async_call_cpp.wasm differ