Skip to content
This repository has been archived by the owner on Feb 22, 2024. It is now read-only.

Ensure request does not hang in non-async runtimes #23

Merged
merged 1 commit into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
- name: Build
run: rustup target add wasm32-wasi && cargo build --verbose
- name: Run simple test
run: cargo run
run: cargo test --all -- --nocapture
92 changes: 65 additions & 27 deletions crates/wasi-experimental-http-wasmtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use anyhow::Error;
use bytes::Bytes;
use futures::executor::block_on;
use http::HeaderMap;
use reqwest::{Client, Method, Response};
use http::{HeaderMap, HeaderValue};
use reqwest::{Client, Method};
use std::str::FromStr;
use tokio::runtime::Handle;
use wasi_experimental_http;
use wasmtime::*;

Expand Down Expand Up @@ -61,7 +63,7 @@ pub fn link_http(linker: &mut Linker) -> Result<(), Error> {
};

// Get the URL, headers, method, and request body from the module's memory.
let (url, headers, method, req_body) = unsafe {
let (url, headers, method, req_body) = match unsafe {
http_parts_from_memory(
&memory,
url_ptr,
Expand All @@ -73,40 +75,40 @@ pub fn link_http(linker: &mut Linker) -> Result<(), Error> {
headers_ptr,
headers_len_ptr,
)
.unwrap()
} {
Ok(r) => (r.0, r.1, r.2, r.3),
Err(_) => {
return err(
"cannot get HTTP parts from memory".to_string(),
Some(&memory),
Some(&alloc),
err_ptr,
err_len_ptr,
4,
)
}
};

// TODO
// We probably need separate methods for blocking and non-blocking
// versions of the HTTP client.
// let res = reqwest::blocking::get(&url).unwrap().text().unwrap();

// Make the HTTP request using `reqwest`.
let client = Client::builder().build().unwrap();
let res = match block_on(
client
.request(method, &url)
.headers(headers)
.body(req_body)
.send(),
) {
Ok(r) => r,
let (status, headers, body) = match request(url, headers, method, req_body) {
Ok(r) => (r.0, r.1, r.2),
Err(e) => {
return err(
e.to_string(),
Some(&memory),
Some(&alloc),
Some(&memory.clone()),
Some(&alloc.clone()),
err_ptr,
err_len_ptr,
2,
3,
)
}
};

// Write the HTTP response back to the module's memory.
unsafe {
match write_http_response_to_memory(
res,
status,
headers,
body,
memory.clone(),
alloc.clone(),
headers_written_ptr,
Expand Down Expand Up @@ -135,6 +137,42 @@ pub fn link_http(linker: &mut Linker) -> Result<(), Error> {
Ok(())
}

fn request(
url: String,
headers: HeaderMap,
method: Method,
body: Vec<u8>,
) -> Result<(u16, HeaderMap<HeaderValue>, Bytes), Error> {
match Handle::try_current() {
Ok(_) => {
println!("wasi_experimental_http::request: tokio runtime available");
let client = Client::builder().build().unwrap();
let res = block_on(
client
.request(method, &url)
.headers(headers)
.body(body)
.send(),
)?;

return Ok((
res.status().as_u16(),
res.headers().clone(),
block_on(res.bytes())?,
));
}
Err(_) => {
println!("wasi_experimental_http::request: no Tokio runtime available");
let res = reqwest::blocking::Client::new()
.request(method, &url)
.headers(headers)
.body(body)
.send()?;
return Ok((res.status().as_u16(), res.headers().clone(), res.bytes()?));
}
};
}

/// Get the URL, method, request body, and headers from the
/// module's memory.
unsafe fn http_parts_from_memory(
Expand All @@ -160,7 +198,9 @@ unsafe fn http_parts_from_memory(

/// Write the response data to the module's memory.
unsafe fn write_http_response_to_memory(
res: Response,
status: u16,
headers: HeaderMap,
res: Bytes,
memory: Memory,
alloc: Func,
headers_written_ptr: u32,
Expand All @@ -169,9 +209,7 @@ unsafe fn write_http_response_to_memory(
status_code_ptr: u32,
body_written_ptr: u32,
) -> Result<(), Error> {
let hs = wasi_experimental_http::header_map_to_string(res.headers())?;
let status = res.status().as_u16();
let res = block_on(res.bytes())?;
let hs = wasi_experimental_http::header_map_to_string(&headers)?;
// Write the response headers.
write(
&hs.as_bytes().to_vec(),
Expand Down
72 changes: 72 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#[cfg(test)]
mod tests {
use anyhow::Error;
use std::time::Instant;
use wasi_experimental_http_wasmtime::link_http;
use wasmtime::*;
use wasmtime_wasi::{Wasi, WasiCtxBuilder};

#[test]
fn test_http() {
let modules = vec![
"target/wasm32-wasi/release/simple_wasi_http_tests.wasm",
"tests/as/build/optimized.wasm",
];
let test_funcs = vec!["get", "post"];

for module in modules {
let instance = create_instance(module.to_string()).unwrap();
run_tests(&instance, &test_funcs.clone()).unwrap();
}
}

#[tokio::test(flavor = "multi_thread")]
async fn test_http_async() {
let modules = vec![
"target/wasm32-wasi/release/simple_wasi_http_tests.wasm",
"tests/as/build/optimized.wasm",
];
let test_funcs = vec!["get", "post"];

for module in modules {
let instance = create_instance(module.to_string()).unwrap();
run_tests(&instance, &test_funcs.clone()).unwrap();
}
}

/// Execute the module's `_start` function.
fn run_tests(instance: &Instance, test_funcs: &Vec<&str>) -> Result<(), Error> {
for func in test_funcs.iter() {
let func = instance.get_func(func).expect("cannot find function");
func.call(&vec![])?;
}

Ok(())
}

/// Create a Wasmtime::Instance from a compiled module and
/// link the WASI imports.
fn create_instance(filename: String) -> Result<Instance, Error> {
let start = Instant::now();
let store = Store::default();
let mut linker = Linker::new(&store);

let ctx = WasiCtxBuilder::new()
.inherit_stdin()
.inherit_stdout()
.inherit_stderr()
.build()?;

let wasi = Wasi::new(&store, ctx);
wasi.add_to_linker(&mut linker)?;
// Link `wasi_experimental_http::req`.
link_http(&mut linker)?;

let module = wasmtime::Module::from_file(store.engine(), filename)?;

let instance = linker.instantiate(&module)?;
let duration = start.elapsed();
println!("module instantiation time: {:#?}", duration);
Ok(instance)
}
}
57 changes: 0 additions & 57 deletions src/main.rs

This file was deleted.