Skip to content

Commit

Permalink
Merge pull request #183 from asomers/async_trait
Browse files Browse the repository at this point in the history
async_trait compatibility
  • Loading branch information
asomers authored Aug 29, 2020
2 parents eb04806 + 7b15dd0 commit eb64bd8
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 42 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## [Unreleased] - ReleaseDate
### Added
- Compatibility with the `#[async_trait]` macro.
([#183](https://github.com/asomers/mockall/pull/183))

- Better support for non-Send types:
* Added `return_const_st` for returning non-`Send` constants, similar to
`returning_st`.
Expand Down
2 changes: 2 additions & 0 deletions mockall/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ predicates-tree = "1.0"
mockall_derive = { version = "= 0.7.2", path = "../mockall_derive" }

[dev-dependencies]
async-trait = "0.1.38"
futures = "0.3"
serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
Expand Down
33 changes: 33 additions & 0 deletions mockall/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
//! * [`Static methods`](#static-methods)
//! * [`Foreign functions`](#foreign-functions)
//! * [`Modules`](#modules)
//! * [`Async Traits`](#async-traits)
//! * [`Crate features`](#crate-features)
//! * [`Examples`](#examples)
//!
Expand Down Expand Up @@ -1001,6 +1002,38 @@
//! # fn main() {}
//! ```
//!
//! ## Async Traits
//!
//! Async traits aren't yet (as of 1.47.0) a part of the Rust language. But
//! they're available from the
//! [`async_trait`](https://docs.rs/async-trait/0.1.38/async_trait/) crate.
//! Mockall is compatible with this crate, with two important limitations:
//!
//! * The `#[automock]` attribute must appear _before_ the `#[async_trait]`
//! attribute.
//!
//! * The `#[async_trait]` macro must be imported with its canonical name.
//!
//! ```
//! # use async_trait::async_trait;
//! # use mockall::*;
//! // async_trait works with both #[automock]
//! #[automock]
//! #[async_trait]
//! pub trait Foo {
//! async fn foo(&self) -> u32;
//! }
//! // and mock!
//! mock! {
//! pub Bar {}
//! #[async_trait]
//! trait Foo {
//! async fn foo(&self) -> u32;
//! }
//! }
//! # fn main() {}
//! ```
//!
//! ## Crate features
//!
//! Mockall has a **nightly** feature. Currently this feature has three
Expand Down
22 changes: 22 additions & 0 deletions mockall/tests/automock_async_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// vim: tw=80
//! An async trait, for use with Futures
#![deny(warnings)]

use async_trait::async_trait;
use futures::executor::block_on;
use mockall::*;

#[automock]
#[async_trait]
pub trait Foo {
async fn foo(&self) -> u32;
}


#[test]
fn return_const() {
let mut mock = MockFoo::new();
mock.expect_foo()
.return_const(42u32);
assert_eq!(block_on(mock.foo()), 42);
}
20 changes: 20 additions & 0 deletions mockall/tests/mock_async_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// vim: tw=80
//! A struct with an async function
#![deny(warnings)]

use futures::executor::block_on;
use mockall::*;

mock! {
pub Foo {
async fn foo(&self) -> u32;
}
}

#[test]
fn return_const() {
let mut mock = MockFoo::new();
mock.expect_foo()
.return_const(42u32);
assert_eq!(block_on(mock.foo()), 42);
}
28 changes: 28 additions & 0 deletions mockall/tests/mock_async_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// vim: tw=80
//! An async trait, for use with Futures
#![deny(warnings)]

use async_trait::async_trait;
use futures::executor::block_on;
use mockall::*;

#[async_trait]
pub trait Foo {
async fn foo(&self) -> u32;
}

mock! {
pub Bar { }
#[async_trait]
trait Foo {
async fn foo(&self) -> u32;
}
}

#[test]
fn return_const() {
let mut mock = MockBar::new();
mock.expect_foo()
.return_const(42u32);
assert_eq!(block_on(mock.foo()), 42);
}
2 changes: 1 addition & 1 deletion mockall/tools/allgen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ cd ${TOOLSDIR}/../tests
#exit 0
for t in `ls *.rs | sed 's/\.rs//'`; do
env MOCKALL_DEBUG=1 cargo +nightly check --all-features --test $t > ${ODIR}/$t.rs || break;
rustfmt ${ODIR}/$t.rs ;
rustfmt --edition 2018 ${ODIR}/$t.rs ;
done
2 changes: 1 addition & 1 deletion mockall/tools/gen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ if [ -z "$t" ]; then
exit 2
fi

env MOCKALL_DEBUG=1 cargo check --all-features --test $t > ${PP}/$t.rs; rustfmt ${PP}/$t.rs
env MOCKALL_DEBUG=1 cargo check --all-features --test $t > ${PP}/$t.rs; rustfmt --edition 2018 ${PP}/$t.rs
echo "diff stat: " `diff ${PP_OLD}/$t.rs ${PP}/$t.rs | wc`
48 changes: 41 additions & 7 deletions mockall_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,49 @@ fn find_lifetimes(ty: &Type) -> HashSet<Lifetime> {
}
}

fn format_attrs(attrs: &[syn::Attribute], include_docs: bool) -> TokenStream {
let mut out = TokenStream::new();
for attr in attrs {
let is_doc = attr.path.get_ident().map(|i| i == "doc").unwrap_or(false);
if !is_doc || include_docs {
attr.to_tokens(&mut out);

struct AttrFormatter<'a>{
attrs: &'a [Attribute],
async_trait: bool,
doc: bool,
}

impl<'a> AttrFormatter<'a> {
fn new(attrs: &'a [Attribute]) -> AttrFormatter<'a> {
Self {
attrs,
async_trait: true,
doc: true
}
}
out

fn async_trait(&mut self, allowed: bool) -> &mut Self {
self.async_trait = allowed;
self
}

fn doc(&mut self, allowed: bool) -> &mut Self {
self.doc = allowed;
self
}

// XXX This logic requires that attributes are imported with their
// standard names.
fn format(&mut self) -> Vec<Attribute> {
self.attrs.iter()
.cloned()
.filter(|attr|
( self.doc ||
attr.path.get_ident()
.map(|i| i != "doc")
.unwrap_or(false)
) && (self.async_trait ||
attr.path.get_ident()
.map(|i| i != "async_trait")
.unwrap_or(false)
)
).collect()
}
}

/// Determine if this Pat is any kind of `self` binding
Expand Down
54 changes: 26 additions & 28 deletions mockall_derive/src/mock_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub(crate) struct MockFunction {
/// Types of the method arguments
argty: Vec<Type>,
/// any attributes on the original function, like #[inline]
attrs: Vec<Attribute>,
pub attrs: Vec<Attribute>,
/// Expressions that should be used for Expectation::call's arguments
call_exprs: Vec<TokenStream>,
/// Generics used for the expectation call
Expand Down Expand Up @@ -382,7 +382,7 @@ impl MockFunction {
// Supplying modname is an unfortunately hack. Ideally MockFunction
// wouldn't need to know that.
pub fn call(&self, modname: Option<&Ident>) -> impl ToTokens {
let attrs = self.format_attrs(true);
let attrs = AttrFormatter::new(&self.attrs).format();
let call_exprs = &self.call_exprs;
let (_, tg, _) = if self.is_method_generic() || self.is_static() {
&self.egenerics
Expand Down Expand Up @@ -418,7 +418,7 @@ impl MockFunction {
let outer_mod_path = self.outer_mod_path(modname);
quote!(
// Don't add a doc string. The original is included in #attrs
#attrs
#(#attrs)*
#vis #sig {
{
let __mockall_guard = #outer_mod_path::EXPECTATIONS
Expand All @@ -437,7 +437,7 @@ impl MockFunction {
} else {
quote!(
// Don't add a doc string. The original is included in #attrs
#attrs
#(#attrs)*
#vis #sig {
self.#substruct_obj #name.#call#tbf(#(#call_exprs,)*)
.expect(#no_match_msg)
Expand All @@ -449,11 +449,13 @@ impl MockFunction {

/// Return this method's contribution to its parent's checkpoint method
pub fn checkpoint(&self) -> impl ToTokens {
let attrs = self.format_attrs(false);
let attrs = AttrFormatter::new(&self.attrs)
.doc(false)
.format();
let inner_mod_ident = self.inner_mod_ident();
if self.is_static {
quote!(
#attrs
#(#attrs)*
{
let __mockall_timeses = #inner_mod_ident::EXPECTATIONS.lock()
.unwrap()
Expand All @@ -462,9 +464,8 @@ impl MockFunction {
}
)
} else {
let attrs = self.format_attrs(false);
let name = &self.name();
quote!(#attrs { self.#name.checkpoint(); })
quote!(#(#attrs)* { self.#name.checkpoint(); })
}
}

Expand All @@ -476,7 +477,9 @@ impl MockFunction {
// Supplying modname is an unfortunately hack. Ideally MockFunction
// wouldn't need to know that.
pub fn context_fn(&self, modname: Option<&Ident>) -> impl ToTokens {
let attrs = self.format_attrs(false);
let attrs = AttrFormatter::new(&self.attrs)
.doc(false)
.format();
let context_docstr = format!("Create a [`Context`]({}{}/struct.Context.html) for mocking the `{}` method",
modname.map(|m| format!("{}/", m)).unwrap_or_default(),
self.inner_mod_ident(),
Expand All @@ -486,7 +489,7 @@ impl MockFunction {
let outer_mod_path = self.outer_mod_path(modname);
let v = &self.call_vis;
quote!(
#attrs
#(#attrs)*
#[doc = #context_docstr]
#v fn #context_ident() -> #outer_mod_path::Context #tg
{
Expand All @@ -503,7 +506,9 @@ impl MockFunction {
// Supplying modname is an unfortunately hack. Ideally MockFunction
// wouldn't need to know that.
pub fn expect(&self, modname: &Ident) -> impl ToTokens {
let attrs = self.format_attrs(false);
let attrs = AttrFormatter::new(&self.attrs)
.doc(false)
.format();
let name = self.name();
let expect_ident = format_ident!("expect_{}", &name);
let expectation_obj = self.expectation_obj();
Expand Down Expand Up @@ -541,7 +546,7 @@ impl MockFunction {
quote!(
#must_use
#[doc = #docstr]
#attrs
#(#attrs)*
#vis fn #expect_ident #ig(&mut self)
-> &mut #modname::#expectation_obj
#wc
Expand Down Expand Up @@ -573,31 +578,22 @@ impl MockFunction {

pub fn field_definition(&self, modname: Option<&Ident>) -> TokenStream {
let name = self.name();
let attrs = self.format_attrs(false);
let attrs = AttrFormatter::new(&self.attrs)
.doc(false)
.format();
let expectations_obj = &self.expectations_obj();
if self.is_method_generic() {
quote!(#attrs #name: #modname::#expectations_obj)
quote!(#(#attrs)* #name: #modname::#expectations_obj)
} else {
// staticize any lifetimes. This is necessary for methods that
// return non-static types, because the Expectation itself must be
// 'static.
let segenerics = staticize(&self.egenerics);
let (_, tg, _) = segenerics.split_for_impl();
quote!(#attrs #name: #modname::#expectations_obj #tg)
quote!(#(#attrs)* #name: #modname::#expectations_obj #tg)
}
}

pub fn format_attrs(&self, include_docs: bool) -> impl ToTokens {
let mut out = TokenStream::new();
for attr in &self.attrs {
let is_doc = attr.path.get_ident().map(|i| i == "doc").unwrap_or(false);
if !is_doc || include_docs {
attr.to_tokens(&mut out);
}
}
out
}

/// Human-readable name of the mock function
fn funcname(&self) -> String {
if let Some(si) = &self.struct_ {
Expand Down Expand Up @@ -658,7 +654,9 @@ impl MockFunction {

/// Generate code for this function's private module
pub fn priv_module(&self) -> impl ToTokens {
let attrs = self.format_attrs(false);
let attrs = AttrFormatter::new(&self.attrs)
.doc(false)
.format();
let common = &Common{f: self};
let context = &Context{f: self};
let expectation: Box<dyn ToTokens> = if self.return_ref {
Expand Down Expand Up @@ -696,7 +694,7 @@ impl MockFunction {
Box::new(StaticRfunc{f: self})
};
quote!(
#attrs
#(#attrs)*
#[allow(missing_docs)]
pub mod #inner_mod_ident {
use super::*;
Expand Down
Loading

0 comments on commit eb64bd8

Please sign in to comment.