diff --git a/Cargo.toml b/Cargo.toml index 2e5b8d7e..c3e4ac7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,3 +60,8 @@ crate-type = ["cdylib"] name = "http_body" path = "examples/http_body.rs" crate-type = ["cdylib"] + +[[example]] +name = "http_config" +path = "examples/http_config.rs" +crate-type = ["cdylib"] diff --git a/examples/http_body.rs b/examples/http_body.rs index 5db8ed8f..da1bda55 100644 --- a/examples/http_body.rs +++ b/examples/http_body.rs @@ -18,7 +18,21 @@ use proxy_wasm::types::*; #[no_mangle] pub fn _start() { proxy_wasm::set_log_level(LogLevel::Trace); - proxy_wasm::set_http_context(|_, _| -> Box { Box::new(HttpBody) }); + proxy_wasm::set_root_context(|_| -> Box { Box::new(HttpBodyRoot) }); +} + +struct HttpBodyRoot; + +impl Context for HttpBodyRoot {} + +impl RootContext for HttpBodyRoot { + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(HttpBody)) + } } struct HttpBody; diff --git a/examples/http_config.rs b/examples/http_config.rs new file mode 100644 index 00000000..fc297a08 --- /dev/null +++ b/examples/http_config.rs @@ -0,0 +1,64 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +#[no_mangle] +pub fn _start() { + proxy_wasm::set_log_level(LogLevel::Trace); + proxy_wasm::set_root_context(|_| -> Box { + Box::new(HttpConfigHeaderRoot { + header_content: String::new(), + }) + }); +} + +struct HttpConfigHeader { + header_content: String, +} + +impl Context for HttpConfigHeader {} + +impl HttpContext for HttpConfigHeader { + fn on_http_response_headers(&mut self, _num_headers: usize) -> Action { + self.add_http_response_header("custom-header", self.header_content.as_str()); + Action::Continue + } +} + +struct HttpConfigHeaderRoot { + header_content: String, +} + +impl Context for HttpConfigHeaderRoot {} + +impl RootContext for HttpConfigHeaderRoot { + fn on_configure(&mut self, _plugin_configuration_size: usize) -> bool { + if let Some(config_bytes) = self.get_configuration() { + self.header_content = String::from_utf8(config_bytes).unwrap() + } + true + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(HttpConfigHeader { + header_content: self.header_content.clone(), + })) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} diff --git a/examples/http_headers.rs b/examples/http_headers.rs index 14b32861..615fdcc6 100644 --- a/examples/http_headers.rs +++ b/examples/http_headers.rs @@ -19,9 +19,23 @@ use proxy_wasm::types::*; #[no_mangle] pub fn _start() { proxy_wasm::set_log_level(LogLevel::Trace); - proxy_wasm::set_http_context(|context_id, _| -> Box { - Box::new(HttpHeaders { context_id }) - }); + proxy_wasm::set_root_context(|_| -> Box { Box::new(HttpHeadersRoot) }); +} + +struct HttpHeadersRoot; + +impl Context for HttpHeadersRoot {} + +impl RootContext for HttpHeadersRoot { + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(HttpHeaders { + context_id: _context_id, + })) + } } struct HttpHeaders { diff --git a/src/dispatcher.rs b/src/dispatcher.rs index bc7d639b..d9c9d491 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -80,6 +80,17 @@ impl Dispatcher { self.new_http_stream.set(Some(callback)); } + fn register_callout(&self, token_id: u32) { + if self + .callouts + .borrow_mut() + .insert(token_id, self.active_id.get()) + .is_some() + { + panic!("duplicate token_id") + } + } + fn create_root_context(&self, context_id: u32) { let new_context = match self.new_root.get() { Some(f) => f(context_id), @@ -96,12 +107,15 @@ impl Dispatcher { } fn create_stream_context(&self, context_id: u32, root_context_id: u32) { - if !self.roots.borrow().contains_key(&root_context_id) { - panic!("invalid root_context_id") - } - let new_context = match self.new_stream.get() { - Some(f) => f(context_id, root_context_id), - None => panic!("missing constructor"), + let new_context = match self.roots.borrow().get(&root_context_id) { + Some(root_context) => match self.new_stream.get() { + Some(f) => f(context_id, root_context_id), + None => match root_context.create_stream_context(context_id) { + Some(stream_context) => stream_context, + None => panic!("create_stream_context returned None"), + }, + }, + None => panic!("invalid root_context_id"), }; if self .streams @@ -114,12 +128,15 @@ impl Dispatcher { } fn create_http_context(&self, context_id: u32, root_context_id: u32) { - if !self.roots.borrow().contains_key(&root_context_id) { - panic!("invalid root_context_id") - } - let new_context = match self.new_http_stream.get() { - Some(f) => f(context_id, root_context_id), - None => panic!("missing constructor"), + let new_context = match self.roots.borrow().get(&root_context_id) { + Some(root_context) => match self.new_http_stream.get() { + Some(f) => f(context_id, root_context_id), + None => match root_context.create_http_context(context_id) { + Some(stream_context) => stream_context, + None => panic!("create_http_context returned None"), + }, + }, + None => panic!("invalid root_context_id"), }; if self .http_streams @@ -131,26 +148,25 @@ impl Dispatcher { } } - fn register_callout(&self, token_id: u32) { - if self - .callouts - .borrow_mut() - .insert(token_id, self.active_id.get()) - .is_some() - { - panic!("duplicate token_id") - } - } - fn on_create_context(&self, context_id: u32, root_context_id: u32) { if root_context_id == 0 { - self.create_root_context(context_id) + self.create_root_context(context_id); } else if self.new_http_stream.get().is_some() { self.create_http_context(context_id, root_context_id); } else if self.new_stream.get().is_some() { self.create_stream_context(context_id, root_context_id); + } else if let Some(root_context) = self.roots.borrow().get(&root_context_id) { + match root_context.get_type() { + Some(ContextType::HttpContext) => { + self.create_http_context(context_id, root_context_id) + } + Some(ContextType::StreamContext) => { + self.create_stream_context(context_id, root_context_id) + } + None => panic!("missing ContextType on root_context"), + } } else { - panic!("missing constructors") + panic!("invalid root_context_id and missing constructors"); } } diff --git a/src/traits.rs b/src/traits.rs index 10b24514..5b7fc4be 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -121,6 +121,18 @@ pub trait RootContext: Context { fn on_queue_ready(&mut self, _queue_id: u32) {} fn on_log(&mut self) {} + + fn create_http_context(&self, _context_id: u32) -> Option> { + None + } + + fn create_stream_context(&self, _context_id: u32) -> Option> { + None + } + + fn get_type(&self) -> Option { + None + } } pub trait StreamContext: Context { diff --git a/src/types.rs b/src/types.rs index a951f78f..855a414b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -47,6 +47,13 @@ pub enum Status { InternalFailure = 10, } +#[repr(u32)] +#[derive(Debug)] +pub enum ContextType { + HttpContext = 0, + StreamContext = 1, +} + #[repr(u32)] #[derive(Debug)] pub enum BufferType {