Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(next): expand macro into axum routes #488

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codegen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod main;
mod next;

use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;
Expand Down
192 changes: 192 additions & 0 deletions codegen/src/next/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use proc_macro_error::emit_error;
use quote::{quote, ToTokens};
use syn::{Ident, LitStr};

struct Endpoint {
route: LitStr,
method: Ident,
function: Ident,
}

impl ToTokens for Endpoint {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let Self {
route,
method,
function,
} = self;

match method.to_string().as_str() {
"get" | "post" | "delete" | "put" | "options" | "head" | "trace" | "patch" => {}
_ => {
emit_error!(
method,
"method is not supported";
hint = "Try one of the following: `get`, `post`, `delete`, `put`, `options`, `head`, `trace` or `patch`"
)
}
};

let route = quote!(.route(#route, axum::routing::#method(#function)));

route.to_tokens(tokens);
}
}

pub(crate) struct App {
endpoints: Vec<Endpoint>,
}

impl ToTokens for App {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let Self { endpoints } = self;

let app = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use tower_service::Service;

let mut router = axum::Router::new()
#(#endpoints)*
.into_service();

let response = router.call(request).await.unwrap();

response
}
);

app.to_tokens(tokens);
}
}

pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
quote!(
#app

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };

let reader = std::io::BufReader::new(&mut parts_fd);

// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}

let request: http::Request<axum::body::Body> = wrapper
.into_request_builder()
.body(body_buf.into())
.unwrap();

println!("inner router received request: {:?}", &request);
let res = futures_executor::block_on(__app(request));

let (parts, mut body) = res.into_parts();

// wrap and serialize response parts as rmp
let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp();

// write response parts
parts_fd.write_all(&response_parts).unwrap();

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}
)
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use quote::quote;
use syn::parse_quote;

use crate::next::App;

use super::Endpoint;

#[test]
fn endpoint_to_token() {
let endpoint = Endpoint {
route: parse_quote!("/hello"),
method: parse_quote!(get),
function: parse_quote!(hello),
};

let actual = quote!(#endpoint);
let expected = quote!(.route("/hello", axum::routing::get(hello)));

assert_eq!(actual.to_string(), expected.to_string());
}

#[test]
fn app_to_token() {
let app = App {
endpoints: vec![
Endpoint {
route: parse_quote!("/hello"),
method: parse_quote!(get),
function: parse_quote!(hello),
},
Endpoint {
route: parse_quote!("/goodbye"),
method: parse_quote!(post),
function: parse_quote!(goodbye),
},
],
};

let actual = quote!(#app);
let expected = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::post(goodbye))
.into_service();

let response = router.call(request).await.unwrap();

response
}
);

assert_eq!(actual.to_string(), expected.to_string());
}
}
54 changes: 25 additions & 29 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
use axum::body::{Body, HttpBody};
use axum::{response::Response, routing::get, Router};
use futures_executor::block_on;
use http::Request;
use shuttle_common::wasm::{RequestWrapper, ResponseWrapper};
use std::fs::File;
use std::io::BufReader;
use std::io::{Read, Write};
use std::os::wasi::prelude::*;
use tower_service::Service;

extern crate rmp_serde as rmps;

pub fn handle_request<B>(req: Request<B>) -> Response
pub fn handle_request<B>(req: http::Request<B>) -> axum::response::Response
where
B: HttpBody + Send + 'static,
B: axum::body::HttpBody + Send + 'static,
{
block_on(app(req))
futures_executor::block_on(app(req))
}

async fn app<B>(request: Request<B>) -> Response
async fn app<B>(request: http::Request<B>) -> axum::response::Response
where
B: HttpBody + Send + 'static,
B: axum::body::HttpBody + Send + 'static,
{
let mut router = Router::new()
.route("/hello", get(hello))
.route("/goodbye", get(goodbye))
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::get(goodbye))
.into_service();

let response = router.call(request).await.unwrap();
Expand All @@ -42,19 +31,26 @@ async fn goodbye() -> &'static str {

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { File::from_raw_fd(fd_3) };
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };

let reader = BufReader::new(&mut parts_fd);
let reader = std::io::BufReader::new(&mut parts_fd);

// deserialize request parts from rust messagepack
let wrapper: RequestWrapper = rmps::from_read(reader).unwrap();
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { File::from_raw_fd(fd_4) };
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut body_buf = Vec::new();
Expand All @@ -68,7 +64,7 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
}
}

let request: Request<Body> = wrapper
let request: http::Request<axum::body::Body> = wrapper
.into_request_builder()
.body(body_buf.into())
.unwrap();
Expand All @@ -79,13 +75,13 @@ pub extern "C" fn __SHUTTLE_Axum_call(fd_3: RawFd, fd_4: RawFd) {
let (parts, mut body) = res.into_parts();

// wrap and serialize response parts as rmp
let response_parts = ResponseWrapper::from(parts).into_rmp();
let response_parts = shuttle_common::wasm::ResponseWrapper::from(parts).into_rmp();

// write response parts
parts_fd.write_all(&response_parts).unwrap();

// write body if there is one
if let Some(body) = block_on(body.data()) {
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
Expand Down