Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve IndexConfig derive macro #616

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
121 changes: 88 additions & 33 deletions meilisearch-index-setting-macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use convert_case::{Case, Casing};
use proc_macro2::Ident;
use quote::quote;
use structmeta::{Flag, StructMeta};
use syn::{parse_macro_input, spanned::Spanned};
use structmeta::{Flag, NameValue, StructMeta};
use syn::{parse_macro_input, spanned::Spanned, Attribute, LitStr};

#[derive(Clone, StructMeta, Default)]
struct FieldAttrs {
Expand All @@ -14,30 +14,51 @@ struct FieldAttrs {
sortable: Flag,
}

#[derive(StructMeta)]
struct StructAttrs {
index_name: Option<NameValue<LitStr>>,
max_total_hits: Option<NameValue<syn::Expr>>,
}

fn is_valid_name(name: &str) -> bool {
name.chars()
.all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_'))
&& !name.is_empty()
}

#[proc_macro_derive(IndexConfig, attributes(index_config))]
pub fn generate_index_settings(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
let syn::DeriveInput {
attrs, ident, data, ..
} = parse_macro_input!(input as syn::DeriveInput);

let fields: &syn::Fields = match ast.data {
let fields: &syn::Fields = match data {
syn::Data::Struct(ref data) => &data.fields,
_ => {
return proc_macro::TokenStream::from(
syn::Error::new(ast.ident.span(), "Applicable only to struct").to_compile_error(),
syn::Error::new(ident.span(), "Applicable only to struct").to_compile_error(),
);
}
};

let struct_ident = &ast.ident;
let struct_ident = &ident;

let index_config_implementation = get_index_config_implementation(struct_ident, fields);
let index_config_implementation = get_index_config_implementation(struct_ident, fields, attrs);
proc_macro::TokenStream::from(quote! {
#index_config_implementation
})
}

fn filter_attrs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("index_config"))
}

fn get_index_config_implementation(
struct_ident: &Ident,
fields: &syn::Fields,
attrs: Vec<Attribute>,
) -> proc_macro2::TokenStream {
let mut primary_key_attribute = String::new();
let mut distinct_key_attribute = String::new();
Expand All @@ -46,23 +67,40 @@ fn get_index_config_implementation(
let mut filterable_attributes = vec![];
let mut sortable_attributes = vec![];

let index_name = struct_ident
.to_string()
.from_case(Case::UpperCamel)
.to_case(Case::Snake);
let mut index_name_override = None;

let mut max_total_hits = None;

let struct_attrs =
filter_attrs(&attrs).filter_map(|attr| attr.parse_args::<StructAttrs>().ok());
for struct_attr in struct_attrs {
if let Some(index_name_value) = struct_attr.index_name {
index_name_override = Some((index_name_value.value.value(), index_name_value.name_span))
}

if let Some(max_total_hits_value) = struct_attr.max_total_hits {
max_total_hits = Some(max_total_hits_value.value)
}
}

let (index_name, span) = index_name_override.unwrap_or_else(|| {
(
struct_ident.to_string().to_case(Case::Snake),
struct_ident.span(),
)
});

if !is_valid_name(&index_name) {
return syn::Error::new(span, "Index must follow the naming guidelines.")
.to_compile_error();
}

let mut primary_key_found = false;
let mut distinct_found = false;

for field in fields {
let attrs = field
.attrs
.iter()
.filter(|attr| attr.path().is_ident("index_config"))
.map(|attr| attr.parse_args::<FieldAttrs>().unwrap())
.collect::<Vec<_>>()
.first()
.cloned()
let attrs = filter_attrs(&field.attrs)
.find_map(|attr| attr.parse_args::<FieldAttrs>().ok())
.unwrap_or_default();

// Check if the primary key field is unique
Expand Down Expand Up @@ -128,28 +166,45 @@ fn get_index_config_implementation(
"with_distinct_attribute",
);

let pagination_token = get_pagination_token(&max_total_hits, "with_pagination");

quote! {
#[::meilisearch_sdk::macro_helper::async_trait(?Send)]
impl ::meilisearch_sdk::documents::IndexConfig for #struct_ident {
const INDEX_STR: &'static str = #index_name;

fn generate_settings() -> ::meilisearch_sdk::settings::Settings {
::meilisearch_sdk::settings::Settings::new()
#display_attr_tokens
#sortable_attr_tokens
#filterable_attr_tokens
#searchable_attr_tokens
#distinct_attr_token
}

async fn generate_index<Http: ::meilisearch_sdk::request::HttpClient>(client: &::meilisearch_sdk::client::Client<Http>) -> std::result::Result<::meilisearch_sdk::indexes::Index<Http>, ::meilisearch_sdk::tasks::Task> {
return client.create_index(#index_name, #primary_key_token)
.await.unwrap()
.wait_for_completion(&client, ::std::option::Option::None, ::std::option::Option::None)
.await.unwrap()
.try_make_index(&client);
::meilisearch_sdk::settings::Settings::new()
#display_attr_tokens
#sortable_attr_tokens
#filterable_attr_tokens
#searchable_attr_tokens
#distinct_attr_token
#pagination_token
}

async fn generate_index<Http: ::meilisearch_sdk::request::HttpClient>(client: &::meilisearch_sdk::client::Client<Http>) -> std::result::Result<::meilisearch_sdk::indexes::Index<Http>, ::meilisearch_sdk::tasks::Task> {
client.create_index(#index_name, #primary_key_token)
.await.unwrap()
.wait_for_completion(client, ::std::option::Option::None, ::std::option::Option::None)
.await.unwrap()
.try_make_index(client)
}
}
}
}

fn get_pagination_token(
max_hits: &Option<syn::Expr>,
method_name: &str,
) -> proc_macro2::TokenStream {
let method_ident = Ident::new(method_name, proc_macro2::Span::call_site());

match max_hits {
Some(value) => {
quote! { .#method_ident(::meilisearch_sdk::settings::PaginationSetting { max_total_hits: #value }) }
}
None => quote! {},
}
}

Expand Down
16 changes: 15 additions & 1 deletion src/documents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};

/// Derive the [`IndexConfig`](crate::documents::IndexConfig) trait.
///
/// ## Struct attribute
/// Use the `#[index_config(..)]` struct attribute to set general index settings.
///
/// The available parameters are:
/// - `index_name = "new_name"` - Override index name
/// - `max_total_hits = 5_000` - [Set pagination settings](https://www.meilisearch.com/docs/reference/api/settings#update-pagination-settings)
/// - Value can be anything that returns usize.
///
/// ## Field attribute
/// Use the `#[index_config(..)]` field attribute to generate the correct settings
/// for each field. The available parameters are:
/// for each field.
///
/// The available parameters are:
/// - `primary_key` (can only be used once)
/// - `distinct` (can only be used once)
/// - `searchable`
Expand All @@ -16,6 +26,10 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
/// ## Index name
/// The name of the index will be the name of the struct converted to snake case.
///
/// Or it can be overridden with `index_name` at the struct attribute level.
///
/// ⚠️ Struct and index names should follow the naming [guidelines](https://www.meilisearch.com/docs/learn/getting_started/indexes#index-uid)
///
/// ## Sample usage:
/// ```
/// use serde::{Serialize, Deserialize};
Expand Down