Skip to content

Commit

Permalink
Add support for multiple traits
Browse files Browse the repository at this point in the history
  • Loading branch information
commonsensesoftware committed Dec 20, 2023
1 parent e0ed5d3 commit 6826c3e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/di_macros/internal/attribute.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use syn::{
parse::{Parse, ParseStream},
Path, Result,
Path, Result, punctuated::Punctuated, token::Plus,
};

pub struct InjectableAttribute {
pub trait_: Option<Path>,
pub trait_: Option<Punctuated<Path, Plus>>,
}

impl Parse for InjectableAttribute {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
trait_: input.parse().ok(),
trait_: input.parse_terminated(Path::parse, Plus).ok(),
})
}
}
14 changes: 9 additions & 5 deletions src/di_macros/internal/derive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use syn::{Generics, ItemStruct, Path, Signature};
use syn::{punctuated::Punctuated, token::Plus, Generics, ItemStruct, Path, Signature};

pub enum MacroTarget<'a> {
Method(&'a Signature),
Expand All @@ -8,15 +8,15 @@ pub enum MacroTarget<'a> {
pub struct DeriveContext<'a> {
pub generics: &'a Generics,
pub implementation: &'a Path,
pub service: &'a Path,
pub service: Punctuated<Path, Plus>,
target: MacroTarget<'a>,
}

impl<'a> DeriveContext<'a> {
pub fn for_method(
generics: &'a Generics,
implementation: &'a Path,
service: &'a Path,
service: Punctuated<Path, Plus>,
method: &'a Signature,
) -> Self {
Self {
Expand All @@ -30,7 +30,7 @@ impl<'a> DeriveContext<'a> {
pub fn for_struct(
generics: &'a Generics,
implementation: &'a Path,
service: &'a Path,
service: Punctuated<Path, Plus>,
struct_: &'a ItemStruct,
) -> Self {
Self {
Expand All @@ -46,8 +46,12 @@ impl<'a> DeriveContext<'a> {
}

pub fn is_trait(&self) -> bool {
if self.service.len() > 1 {
return true;
}

let impl_ = &self.implementation.segments.last().unwrap().ident;
let struct_ = &self.service.segments.last().unwrap().ident;
let struct_ = &self.service.first().unwrap().segments.last().unwrap().ident;
impl_ != struct_
}
}
6 changes: 3 additions & 3 deletions src/di_macros/internal/derive_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl InjectableTrait {
}

let service = if context.is_trait() {
let svc = context.service;
quote! { dyn #svc }
let svc = context.service.iter();
quote! { dyn #(#svc)+* }
} else {
quote! { Self }
};
Expand All @@ -48,7 +48,7 @@ impl InjectableTrait {
};
let activate2 = activate.clone();
let code = quote! {
impl#generics di::Injectable for #implementation #where_ {
impl #generics di::Injectable for #implementation #where_ {
fn inject(lifetime: di::ServiceLifetime) -> di::InjectBuilder {
di::InjectBuilder::new(
di::Activator::new::<#service, Self>(
Expand Down
33 changes: 25 additions & 8 deletions src/di_macros/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ extern crate proc_macro;
use crate::internal::*;
use internal::{Constructor, DeriveContext, InjectableTrait};
use proc_macro2::TokenStream;
use syn::{punctuated::Punctuated, spanned::Spanned, token::PathSep, *};
use syn::{
punctuated::Punctuated,
spanned::Spanned,
token::{PathSep, Plus},
*,
};

/// Represents the metadata used to identify an injected function.
///
Expand Down Expand Up @@ -164,11 +169,10 @@ fn derive_from_struct_impl(
) -> Result<TokenStream> {
if let Type::Path(type_) = &*impl_.self_ty {
let imp = &type_.path;
let svc = attribute.trait_.as_ref().unwrap_or(imp);

let svc = service_from_attribute(imp, attribute);
match Constructor::select(&impl_, imp) {
Ok(method) => {
let context = DeriveContext::for_method(&impl_.generics, imp, &svc, method);
let context = DeriveContext::for_method(&impl_.generics, imp, svc, method);
derive(context, original)
}
Err(error) => Err(error),
Expand All @@ -184,12 +188,22 @@ fn derive_from_struct(
original: TokenStream,
) -> Result<TokenStream> {
let imp = &build_path_from_struct(&struct_);
let svc = attribute.trait_.as_ref().unwrap_or(imp);
let svc = service_from_attribute(imp, attribute);
let context = DeriveContext::for_struct(&struct_.generics, imp, svc, &struct_);

derive(context, original)
}

fn service_from_attribute(impl_: &Path, mut attribute: InjectableAttribute) -> Punctuated<Path, Plus> {
let mut punctuated = attribute.trait_.take().unwrap_or_else(Punctuated::<Path, Plus>::new);

if punctuated.is_empty() {
punctuated.push(impl_.clone());
}

punctuated
}

fn build_path_from_struct(struct_: &ItemStruct) -> Path {
let generics = &struct_.generics;
let mut segments = Punctuated::<PathSegment, PathSep>::new();
Expand Down Expand Up @@ -279,7 +293,8 @@ mod test {
"| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
"lifetime) ",
"} ",
"}");
"}"
);

assert_eq!(expected, result.to_string());
}
Expand Down Expand Up @@ -516,7 +531,8 @@ mod test {
"| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
"lifetime) ",
"} ",
"}");
"}"
);

assert_eq!(expected, result.to_string());
}
Expand Down Expand Up @@ -593,7 +609,8 @@ mod test {
"| sp : & di :: ServiceProvider | di :: RefMut :: new (Self :: new () . into ())) , ",
"lifetime) ",
"} ",
"}");
"}"
);

assert_eq!(expected, result.to_string());
}
Expand Down
24 changes: 23 additions & 1 deletion test/di/scenarios.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ fn inject_should_resolve_keyed_mut() {
}

#[test]
fn inject_should_support_multiple_traits() {
fn inject_should_resolve_multiple_traits() {
// arrange
let provider = ServiceCollection::new()
.add(MultiService::singleton())
Expand All @@ -464,3 +464,25 @@ fn inject_should_support_multiple_traits() {
// assert
// no panic!
}

#[test]
fn inject_should_support_multiple_traits() {
// arrange
trait IPityTheFoo {}

#[injectable(IPityTheFoo + Send + Sync)]
struct Foo;

impl IPityTheFoo for Foo {}

let provider = ServiceCollection::new()
.add(Foo::transient())
.build_provider()
.unwrap();

// act
let _ = provider.get_required::<dyn IPityTheFoo + Send + Sync>();

// assert
// no panic!
}

0 comments on commit 6826c3e

Please sign in to comment.