diff --git a/Cargo.toml b/Cargo.toml index 0122d770d9..36bd4a6a03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,18 @@ resolver = "1" members = [ "crates/*", "examples" + "crates/stdarch-verify", + "crates/core_arch", + "crates/std_detect", + "crates/stdarch-gen-arm", + "crates/stdarch-gen-loongarch", + "crates/stdarch-gen", + "crates/stdarch-gen2", + "crates/intrinsic-test", + "examples/" +] +exclude = [ + "crates/wasm-assert-instr-tests" ] [profile.release] diff --git a/crates/stdarch-gen2/Cargo.toml b/crates/stdarch-gen2/Cargo.toml new file mode 100644 index 0000000000..c9a039ea6b --- /dev/null +++ b/crates/stdarch-gen2/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "stdarch-gen2" +version = "0.1.0" +authors = ["Luca Vizzarro ", + "Jamie Cunliffe ", + "Adam Gemmell "] +license = "MIT OR Apache-2.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +itertools = "0.10" +lazy_static = "1.4.0" +proc-macro2 = "1.0" +quote = "1.0" +regex = "1.5" +serde = { version = "1.0", features = ["derive"] } +serde_with = "1.14" +serde_yaml = "0.8" +walkdir = "2.3.2" diff --git a/crates/stdarch-gen2/src/assert_instr.rs b/crates/stdarch-gen2/src/assert_instr.rs new file mode 100644 index 0000000000..ce1bbe8b55 --- /dev/null +++ b/crates/stdarch-gen2/src/assert_instr.rs @@ -0,0 +1,372 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use serde::de::{self, MapAccess, Visitor}; +use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize}; +use std::fmt; + +use crate::{ + context::{self, Context}, + typekinds::{BaseType, BaseTypeKind}, + wildstring::WildString, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum InstructionAssertion { + Basic(WildString), + WithArgs(WildString, WildString), +} + +impl InstructionAssertion { + fn build(&mut self, ctx: &Context) -> context::Result { + match self { + InstructionAssertion::Basic(ws) => ws.build_acle(ctx.local), + InstructionAssertion::WithArgs(ws, args_ws) => [ws, args_ws] + .into_iter() + .try_for_each(|ws| ws.build_acle(ctx.local)), + } + } +} + +impl ToTokens for InstructionAssertion { + fn to_tokens(&self, tokens: &mut TokenStream) { + let instr = format_ident!( + "{}", + match self { + Self::Basic(instr) => instr, + Self::WithArgs(instr, _) => instr, + } + .to_string() + ); + tokens.append_all(quote! { #instr }); + + if let Self::WithArgs(_, args) = self { + let ex: TokenStream = args + .to_string() + .parse() + .expect("invalid instruction assertion arguments expression given"); + tokens.append_all(quote! {, #ex}) + } + } +} + +// Asserts that the given instruction is present for the intrinsic of the associated type bitsize. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(remote = "Self")] +pub struct InstructionAssertionMethodForBitsize { + pub default: InstructionAssertion, + pub byte: Option, + pub halfword: Option, + pub word: Option, + pub doubleword: Option, +} + +impl InstructionAssertionMethodForBitsize { + fn build(&mut self, ctx: &Context) -> context::Result { + if let Some(ref mut byte) = self.byte { + byte.build(ctx)? + } + if let Some(ref mut halfword) = self.halfword { + halfword.build(ctx)? + } + if let Some(ref mut word) = self.word { + word.build(ctx)? + } + if let Some(ref mut doubleword) = self.doubleword { + doubleword.build(ctx)? + } + self.default.build(ctx) + } +} + +impl Serialize for InstructionAssertionMethodForBitsize { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + InstructionAssertionMethodForBitsize { + default: InstructionAssertion::Basic(instr), + byte: None, + halfword: None, + word: None, + doubleword: None, + } => serializer.serialize_str(&instr.to_string()), + InstructionAssertionMethodForBitsize { + default: InstructionAssertion::WithArgs(instr, args), + byte: None, + halfword: None, + word: None, + doubleword: None, + } => { + let mut seq = serializer.serialize_seq(Some(2))?; + seq.serialize_element(&instr.to_string())?; + seq.serialize_element(&args.to_string())?; + seq.end() + } + _ => InstructionAssertionMethodForBitsize::serialize(self, serializer), + } + } +} + +impl<'de> Deserialize<'de> for InstructionAssertionMethodForBitsize { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct IAMVisitor; + + impl<'de> Visitor<'de> for IAMVisitor { + type Value = InstructionAssertionMethodForBitsize; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("array, string or map") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(InstructionAssertionMethodForBitsize { + default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?), + byte: None, + halfword: None, + word: None, + doubleword: None, + }) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + use serde::de::Error; + let make_err = + || Error::custom("invalid number of arguments passed to assert_instruction"); + let instruction = seq.next_element()?.ok_or_else(make_err)?; + let args = seq.next_element()?.ok_or_else(make_err)?; + + if let Some(true) = seq.size_hint().map(|len| len > 0) { + Err(make_err()) + } else { + Ok(InstructionAssertionMethodForBitsize { + default: InstructionAssertion::WithArgs(instruction, args), + byte: None, + halfword: None, + word: None, + doubleword: None, + }) + } + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + InstructionAssertionMethodForBitsize::deserialize( + de::value::MapAccessDeserializer::new(map), + ) + } + } + + deserializer.deserialize_any(IAMVisitor) + } +} + +/// Asserts that the given instruction is present for the intrinsic of the associated type. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(remote = "Self")] +pub struct InstructionAssertionMethod { + /// Instruction for integer intrinsics + pub default: InstructionAssertionMethodForBitsize, + /// Instruction for floating-point intrinsics (optional) + #[serde(default)] + pub float: Option, + /// Instruction for unsigned integer intrinsics (optional) + #[serde(default)] + pub unsigned: Option, +} + +impl InstructionAssertionMethod { + pub(crate) fn build(&mut self, ctx: &Context) -> context::Result { + if let Some(ref mut float) = self.float { + float.build(ctx)? + } + if let Some(ref mut unsigned) = self.unsigned { + unsigned.build(ctx)? + } + self.default.build(ctx) + } +} + +impl Serialize for InstructionAssertionMethod { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + InstructionAssertionMethod { + default: + InstructionAssertionMethodForBitsize { + default: InstructionAssertion::Basic(instr), + byte: None, + halfword: None, + word: None, + doubleword: None, + }, + float: None, + unsigned: None, + } => serializer.serialize_str(&instr.to_string()), + InstructionAssertionMethod { + default: + InstructionAssertionMethodForBitsize { + default: InstructionAssertion::WithArgs(instr, args), + byte: None, + halfword: None, + word: None, + doubleword: None, + }, + float: None, + unsigned: None, + } => { + let mut seq = serializer.serialize_seq(Some(2))?; + seq.serialize_element(&instr.to_string())?; + seq.serialize_element(&args.to_string())?; + seq.end() + } + _ => InstructionAssertionMethod::serialize(self, serializer), + } + } +} + +impl<'de> Deserialize<'de> for InstructionAssertionMethod { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct IAMVisitor; + + impl<'de> Visitor<'de> for IAMVisitor { + type Value = InstructionAssertionMethod; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("array, string or map") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(InstructionAssertionMethod { + default: InstructionAssertionMethodForBitsize { + default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?), + byte: None, + halfword: None, + word: None, + doubleword: None, + }, + float: None, + unsigned: None, + }) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + use serde::de::Error; + let make_err = + || Error::custom("invalid number of arguments passed to assert_instruction"); + let instruction = seq.next_element()?.ok_or_else(make_err)?; + let args = seq.next_element()?.ok_or_else(make_err)?; + + if let Some(true) = seq.size_hint().map(|len| len > 0) { + Err(make_err()) + } else { + Ok(InstructionAssertionMethod { + default: InstructionAssertionMethodForBitsize { + default: InstructionAssertion::WithArgs(instruction, args), + byte: None, + halfword: None, + word: None, + doubleword: None, + }, + float: None, + unsigned: None, + }) + } + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + InstructionAssertionMethod::deserialize(de::value::MapAccessDeserializer::new(map)) + } + } + + deserializer.deserialize_any(IAMVisitor) + } +} + +#[derive(Debug)] +pub struct InstructionAssertionsForBaseType<'a>( + pub &'a Vec, + pub &'a Option<&'a BaseType>, +); + +impl<'a> ToTokens for InstructionAssertionsForBaseType<'a> { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.0.iter().for_each( + |InstructionAssertionMethod { + default, + float, + unsigned, + }| { + let kind = self.1.map(|ty| ty.kind()); + let instruction = match (kind, float, unsigned) { + (None, float, unsigned) if float.is_some() || unsigned.is_some() => { + unreachable!( + "cannot determine the base type kind for instruction assertion: {self:#?}") + } + (Some(BaseTypeKind::Float), Some(float), _) => float, + (Some(BaseTypeKind::UInt), _, Some(unsigned)) => unsigned, + _ => default, + }; + + let bitsize = self.1.and_then(|ty| ty.get_size().ok()); + let instruction = match (bitsize, instruction) { + ( + Some(8), + InstructionAssertionMethodForBitsize { + byte: Some(byte), .. + }, + ) => byte, + ( + Some(16), + InstructionAssertionMethodForBitsize { + halfword: Some(halfword), + .. + }, + ) => halfword, + ( + Some(32), + InstructionAssertionMethodForBitsize { + word: Some(word), .. + }, + ) => word, + ( + Some(64), + InstructionAssertionMethodForBitsize { + doubleword: Some(doubleword), + .. + }, + ) => doubleword, + (_, InstructionAssertionMethodForBitsize { default, .. }) => default, + }; + + tokens.append_all(quote! { #[cfg_attr(test, assert_instr(#instruction))]}) + }, + ); + } +} diff --git a/crates/stdarch-gen2/src/context.rs b/crates/stdarch-gen2/src/context.rs new file mode 100644 index 0000000000..108f7ab706 --- /dev/null +++ b/crates/stdarch-gen2/src/context.rs @@ -0,0 +1,249 @@ +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::{ + expression::Expression, + input::{InputSet, InputType}, + intrinsic::{Constraint, Intrinsic, Signature}, + matching::SizeMatchable, + predicate_forms::PredicateForm, + typekinds::{ToRepr, TypeKind}, + wildcards::Wildcard, + wildstring::WildString, +}; + +/// Maximum SVE vector size +const SVE_VECTOR_MAX_SIZE: u32 = 2048; +/// Vector register size +const VECTOR_REG_SIZE: u32 = 128; + +/// Generator result +pub type Result = std::result::Result; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchitectureSettings { + #[serde(alias = "arch")] + pub arch_name: String, + pub target_feature: Vec, + #[serde(alias = "llvm_prefix")] + pub llvm_link_prefix: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GlobalContext { + pub arch_cfgs: Vec, + #[serde(default)] + pub uses_neon_types: bool, +} + +/// Context of an intrinsic group +#[derive(Debug, Clone, Default)] +pub struct GroupContext { + /// LLVM links to target input sets + pub links: HashMap, +} + +#[derive(Debug, Clone, Copy)] +pub enum VariableType { + Argument, + Internal, +} + +#[derive(Debug, Clone)] +pub struct LocalContext { + pub signature: Signature, + + pub input: InputSet, + + pub substitutions: HashMap, + pub variables: HashMap, +} + +impl LocalContext { + pub fn new(input: InputSet, original: &Intrinsic) -> LocalContext { + LocalContext { + signature: original.signature.clone(), + input, + substitutions: HashMap::new(), + variables: HashMap::new(), + } + } + + pub fn provide_type_wildcard(&self, wildcard: &Wildcard) -> Result { + let err = || format!("wildcard {{{wildcard}}} not found"); + + let make_neon = |tuple_size| move |ty| TypeKind::make_vector(ty, false, tuple_size); + let make_sve = |tuple_size| move |ty| TypeKind::make_vector(ty, true, tuple_size); + + match wildcard { + Wildcard::Type(idx) => self.input.typekind(*idx).ok_or_else(err), + Wildcard::NEONType(idx, tuple_size) => self + .input + .typekind(*idx) + .ok_or_else(err) + .and_then(make_neon(*tuple_size)), + Wildcard::SVEType(idx, tuple_size) => self + .input + .typekind(*idx) + .ok_or_else(err) + .and_then(make_sve(*tuple_size)), + Wildcard::Predicate(idx) => self.input.typekind(*idx).map_or_else( + || { + if idx.is_none() && self.input.types_len() == 1 { + Err(err()) + } else { + Err(format!( + "there is no type at index {} to infer the predicate from", + idx.unwrap_or(0) + )) + } + }, + |ref ty| TypeKind::make_predicate_from(ty), + ), + Wildcard::MaxPredicate => self + .input + .iter() + .filter_map(|arg| arg.typekind()) + .max_by(|x, y| { + x.base_type() + .and_then(|bt| bt.get_size().ok()) + .unwrap_or(0) + .cmp(&y.base_type().and_then(|bt| bt.get_size().ok()).unwrap_or(0)) + }) + .map_or_else( + || Err("there are no types available to infer the predicate from".to_string()), + TypeKind::make_predicate_from, + ), + Wildcard::Scale(w, as_ty) => { + let mut ty = self.provide_type_wildcard(w)?; + if let Some(vty) = ty.vector_mut() { + let base_ty = if let Some(w) = as_ty.wildcard() { + *self.provide_type_wildcard(w)?.base_type().unwrap() + } else { + *as_ty.base_type().unwrap() + }; + vty.cast_base_type_as(base_ty) + } + Ok(ty) + } + _ => Err(err()), + } + } + + pub fn provide_substitution_wildcard(&self, wildcard: &Wildcard) -> Result { + let err = || Err(format!("wildcard {{{wildcard}}} not found")); + + match wildcard { + Wildcard::SizeLiteral(idx) => self.input.typekind(*idx) + .map_or_else(err, |ty| Ok(ty.size_literal())), + Wildcard::Size(idx) => self.input.typekind(*idx) + .map_or_else(err, |ty| Ok(ty.size())), + Wildcard::SizeMinusOne(idx) => self.input.typekind(*idx) + .map_or_else(err, |ty| Ok((ty.size().parse::().unwrap()-1).to_string())), + Wildcard::SizeInBytesLog2(idx) => self.input.typekind(*idx) + .map_or_else(err, |ty| Ok(ty.size_in_bytes_log2())), + Wildcard::NVariant if self.substitutions.get(wildcard).is_none() => Ok(String::new()), + Wildcard::TypeKind(idx, opts) => { + self.input.typekind(*idx) + .map_or_else(err, |ty| { + let literal = if let Some(opts) = opts { + opts.contains(ty.base_type().map(|bt| *bt.kind()).ok_or_else(|| { + format!("cannot retrieve a type literal out of {ty}") + })?) + .then(|| ty.type_kind()) + .unwrap_or_default() + } else { + ty.type_kind() + }; + Ok(literal) + }) + } + Wildcard::PredicateForms(_) => self + .input + .iter() + .find_map(|arg| { + if let InputType::PredicateForm(pf) = arg { + Some(pf.get_suffix().to_string()) + } else { + None + } + }) + .ok_or_else(|| unreachable!("attempting to render a predicate form wildcard, but no predicate form was compiled for it")), + _ => self + .substitutions + .get(wildcard) + .map_or_else(err, |s| Ok(s.clone())), + } + } + + pub fn make_assertion_from_constraint(&self, constraint: &Constraint) -> Result { + match constraint { + Constraint::AnyI32 { + variable, + any_values, + } => { + let where_ex = any_values + .iter() + .map(|value| format!("{variable} == {value}")) + .join(" || "); + Ok(Expression::MacroCall("static_assert".to_string(), where_ex)) + } + Constraint::RangeI32 { + variable, + range: SizeMatchable::Matched(range), + } => Ok(Expression::MacroCall( + "static_assert_range".to_string(), + format!( + "{variable}, {min}, {max}", + min = range.start(), + max = range.end() + ), + )), + Constraint::SVEMaxElems { + variable, + sve_max_elems_type: ty, + } + | Constraint::VecMaxElems { + variable, + vec_max_elems_type: ty, + } => { + if !self.input.is_empty() { + let higher_limit = match constraint { + Constraint::SVEMaxElems { .. } => SVE_VECTOR_MAX_SIZE, + Constraint::VecMaxElems { .. } => VECTOR_REG_SIZE, + _ => unreachable!(), + }; + + let max = ty.base_type() + .map(|ty| ty.get_size()) + .transpose()? + .map_or_else( + || Err(format!("can't make an assertion out of constraint {self:?}: no valid type is present")), + |bitsize| Ok(higher_limit / bitsize - 1))?; + Ok(Expression::MacroCall( + "static_assert_range".to_string(), + format!("{variable}, 0, {max}"), + )) + } else { + Err(format!("can't make an assertion out of constraint {self:?}: no types are being used")) + } + } + _ => unreachable!("constraints were not built successfully!"), + } + } + + pub fn predicate_form(&self) -> Option<&PredicateForm> { + self.input.iter().find_map(|arg| arg.predicate_form()) + } + + pub fn n_variant_op(&self) -> Option<&WildString> { + self.input.iter().find_map(|arg| arg.n_variant_op()) + } +} + +pub struct Context<'ctx> { + pub local: &'ctx mut LocalContext, + pub group: &'ctx mut GroupContext, + pub global: &'ctx GlobalContext, +} diff --git a/crates/stdarch-gen2/src/expression.rs b/crates/stdarch-gen2/src/expression.rs new file mode 100644 index 0000000000..4434ae276e --- /dev/null +++ b/crates/stdarch-gen2/src/expression.rs @@ -0,0 +1,546 @@ +use itertools::Itertools; +use lazy_static::lazy_static; +use proc_macro2::{Literal, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use regex::Regex; +use serde::de::{self, MapAccess, Visitor}; +use serde::{Deserialize, Deserializer, Serialize}; +use std::fmt; +use std::str::FromStr; + +use crate::intrinsic::Intrinsic; +use crate::{ + context::{self, Context, VariableType}, + intrinsic::{Argument, LLVMLink, StaticDefinition}, + matching::{MatchKindValues, MatchSizeValues}, + typekinds::{BaseType, BaseTypeKind, TypeKind}, + wildcards::Wildcard, + wildstring::WildString, +}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum IdentifierType { + Variable, + Symbol, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum LetVariant { + Basic(WildString, Box), + WithType(WildString, TypeKind, Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FnCall( + /// Function pointer + pub Box, + /// Function arguments + pub Vec, + /// Function turbofish arguments + #[serde(default)] + pub Vec, +); + +impl FnCall { + pub fn new_expression(fn_ptr: Expression, arguments: Vec) -> Expression { + FnCall(Box::new(fn_ptr), arguments, Vec::new()).into() + } + + pub fn is_llvm_link_call(&self, llvm_link_name: &String) -> bool { + if let Expression::Identifier(fn_name, IdentifierType::Symbol) = self.0.as_ref() { + &fn_name.to_string() == llvm_link_name + } else { + false + } + } + + pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result { + self.0.pre_build(ctx)?; + self.1 + .iter_mut() + .chain(self.2.iter_mut()) + .try_for_each(|ex| ex.pre_build(ctx)) + } + + pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result { + self.0.build(intrinsic, ctx)?; + self.1 + .iter_mut() + .chain(self.2.iter_mut()) + .try_for_each(|ex| ex.build(intrinsic, ctx)) + } +} + +impl ToTokens for FnCall { + fn to_tokens(&self, tokens: &mut TokenStream) { + let FnCall(fn_ptr, arguments, turbofish) = self; + + fn_ptr.to_tokens(tokens); + + if !turbofish.is_empty() { + tokens.append_all(quote! {::<#(#turbofish),*>}); + } + + tokens.append_all(quote! { (#(#arguments),*) }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(remote = "Self", deny_unknown_fields)] +pub enum Expression { + /// (Re)Defines a variable + Let(LetVariant), + /// Performs a variable assignment operation + Assign(String, Box), + /// Performs a macro call + MacroCall(String, String), + /// Performs a function call + FnCall(FnCall), + /// Performs a method call. The following: + /// `MethodCall: ["$object", "to_string", []]` + /// is tokenized as: + /// `object.to_string()`. + MethodCall(Box, String, Vec), + /// Symbol identifier name, prepend with a `$` to treat it as a scope variable + /// which engages variable tracking and enables inference. + /// E.g. `my_function_name` for a generic symbol or `$my_variable` for + /// a variable. + Identifier(WildString, IdentifierType), + /// Constant signed integer number expression + IntConstant(i32), + /// Constant floating point number expression + FloatConstant(f32), + /// Constant boolean expression, either `true` or `false` + BoolConstant(bool), + /// Array expression + Array(Vec), + + // complex expressions + /// Makes an LLVM link. + /// + /// It stores the link's function name in the wildcard `{llvm_link}`, for use in + /// subsequent expressions. + LLVMLink(LLVMLink), + /// Casts the given expression to the specified (unchecked) type + CastAs(Box, String), + /// Returns the LLVM `undef` symbol + SvUndef, + /// Multiplication + Multiply(Box, Box), + /// Converts the specified constant to the specified type's kind + ConvertConst(TypeKind, i32), + /// Yields the given type in the Rust representation + Type(TypeKind), + + MatchSize(TypeKind, MatchSizeValues>), + MatchKind(TypeKind, MatchKindValues>), +} + +impl Expression { + pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result { + match self { + Self::FnCall(fn_call) => fn_call.pre_build(ctx), + Self::MethodCall(cl_ptr_ex, _, arg_exs) => { + cl_ptr_ex.pre_build(ctx)?; + arg_exs.iter_mut().try_for_each(|ex| ex.pre_build(ctx)) + } + Self::Let(LetVariant::Basic(_, ex) | LetVariant::WithType(_, _, ex)) => { + ex.pre_build(ctx) + } + Self::CastAs(ex, _) => ex.pre_build(ctx), + Self::Multiply(lhs, rhs) => { + lhs.pre_build(ctx)?; + rhs.pre_build(ctx) + } + Self::MatchSize(match_ty, values) => { + *self = *values.get(match_ty, ctx.local)?.to_owned(); + self.pre_build(ctx) + } + Self::MatchKind(match_ty, values) => { + *self = *values.get(match_ty, ctx.local)?.to_owned(); + self.pre_build(ctx) + } + _ => Ok(()), + } + } + + pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result { + match self { + Self::LLVMLink(link) => link.build_and_save(ctx), + Self::Identifier(identifier, id_type) => { + identifier.build_acle(ctx.local)?; + + if let IdentifierType::Variable = id_type { + ctx.local + .variables + .get(&identifier.to_string()) + .map(|_| ()) + .ok_or_else(|| format!("invalid variable {identifier} being referenced")) + } else { + Ok(()) + } + } + Self::FnCall(fn_call) => { + fn_call.build(intrinsic, ctx)?; + + if let Some(llvm_link_name) = ctx.local.substitutions.get(&Wildcard::LLVMLink) { + if fn_call.is_llvm_link_call(llvm_link_name) { + *self = intrinsic + .llvm_link() + .expect("got LLVMLink wildcard without a LLVM link in `compose`") + .apply_conversions_to_call(fn_call.clone(), ctx.local)? + } + } + + Ok(()) + } + Self::MethodCall(cl_ptr_ex, _, arg_exs) => { + cl_ptr_ex.build(intrinsic, ctx)?; + arg_exs + .iter_mut() + .try_for_each(|ex| ex.build(intrinsic, ctx)) + } + Self::Let(variant) => { + let (var_name, ex, ty) = match variant { + LetVariant::Basic(var_name, ex) => (var_name, ex, None), + LetVariant::WithType(var_name, ty, ex) => { + if let Some(w) = ty.wildcard() { + ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?; + } + (var_name, ex, Some(ty.to_owned())) + } + }; + + var_name.build_acle(ctx.local)?; + ctx.local.variables.insert( + var_name.to_string(), + ( + ty.unwrap_or_else(|| TypeKind::Custom("unknown".to_string())), + VariableType::Internal, + ), + ); + ex.build(intrinsic, ctx) + } + Self::CastAs(ex, _) => ex.build(intrinsic, ctx), + Self::Multiply(lhs, rhs) => { + lhs.build(intrinsic, ctx)?; + rhs.build(intrinsic, ctx) + } + Self::ConvertConst(ty, num) => { + if let Some(w) = ty.wildcard() { + *ty = ctx.local.provide_type_wildcard(w)? + } + + if let Some(BaseType::Sized(BaseTypeKind::Float, _)) = ty.base() { + *self = Expression::FloatConstant(*num as f32) + } else { + *self = Expression::IntConstant(*num) + } + Ok(()) + } + Self::Type(ty) => { + if let Some(w) = ty.wildcard() { + *ty = ctx.local.provide_type_wildcard(w)? + } + + Ok(()) + } + _ => Ok(()), + } + } + + /// True if the expression requires an `unsafe` context in a safe function. + /// + /// The classification is somewhat fuzzy, based on actual usage (e.g. empirical function names) + /// rather than a full parse. This is a reasonable approach because mistakes here will usually + /// be caught at build time: + /// + /// - Missing an `unsafe` is a build error. + /// - An unnecessary `unsafe` is a warning, made into an error by the CI's `-D warnings`. + /// + /// This **panics** if it encounters an expression that shouldn't appear in a safe function at + /// all (such as `SvUndef`). + pub fn requires_unsafe_wrapper(&self, ctx_fn: &str) -> bool { + match self { + // The call will need to be unsafe, but the declaration does not. + Self::LLVMLink(..) => false, + // Identifiers, literals and type names are never unsafe. + Self::Identifier(..) => false, + Self::IntConstant(..) => false, + Self::FloatConstant(..) => false, + Self::BoolConstant(..) => false, + Self::Type(..) => false, + Self::ConvertConst(..) => false, + // Nested structures that aren't inherently unsafe, but could contain other expressions + // that might be. + Self::Assign(_var, exp) => exp.requires_unsafe_wrapper(ctx_fn), + Self::Let(LetVariant::Basic(_, exp) | LetVariant::WithType(_, _, exp)) => { + exp.requires_unsafe_wrapper(ctx_fn) + } + Self::Array(exps) => exps.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)), + Self::Multiply(lhs, rhs) => { + lhs.requires_unsafe_wrapper(ctx_fn) || rhs.requires_unsafe_wrapper(ctx_fn) + } + Self::CastAs(exp, _ty) => exp.requires_unsafe_wrapper(ctx_fn), + // Functions and macros can be unsafe, but can also contain other expressions. + Self::FnCall(FnCall(fn_exp, args, turbo_args)) => { + let fn_name = fn_exp.to_string(); + fn_exp.requires_unsafe_wrapper(ctx_fn) + || fn_name.starts_with("_sv") + || fn_name.starts_with("simd_") + || fn_name.ends_with("transmute") + || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) + || turbo_args + .iter() + .any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) + } + Self::MethodCall(exp, fn_name, args) => match fn_name.as_str() { + // `as_signed` and `as_unsigned` are unsafe because they're trait methods with + // target features to allow use on feature-dependent types (such as SVE vectors). + // We can safely wrap them here. + "as_signed" => true, + "as_unsigned" => true, + _ => { + exp.requires_unsafe_wrapper(ctx_fn) + || args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)) + } + }, + // We only use macros to check const generics (using static assertions). + Self::MacroCall(_name, _args) => false, + // Materialising uninitialised values is always unsafe, and we avoid it in safe + // functions. + Self::SvUndef => panic!("Refusing to wrap unsafe SvUndef in safe function '{ctx_fn}'."), + // Variants that aren't tokenised. We shouldn't encounter these here. + Self::MatchKind(..) => { + unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.") + } + Self::MatchSize(..) => { + unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.") + } + } + } +} + +impl FromStr for Expression { + type Err = String; + + fn from_str(s: &str) -> Result { + lazy_static! { + static ref MACRO_RE: Regex = + Regex::new(r"^(?P[\w\d_]+)!\((?P.*?)\);?$").unwrap(); + } + + if s == "SvUndef" { + Ok(Expression::SvUndef) + } else if MACRO_RE.is_match(s) { + let c = MACRO_RE.captures(s).unwrap(); + let ex = c["ex"].to_string(); + let _: TokenStream = ex + .parse() + .map_err(|e| format!("could not parse macro call expression: {e:#?}"))?; + Ok(Expression::MacroCall(c["name"].to_string(), ex)) + } else { + let (s, id_type) = if let Some(varname) = s.strip_prefix('$') { + (varname, IdentifierType::Variable) + } else { + (s, IdentifierType::Symbol) + }; + let identifier = s.trim().parse()?; + Ok(Expression::Identifier(identifier, id_type)) + } + } +} + +impl From for Expression { + fn from(fn_call: FnCall) -> Self { + Expression::FnCall(fn_call) + } +} + +impl From for Expression { + fn from(ws: WildString) -> Self { + Expression::Identifier(ws, IdentifierType::Symbol) + } +} + +impl From<&Argument> for Expression { + fn from(a: &Argument) -> Self { + Expression::Identifier(a.name.to_owned(), IdentifierType::Variable) + } +} + +impl TryFrom<&StaticDefinition> for Expression { + type Error = String; + + fn try_from(sd: &StaticDefinition) -> Result { + match sd { + StaticDefinition::Constant(imm) => Ok(imm.into()), + StaticDefinition::Generic(t) => t.parse(), + } + } +} + +impl fmt::Display for Expression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Identifier(identifier, kind) => { + write!( + f, + "{}{identifier}", + matches!(kind, IdentifierType::Variable) + .then_some("$") + .unwrap_or_default() + ) + } + Self::MacroCall(name, expression) => { + write!(f, "{name}!({expression})") + } + _ => Err(fmt::Error), + } + } +} + +impl ToTokens for Expression { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::Let(LetVariant::Basic(var_name, exp)) => { + let var_ident = format_ident!("{}", var_name.to_string()); + tokens.append_all(quote! { let #var_ident = #exp }) + } + Self::Let(LetVariant::WithType(var_name, ty, exp)) => { + let var_ident = format_ident!("{}", var_name.to_string()); + tokens.append_all(quote! { let #var_ident: #ty = #exp }) + } + Self::Assign(var_name, exp) => { + let var_ident = format_ident!("{}", var_name); + tokens.append_all(quote! { #var_ident = #exp }) + } + Self::MacroCall(name, ex) => { + let name = format_ident!("{name}"); + let ex: TokenStream = ex.parse().unwrap(); + tokens.append_all(quote! { #name!(#ex) }) + } + Self::FnCall(fn_call) => fn_call.to_tokens(tokens), + Self::MethodCall(exp, fn_name, args) => { + let fn_ident = format_ident!("{}", fn_name); + tokens.append_all(quote! { #exp.#fn_ident(#(#args),*) }) + } + Self::Identifier(identifier, _) => { + assert!( + !identifier.has_wildcards(), + "expression {self:#?} was not built before calling to_tokens" + ); + identifier + .to_string() + .parse::() + .expect("invalid syntax") + .to_tokens(tokens); + } + Self::IntConstant(n) => tokens.append(Literal::i32_unsuffixed(*n)), + Self::FloatConstant(n) => tokens.append(Literal::f32_unsuffixed(*n)), + Self::BoolConstant(true) => tokens.append(format_ident!("true")), + Self::BoolConstant(false) => tokens.append(format_ident!("false")), + Self::Array(vec) => tokens.append_all(quote! { [ #(#vec),* ] }), + Self::LLVMLink(link) => link.to_tokens(tokens), + Self::CastAs(ex, ty) => { + let ty: TokenStream = ty.parse().expect("invalid syntax"); + tokens.append_all(quote! { #ex as #ty }) + } + Self::SvUndef => tokens.append_all(quote! { simd_reinterpret(()) }), + Self::Multiply(lhs, rhs) => tokens.append_all(quote! { #lhs * #rhs }), + Self::Type(ty) => ty.to_tokens(tokens), + _ => unreachable!("{self:?} cannot be converted to tokens."), + } + } +} + +impl Serialize for Expression { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::IntConstant(v) => serializer.serialize_i32(*v), + Self::FloatConstant(v) => serializer.serialize_f32(*v), + Self::BoolConstant(v) => serializer.serialize_bool(*v), + Self::Identifier(..) => serializer.serialize_str(&self.to_string()), + Self::MacroCall(..) => serializer.serialize_str(&self.to_string()), + _ => Expression::serialize(self, serializer), + } + } +} + +impl<'de> Deserialize<'de> for Expression { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct CustomExpressionVisitor; + + impl<'de> Visitor<'de> for CustomExpressionVisitor { + type Value = Expression; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("integer, float, boolean, string or map") + } + + fn visit_bool(self, v: bool) -> Result + where + E: de::Error, + { + Ok(Expression::BoolConstant(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: de::Error, + { + Ok(Expression::IntConstant(v as i32)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: de::Error, + { + Ok(Expression::IntConstant(v as i32)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: de::Error, + { + Ok(Expression::FloatConstant(v as f32)) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + FromStr::from_str(value).map_err(de::Error::custom) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let arr = std::iter::from_fn(|| seq.next_element::().transpose()) + .try_collect()?; + Ok(Expression::Array(arr)) + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + // `MapAccessDeserializer` is a wrapper that turns a `MapAccess` + // into a `Deserializer`, allowing it to be used as the input to T's + // `Deserialize` implementation. T then deserializes itself using + // the entries from the map visitor. + Expression::deserialize(de::value::MapAccessDeserializer::new(map)) + } + } + + deserializer.deserialize_any(CustomExpressionVisitor) + } +} diff --git a/crates/stdarch-gen2/src/input.rs b/crates/stdarch-gen2/src/input.rs new file mode 100644 index 0000000000..bb2414adec --- /dev/null +++ b/crates/stdarch-gen2/src/input.rs @@ -0,0 +1,432 @@ +use itertools::Itertools; +use serde::{de, Deserialize, Deserializer, Serialize}; + +use crate::{ + context::{self, GlobalContext}, + intrinsic::Intrinsic, + predicate_forms::{PredicateForm, PredicationMask, PredicationMethods}, + typekinds::TypeKind, + wildstring::WildString, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum InputType { + /// PredicateForm variant argument + #[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip. + PredicateForm(PredicateForm), + /// Operand from which to generate an N variant + #[serde(skip)] + NVariantOp(Option), + /// TypeKind variant argument + Type(TypeKind), +} + +impl InputType { + /// Optionally unwraps as a PredicateForm. + pub fn predicate_form(&self) -> Option<&PredicateForm> { + match self { + InputType::PredicateForm(pf) => Some(pf), + _ => None, + } + } + + /// Optionally unwraps as a mutable PredicateForm + pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> { + match self { + InputType::PredicateForm(pf) => Some(pf), + _ => None, + } + } + + /// Optionally unwraps as a TypeKind. + pub fn typekind(&self) -> Option<&TypeKind> { + match self { + InputType::Type(ty) => Some(ty), + _ => None, + } + } + + /// Optionally unwraps as a NVariantOp + pub fn n_variant_op(&self) -> Option<&WildString> { + match self { + InputType::NVariantOp(Some(op)) => Some(op), + _ => None, + } + } +} + +impl PartialOrd for InputType { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for InputType { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use std::cmp::Ordering::*; + + match (self, other) { + (InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2), + (InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2), + + (InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less, + (InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater, + (InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal, + + (InputType::Type(..), InputType::PredicateForm(..)) => Less, + (InputType::PredicateForm(..), InputType::Type(..)) => Greater, + + (InputType::Type(..), InputType::NVariantOp(..)) => Less, + (InputType::NVariantOp(..), InputType::Type(..)) => Greater, + + (InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less, + (InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater, + } + } +} + +mod many_or_one { + use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; + + pub fn serialize(vec: &Vec, serializer: S) -> Result + where + T: Serialize, + S: Serializer, + { + if vec.len() == 1 { + vec.first().unwrap().serialize(serializer) + } else { + vec.serialize(serializer) + } + } + + pub fn deserialize<'de, T, D>(deserializer: D) -> Result, D::Error> + where + T: Deserialize<'de>, + D: Deserializer<'de>, + { + #[derive(Debug, Clone, Serialize, Deserialize)] + #[serde(untagged)] + enum ManyOrOne { + Many(Vec), + One(T), + } + + match ManyOrOne::deserialize(deserializer)? { + ManyOrOne::Many(vec) => Ok(vec), + ManyOrOne::One(val) => Ok(vec![val]), + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct InputSet(#[serde(with = "many_or_one")] Vec); + +impl InputSet { + pub fn get(&self, idx: usize) -> Option<&InputType> { + self.0.get(idx) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> impl Iterator + '_ { + self.0.iter_mut() + } + + pub fn into_iter(self) -> impl Iterator + Clone { + self.0.into_iter() + } + + pub fn types_len(&self) -> usize { + self.iter().filter_map(|arg| arg.typekind()).count() + } + + pub fn typekind(&self, idx: Option) -> Option { + let types_len = self.types_len(); + self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| { + if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) { + None + } else { + arg.typekind().cloned() + } + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec); + +impl InputSetEntry { + pub fn new(input: Vec) -> Self { + Self(input) + } + + pub fn get(&self, idx: usize) -> Option<&InputSet> { + self.0.get(idx) + } +} + +fn validate_types<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let v: Vec = Vec::deserialize(deserializer)?; + + let mut it = v.iter(); + if let Some(first) = it.next() { + it.try_fold(first, |last, cur| { + if last.0.len() == cur.0.len() { + Ok(cur) + } else { + Err("the length of the InputSets and the product lists must match".to_string()) + } + }) + .map_err(de::Error::custom)?; + } + + Ok(v) +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct IntrinsicInput { + #[serde(default)] + #[serde(deserialize_with = "validate_types")] + pub types: Vec, + + #[serde(flatten)] + pub predication_methods: PredicationMethods, + + /// Generates a _n variant where the specified operand is a primitive type + /// that requires conversion to an SVE one. The `{_n}` wildcard is required + /// in the intrinsic's name, otherwise an error will be thrown. + #[serde(default)] + pub n_variant_op: WildString, +} + +impl IntrinsicInput { + /// Extracts all the possible variants as an iterator. + pub fn variants( + &self, + intrinsic: &Intrinsic, + ) -> context::Result + '_> { + let mut top_product = vec![]; + + if !self.types.is_empty() { + top_product.push( + self.types + .iter() + .flat_map(|ty_in| { + ty_in + .0 + .iter() + .map(|v| v.clone().into_iter()) + .multi_cartesian_product() + }) + .collect_vec(), + ) + } + + if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) { + top_product.push( + PredicateForm::compile_list(&mask, &self.predication_methods)? + .into_iter() + .map(|pf| vec![InputType::PredicateForm(pf)]) + .collect_vec(), + ) + } + + if !self.n_variant_op.is_empty() { + top_product.push(vec![ + vec![InputType::NVariantOp(None)], + vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))], + ]) + } + + let it = top_product + .into_iter() + .map(|v| v.into_iter()) + .multi_cartesian_product() + .map(|set| InputSet(set.into_iter().flatten().collect_vec())); + Ok(it) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneratorInput { + #[serde(flatten)] + pub ctx: GlobalContext, + pub intrinsics: Vec, +} + +#[cfg(test)] +mod tests { + use crate::{ + input::*, + predicate_forms::{DontCareMethod, ZeroingMethod}, + }; + + #[test] + fn test_empty() { + let str = r#"types: []"#; + let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); + let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter(); + assert_eq!(variants.next(), None); + } + + #[test] + fn test_product() { + let str = r#"types: +- [f64, f32] +- [i64, [f64, f32]] +"#; + let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); + let mut intrinsic = Intrinsic::default(); + intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap(); + let mut variants = input.variants(&intrinsic).unwrap().into_iter(); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("f64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::PredicateForm(PredicateForm::Merging), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("f64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::Type("f64".parse().unwrap()), + InputType::PredicateForm(PredicateForm::Merging), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::Type("f64".parse().unwrap()), + InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::PredicateForm(PredicateForm::Merging), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)), + ])), + ); + assert_eq!(variants.next(), None); + } + + #[test] + fn test_n_variant() { + let str = r#"types: +- [f64, f32] +n_variant_op: op2 +"#; + let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); + let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter(); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("f64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::NVariantOp(None), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("f64".parse().unwrap()), + InputType::Type("f32".parse().unwrap()), + InputType::NVariantOp(Some("op2".parse().unwrap())), + ])) + ); + assert_eq!(variants.next(), None) + } + + #[test] + fn test_invalid_length() { + let str = r#"types: [i32, [[u64], [u32]]]"#; + serde_yaml::from_str::(str).expect_err("failure expected"); + } + + #[test] + fn test_invalid_predication() { + let str = "types: []"; + let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); + let mut intrinsic = Intrinsic::default(); + intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap(); + input + .variants(&intrinsic) + .map(|v| v.collect_vec()) + .expect_err("failure expected"); + } + + #[test] + fn test_invalid_predication_mask() { + "test_intrinsic{_mxy}" + .parse::() + .expect_err("failure expected"); + "test_intrinsic{_}" + .parse::() + .expect_err("failure expected"); + } + + #[test] + fn test_zeroing_predication() { + let str = r#"types: [i64] +zeroing_method: { drop: inactive }"#; + let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse"); + let mut intrinsic = Intrinsic::default(); + intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap(); + let mut variants = input.variants(&intrinsic).unwrap(); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::PredicateForm(PredicateForm::Merging), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)), + ])) + ); + assert_eq!( + variants.next(), + Some(InputSet(vec![ + InputType::Type("i64".parse().unwrap()), + InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop { + drop: "inactive".parse().unwrap() + })), + ])) + ); + assert_eq!(variants.next(), None) + } +} diff --git a/crates/stdarch-gen2/src/intrinsic.rs b/crates/stdarch-gen2/src/intrinsic.rs new file mode 100644 index 0000000000..d05b71e44d --- /dev/null +++ b/crates/stdarch-gen2/src/intrinsic.rs @@ -0,0 +1,1498 @@ +use itertools::Itertools; +use proc_macro2::{Punct, Spacing, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use serde::{Deserialize, Serialize}; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::ops::RangeInclusive; +use std::str::FromStr; + +use crate::assert_instr::InstructionAssertionsForBaseType; +use crate::context::{GlobalContext, GroupContext}; +use crate::input::{InputSet, InputSetEntry}; +use crate::predicate_forms::{DontCareMethod, PredicateForm, PredicationMask, ZeroingMethod}; +use crate::{ + assert_instr::InstructionAssertionMethod, + context::{self, ArchitectureSettings, Context, LocalContext, VariableType}, + expression::{Expression, FnCall, IdentifierType}, + input::IntrinsicInput, + matching::{KindMatchable, SizeMatchable}, + typekinds::*, + wildcards::Wildcard, + wildstring::WildString, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum SubstitutionType { + MatchSize(SizeMatchable), + MatchKind(KindMatchable), +} + +impl SubstitutionType { + pub fn get(&mut self, ctx: &LocalContext) -> context::Result { + match self { + Self::MatchSize(smws) => { + smws.perform_match(ctx)?; + Ok(smws.as_ref().clone()) + } + Self::MatchKind(kmws) => { + kmws.perform_match(ctx)?; + Ok(kmws.as_ref().clone()) + } + } + } +} + +/// Mutability level +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AccessLevel { + /// Immutable + R, + /// Mutable + RW, +} + +/// Function signature argument. +/// +/// Prepend the `mut` keyword for a mutable argument. Separate argument name +/// and type with a semicolon `:`. Usage examples: +/// - Mutable argument: `mut arg1: *u64` +/// - Immutable argument: `arg2: u32` +#[derive(Debug, Clone, SerializeDisplay, DeserializeFromStr)] +pub struct Argument { + /// Argument name + pub name: WildString, + /// Mutability level + pub rw: AccessLevel, + /// Argument type + pub kind: TypeKind, +} + +impl Argument { + pub fn populate_variables(&self, vars: &mut HashMap) { + vars.insert( + self.name.to_string(), + (self.kind.clone(), VariableType::Argument), + ); + } +} + +impl FromStr for Argument { + type Err = String; + + fn from_str(s: &str) -> Result { + let mut it = s.splitn(2, ':').map(::trim); + if let Some(mut lhs) = it.next().map(|s| s.split_whitespace()) { + let lhs_len = lhs.clone().count(); + match (lhs_len, lhs.next(), it.next()) { + (2, Some("mut"), Some(kind)) => Ok(Argument { + name: lhs.next().unwrap().parse()?, + rw: AccessLevel::RW, + kind: kind.parse()?, + }), + (2, Some(ident), _) => Err(format!("invalid {ident:#?} keyword")), + (1, Some(name), Some(kind)) => Ok(Argument { + name: name.parse()?, + rw: AccessLevel::R, + kind: kind.parse()?, + }), + _ => Err(format!("invalid argument `{s}` provided")), + } + } else { + Err(format!("invalid argument `{s}` provided")) + } + } +} + +impl fmt::Display for Argument { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let AccessLevel::RW = &self.rw { + write!(f, "mut ")?; + } + + write!(f, "{}: {}", self.name, self.kind) + } +} + +impl ToTokens for Argument { + fn to_tokens(&self, tokens: &mut TokenStream) { + if let AccessLevel::RW = &self.rw { + tokens.append(format_ident!("mut")) + } + + let (name, kind) = (format_ident!("{}", self.name.to_string()), &self.kind); + tokens.append_all(quote! { #name: #kind }) + } +} + +/// Static definition part of the signature. It may evaluate to a constant +/// expression with e.g. `const imm: u64`, or a generic `T: Into`. +#[derive(Debug, Clone, SerializeDisplay, DeserializeFromStr)] +pub enum StaticDefinition { + /// Constant expression + Constant(Argument), + /// Generic type + Generic(String), +} + +impl StaticDefinition { + pub fn as_variable(&self) -> Option<(String, (TypeKind, VariableType))> { + match self { + StaticDefinition::Constant(arg) => Some(( + arg.name.to_string(), + (arg.kind.clone(), VariableType::Argument), + )), + StaticDefinition::Generic(..) => None, + } + } +} + +impl FromStr for StaticDefinition { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim() { + s if s.starts_with("const ") => Ok(StaticDefinition::Constant(s[6..].trim().parse()?)), + s => Ok(StaticDefinition::Generic(s.to_string())), + } + } +} + +impl fmt::Display for StaticDefinition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StaticDefinition::Constant(arg) => write!(f, "const {arg}"), + StaticDefinition::Generic(generic) => write!(f, "{generic}"), + } + } +} + +impl ToTokens for StaticDefinition { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.append_all(match self { + StaticDefinition::Constant(arg) => quote! { const #arg }, + StaticDefinition::Generic(generic) => { + let generic: TokenStream = generic.parse().expect("invalid Rust code"); + quote! { #generic } + } + }) + } +} + +/// Function constraints +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Constraint { + /// Asserts that the given variable equals to any of the given integer values + AnyI32 { + variable: String, + any_values: Vec, + }, + /// WildString version of RangeI32. If the string values given for the range + /// are valid, this gets built into a RangeI32. + RangeWildstring { + variable: String, + range: (WildString, WildString), + }, + /// Asserts that the given variable's value falls in the specified range + RangeI32 { + variable: String, + range: SizeMatchable>, + }, + /// Asserts that the number of elements/lanes does not exceed the 2048-bit SVE constraint + SVEMaxElems { + variable: String, + sve_max_elems_type: TypeKind, + }, + /// Asserts that the number of elements/lanes does not exceed the 128-bit register constraint + VecMaxElems { + variable: String, + vec_max_elems_type: TypeKind, + }, +} + +impl Constraint { + fn variable(&self) -> &str { + match self { + Constraint::AnyI32 { variable, .. } + | Constraint::RangeWildstring { variable, .. } + | Constraint::RangeI32 { variable, .. } + | Constraint::SVEMaxElems { variable, .. } + | Constraint::VecMaxElems { variable, .. } => variable, + } + } + pub fn build(&mut self, ctx: &Context) -> context::Result { + if let Self::RangeWildstring { + variable, + range: (min, max), + } = self + { + min.build_acle(ctx.local)?; + max.build_acle(ctx.local)?; + let min = min.to_string(); + let max = max.to_string(); + let min: i32 = min + .parse() + .map_err(|_| format!("the minimum value `{min}` is not a valid number"))?; + let max: i32 = max + .parse() + .map_err(|_| format!("the maximum value `{max}` is not a valid number"))?; + *self = Self::RangeI32 { + variable: variable.to_owned(), + range: SizeMatchable::Matched(RangeInclusive::new(min, max)), + } + } + + if let Self::SVEMaxElems { + sve_max_elems_type: ty, + .. + } + | Self::VecMaxElems { + vec_max_elems_type: ty, + .. + } = self + { + if let Some(w) = ty.wildcard() { + ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?; + } + } + + if let Self::RangeI32 { range, .. } = self { + range.perform_match(ctx.local)?; + } + + let variable = self.variable(); + ctx.local + .variables + .contains_key(variable) + .then_some(()) + .ok_or_else(|| format!("cannot build constraint, could not find variable {variable}")) + } +} + +/// Function signature +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Signature { + /// Function name + pub name: WildString, + /// List of function arguments, leave unset or empty for no arguments + pub arguments: Vec, + /// Function return type, leave unset for void + pub return_type: Option, + + /// List of static definitions, leave unset of empty if not required + #[serde(default)] + pub static_defs: Vec, + + /// **Internal use only.** + /// Condition for which the ultimate function is specific to predicates. + #[serde(skip)] + pub is_predicate_specific: bool, + + /// **Internal use only.** + /// Setting this property will trigger the signature builder to convert any `svbool*_t` to `svbool_t` in the input and output. + #[serde(skip)] + pub predicate_needs_conversion: bool, +} + +impl Signature { + pub fn drop_argument(&mut self, arg_name: &WildString) -> Result<(), String> { + if let Some(idx) = self + .arguments + .iter() + .position(|arg| arg.name.to_string() == arg_name.to_string()) + { + self.arguments.remove(idx); + Ok(()) + } else { + Err(format!("no argument {arg_name} found to drop")) + } + } + + pub fn build(&mut self, ctx: &LocalContext) -> context::Result { + self.name.build_acle(ctx)?; + + if let Some(ref mut return_type) = self.return_type { + if let Some(w) = return_type.clone().wildcard() { + return_type.populate_wildcard(ctx.provide_type_wildcard(w)?)?; + } + } + + self.arguments + .iter_mut() + .try_for_each(|arg| arg.name.build_acle(ctx))?; + + self.arguments + .iter_mut() + .filter_map(|arg| { + arg.kind + .clone() + .wildcard() + .map(|w| (&mut arg.kind, w.clone())) + }) + .try_for_each(|(ty, w)| ty.populate_wildcard(ctx.provide_type_wildcard(&w)?)) + } + + pub fn fn_name(&self) -> WildString { + self.name.replace(['[', ']'], "") + } + + pub fn doc_name(&self) -> String { + self.name.to_string() + } +} + +impl ToTokens for Signature { + fn to_tokens(&self, tokens: &mut TokenStream) { + let name_ident = format_ident!("{}", self.fn_name().to_string()); + let arguments = self + .arguments + .clone() + .into_iter() + .map(|mut arg| { + if arg + .kind + .vector() + .map_or(false, |ty| ty.base_type().is_bool()) + && self.predicate_needs_conversion + { + arg.kind = TypeKind::Vector(VectorType::make_predicate_from_bitsize(8)) + } + arg + }) + .collect_vec(); + let static_defs = &self.static_defs; + tokens.append_all(quote! { fn #name_ident<#(#static_defs),*>(#(#arguments),*) }); + + if let Some(ref return_type) = self.return_type { + if return_type + .vector() + .map_or(false, |ty| ty.base_type().is_bool()) + && self.predicate_needs_conversion + { + tokens.append_all(quote! { -> svbool_t }) + } else { + tokens.append_all(quote! { -> #return_type }) + } + } + } +} + +#[derive(Debug, Clone)] +pub struct LLVMLinkAttribute { + pub arch: String, + pub link: String, +} + +impl ToTokens for LLVMLinkAttribute { + fn to_tokens(&self, tokens: &mut TokenStream) { + let LLVMLinkAttribute { arch, link } = self; + tokens.append_all(quote! { + #[cfg_attr(target_arch = #arch, link_name = #link)] + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LLVMLink { + /// LLVM link function name without namespace and types, + /// e.g. `st1` in `llvm.aarch64.sve.st1.nxv4i32` + pub name: WildString, + + /// LLVM link signature arguments, leave unset if it inherits from intrinsic's signature + pub arguments: Option>, + /// LLVM link signature return type, leave unset if it inherits from intrinsic's signature + pub return_type: Option, + + /// **Internal use only. Do not set.** + /// Attribute LLVM links for the function. First element is the architecture it targets, + /// second element is the LLVM link itself. + #[serde(skip)] + pub links: Option>, + + /// **Internal use only. Do not set.** + /// Generated signature from these `arguments` and/or `return_type` if set, and the intrinsic's signature. + #[serde(skip)] + pub signature: Option>, +} + +impl LLVMLink { + pub fn resolve(&self, cfg: &ArchitectureSettings) -> String { + self.name + .starts_with("llvm") + .then(|| self.name.to_string()) + .unwrap_or_else(|| format!("{}.{}", cfg.llvm_link_prefix, self.name)) + } + + pub fn build_and_save(&mut self, ctx: &mut Context) -> context::Result { + self.build(ctx)?; + + // Save LLVM link to the group context + ctx.global.arch_cfgs.iter().for_each(|cfg| { + ctx.group + .links + .insert(self.resolve(cfg), ctx.local.input.clone()); + }); + + Ok(()) + } + + pub fn build(&mut self, ctx: &mut Context) -> context::Result { + let mut sig_name = ctx.local.signature.name.clone(); + sig_name.prepend_str("_"); + + let mut sig = Signature { + name: sig_name, + arguments: self + .arguments + .clone() + .unwrap_or_else(|| ctx.local.signature.arguments.clone()), + return_type: self + .return_type + .clone() + .or_else(|| ctx.local.signature.return_type.clone()), + static_defs: vec![], + is_predicate_specific: ctx.local.signature.is_predicate_specific, + predicate_needs_conversion: false, + }; + + sig.build(ctx.local)?; + self.name.build(ctx.local, TypeRepr::LLVMMachine)?; + + // Add link function name to context + ctx.local + .substitutions + .insert(Wildcard::LLVMLink, sig.fn_name().to_string()); + + self.signature = Some(Box::new(sig)); + self.links = Some( + ctx.global + .arch_cfgs + .iter() + .map(|cfg| LLVMLinkAttribute { + arch: cfg.arch_name.to_owned(), + link: self.resolve(cfg), + }) + .collect_vec(), + ); + + Ok(()) + } + + /// Alters all the unsigned types from the signature, as unsupported by LLVM. + pub fn sanitise_uints(&mut self) { + let transform = |tk: &mut TypeKind| { + if let Some(BaseType::Sized(BaseTypeKind::UInt, size)) = tk.base_type() { + *tk.base_type_mut().unwrap() = BaseType::Sized(BaseTypeKind::Int, *size) + } + }; + + if let Some(sig) = self.signature.as_mut() { + for arg in sig.arguments.iter_mut() { + transform(&mut arg.kind); + } + + sig.return_type.as_mut().map(transform); + } + } + + /// Make a function call to the LLVM link + pub fn make_fn_call(&self, intrinsic_sig: &Signature) -> context::Result { + let link_sig = self.signature.as_ref().ok_or_else(|| { + "cannot derive the LLVM link call, as it does not hold a valid function signature" + .to_string() + })?; + + if intrinsic_sig.arguments.len() != link_sig.arguments.len() { + return Err( + "cannot derive the LLVM link call, the number of arguments does not match" + .to_string(), + ); + } + + let call_args = intrinsic_sig + .arguments + .iter() + .zip(link_sig.arguments.iter()) + .map(|(intrinsic_arg, link_arg)| { + // Could also add a type check... + if intrinsic_arg.name == link_arg.name { + Ok(Expression::Identifier( + intrinsic_arg.name.to_owned(), + IdentifierType::Variable, + )) + } else { + Err("cannot derive the LLVM link call, the arguments do not match".to_string()) + } + }) + .try_collect()?; + + Ok(FnCall::new_expression(link_sig.fn_name().into(), call_args)) + } + + /// Given a FnCall, apply all the predicate and unsigned conversions as required. + pub fn apply_conversions_to_call( + &self, + mut fn_call: FnCall, + ctx: &LocalContext, + ) -> context::Result { + use BaseType::{Sized, Unsized}; + use BaseTypeKind::{Bool, UInt}; + use VariableType::Argument; + + let convert = + |method: &str, ex| Expression::MethodCall(Box::new(ex), method.to_string(), vec![]); + + fn_call.1 = fn_call + .1 + .into_iter() + .map(|arg| -> context::Result { + if let Expression::Identifier(ref var_name, IdentifierType::Variable) = arg { + let (kind, scope) = ctx + .variables + .get(&var_name.to_string()) + .ok_or_else(|| format!("invalid variable {var_name:?} being referenced"))?; + + match (scope, kind.base_type()) { + (Argument, Some(Sized(Bool, bitsize))) if *bitsize != 8 => { + Ok(convert("into", arg)) + } + (Argument, Some(Sized(UInt, _) | Unsized(UInt))) => { + Ok(convert("as_signed", arg)) + } + _ => Ok(arg), + } + } else { + Ok(arg) + } + }) + .try_collect()?; + + let return_type_requires_conversion = self + .signature + .as_ref() + .and_then(|sig| sig.return_type.as_ref()) + .and_then(|ty| { + if let Some(Sized(Bool, bitsize)) = ty.base_type() { + (*bitsize != 8).then_some(Bool) + } else if let Some(Sized(UInt, _) | Unsized(UInt)) = ty.base_type() { + Some(UInt) + } else { + None + } + }); + + let fn_call = Expression::FnCall(fn_call); + match return_type_requires_conversion { + Some(Bool) => Ok(convert("into", fn_call)), + Some(UInt) => Ok(convert("as_unsigned", fn_call)), + _ => Ok(fn_call), + } + } +} + +impl ToTokens for LLVMLink { + fn to_tokens(&self, tokens: &mut TokenStream) { + assert!( + self.signature.is_some() && self.links.is_some(), + "expression {self:#?} was not built before calling to_tokens" + ); + + let signature = self.signature.as_ref().unwrap(); + let links = self.links.as_ref().unwrap(); + tokens.append_all(quote! { + extern "C" { + #(#links),* + #signature; + } + }) + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FunctionVisibility { + #[default] + Public, + Private, +} + +/// Whether to generate a load/store test, and which typeset index +/// represents the data type of the load/store target address +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Test { + #[default] + #[serde(skip)] + None, // Covered by `intrinsic-test` + Load(usize), + Store(usize), +} + +impl Test { + pub fn get_typeset_index(&self) -> Option { + match *self { + Test::Load(n) => Some(n), + Test::Store(n) => Some(n), + _ => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Safety { + Safe, + Unsafe(Vec), +} + +impl Safety { + /// Return `Ok(Safety::Safe)` if safety appears reasonable for the given `intrinsic`'s name and + /// prototype. Otherwise, return `Err()` with a suitable diagnostic. + fn safe_checked(intrinsic: &Intrinsic) -> Result { + let name = intrinsic.signature.doc_name(); + if name.starts_with("sv") { + let handles_pointers = intrinsic + .signature + .arguments + .iter() + .any(|arg| matches!(arg.kind, TypeKind::Pointer(..))); + if name.starts_with("svld") + || name.starts_with("svst") + || name.starts_with("svprf") + || name.starts_with("svundef") + || handles_pointers + { + let doc = intrinsic.doc.as_ref().map(|s| s.to_string()); + let doc = doc.as_deref().unwrap_or("..."); + Err(format!( + "`{name}` has no safety specification, but it looks like it should be unsafe. \ + Consider specifying (un)safety explicitly: + + - name: {name} + doc: {doc} + safety: + unsafe: + - ... + ... +" + )) + } else { + Ok(Self::Safe) + } + } else { + Err(format!( + "Safety::safe_checked() for non-SVE intrinsic: {name}" + )) + } + } + + fn is_safe(&self) -> bool { + match self { + Self::Safe => true, + Self::Unsafe(..) => false, + } + } + + fn is_unsafe(&self) -> bool { + !self.is_safe() + } + + fn has_doc_comments(&self) -> bool { + match self { + Self::Safe => false, + Self::Unsafe(v) => !v.is_empty(), + } + } + + fn doc_comments(&self) -> &[UnsafetyComment] { + match self { + Self::Safe => &[], + Self::Unsafe(v) => v.as_slice(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UnsafetyComment { + Custom(String), + Uninitialized, + PointerOffset(GovernedBy), + PointerOffsetVnum(GovernedBy), + Dereference(GovernedBy), + UnpredictableOnFault, + NonTemporal, + NoProvenance(String), +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum GovernedBy { + #[default] + Predicated, + PredicatedNonFaulting, + PredicatedFirstFaulting, +} + +impl fmt::Display for GovernedBy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Predicated => write!(f, " (governed by `pg`)"), + Self::PredicatedNonFaulting => write!( + f, + " (governed by `pg`, the first-fault register (`FFR`) \ + and non-faulting behaviour)" + ), + Self::PredicatedFirstFaulting => write!( + f, + " (governed by `pg`, the first-fault register (`FFR`) \ + and first-faulting behaviour)" + ), + } + } +} + +impl fmt::Display for UnsafetyComment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Custom(s) => s.fmt(f), + Self::Uninitialized => write!( + f, + "This creates an uninitialized value, and may be unsound (like \ + [`core::mem::uninitialized`])." + ), + Self::PointerOffset(gov) => write!( + f, + "[`pointer::offset`](pointer#method.offset) safety constraints must \ + be met for the address calculation for each active element{gov}." + ), + Self::PointerOffsetVnum(gov) => write!( + f, + "[`pointer::offset`](pointer#method.offset) safety constraints must \ + be met for the address calculation for each active element{gov}. \ + In particular, note that `vnum` is scaled by the vector \ + length, `VL`, which is not known at compile time." + ), + Self::Dereference(gov) => write!( + f, + "This dereferences and accesses the calculated address for each \ + active element{gov}." + ), + Self::NonTemporal => write!( + f, + "Non-temporal accesses have special memory ordering rules, and \ + [explicit barriers may be required for some applications]\ + (https://developer.arm.com/documentation/den0024/a/Memory-Ordering/Barriers/Non-temporal-load-and-store-pair?lang=en)." + ), + Self::NoProvenance(arg) => write!( + f, + "Addresses passed in `{arg}` lack provenance, so this is similar to using a \ + `usize as ptr` cast (or [`core::ptr::from_exposed_addr`]) on each lane before \ + using it." + ), + Self::UnpredictableOnFault => write!( + f, + "Result lanes corresponding to inactive FFR lanes (either before or as a result \ + of this intrinsic) have \"CONSTRAINED UNPREDICTABLE\" values, irrespective of \ + predication. Refer to architectural documentation for details." + ), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Intrinsic { + #[serde(default)] + pub visibility: FunctionVisibility, + #[serde(default)] + pub doc: Option, + #[serde(flatten)] + pub signature: Signature, + /// Function sequential composition + pub compose: Vec, + /// Input to generate the intrinsic against. Leave empty if the intrinsic + /// does not have any variants. + /// Specific variants contain one InputSet + #[serde(flatten, default)] + pub input: IntrinsicInput, + #[serde(default)] + pub constraints: Vec, + /// Additional target features to add to the global settings + #[serde(default)] + pub target_features: Vec, + /// Should the intrinsic be `unsafe`? By default, the generator will try to guess from the + /// prototype, but it errs on the side of `unsafe`, and prints a warning in that case. + #[serde(default)] + pub safety: Option, + #[serde(default)] + pub substitutions: HashMap, + /// List of the only indices in a typeset that require conversion to signed + /// when deferring unsigned intrinsics to signed. (optional, default + /// behaviour is all unsigned types are converted to signed) + #[serde(default)] + pub defer_to_signed_only_indices: HashSet, + pub assert_instr: Vec, + /// Whether we should generate a test for this intrinsic + #[serde(default)] + pub test: Test, + /// Primary base type, used for instruction assertion. + #[serde(skip)] + pub base_type: Option, +} + +impl Intrinsic { + pub fn llvm_link(&self) -> Option<&LLVMLink> { + self.compose.iter().find_map(|ex| { + if let Expression::LLVMLink(llvm_link) = ex { + Some(llvm_link) + } else { + None + } + }) + } + + pub fn llvm_link_mut(&mut self) -> Option<&mut LLVMLink> { + self.compose.iter_mut().find_map(|ex| { + if let Expression::LLVMLink(llvm_link) = ex { + Some(llvm_link) + } else { + None + } + }) + } + + pub fn generate_variants(&self, global_ctx: &GlobalContext) -> context::Result> { + let wrap_err = |err| format!("{}: {err}", self.signature.name); + + let mut group_ctx = GroupContext::default(); + self.input + .variants(self) + .map_err(wrap_err)? + .map(|input| { + self.generate_variant(input.clone(), &mut group_ctx, global_ctx) + .map_err(wrap_err) + .map(|variant| (variant, input)) + }) + .collect::>>() + .and_then(|mut variants| { + variants.sort_by_cached_key(|(_, input)| input.to_owned()); + + if variants.is_empty() { + let standalone_variant = self + .generate_variant(InputSet::default(), &mut group_ctx, global_ctx) + .map_err(wrap_err)?; + + Ok(vec![standalone_variant]) + } else { + Ok(variants + .into_iter() + .map(|(variant, _)| variant) + .collect_vec()) + } + }) + } + + pub fn generate_variant( + &self, + input: InputSet, + group_ctx: &mut GroupContext, + global_ctx: &GlobalContext, + ) -> context::Result { + let mut variant = self.clone(); + + variant.input.types = vec![InputSetEntry::new(vec![input.clone()])]; + + let mut local_ctx = LocalContext::new(input, self); + let mut ctx = Context { + local: &mut local_ctx, + group: group_ctx, + global: global_ctx, + }; + + variant.pre_build(&mut ctx)?; + + match ctx.local.predicate_form().cloned() { + Some(PredicateForm::DontCare(method)) => { + variant.compose = variant.generate_dont_care_pass_through(&mut ctx, method)? + } + Some(PredicateForm::Zeroing(method)) => { + variant.compose = variant.generate_zeroing_pass_through(&mut ctx, method)? + } + _ => { + for idx in 0..variant.compose.len() { + let mut ex = variant.compose[idx].clone(); + ex.build(&variant, &mut ctx)?; + variant.compose[idx] = ex; + } + } + }; + + variant.post_build(&mut ctx)?; + + if let Some(n_variant_op) = ctx.local.n_variant_op().cloned() { + variant.generate_n_variant(n_variant_op, &mut ctx) + } else { + Ok(variant) + } + } + + /// Implement a "zeroing" (_z) method by calling an existing "merging" (_m) method, as required. + fn generate_zeroing_pass_through( + &mut self, + ctx: &mut Context, + method: ZeroingMethod, + ) -> context::Result> { + PredicationMask::try_from(&ctx.local.signature.name) + .ok() + .filter(|mask| mask.has_merging()) + .ok_or_else(|| format!("cannot generate zeroing passthrough for {}, no merging predicate form is specified", self.signature.name))?; + + // Determine the function to pass through to. + let mut target_ctx = ctx.local.clone(); + // Change target function predicate form to merging + *target_ctx.input.iter_mut() + .find_map(|arg| arg.predicate_form_mut()) + .expect("failed to generate zeroing pass through, could not find predicate form in the InputSet") = PredicateForm::Merging; + + let mut sig = target_ctx.signature.clone(); + sig.build(&target_ctx)?; + + let args_as_expressions = |arg: &Argument| -> context::Result { + let arg_name = arg.name.to_string(); + match &method { + ZeroingMethod::Drop { drop } if arg_name == drop.to_string() => { + Ok(PredicateForm::make_zeroinitializer(&arg.kind)) + } + ZeroingMethod::Select { select } if arg_name == select.to_string() => { + let pg = sig + .arguments + .iter() + .find_map(|arg| match arg.kind.vector() { + Some(ty) if ty.base_type().is_bool() => Some(arg.name.clone()), + _ => None, + }) + .ok_or_else(|| { + format!("cannot generate zeroing passthrough for {}, no predicate found in the signature for zero selection", self.signature.name) + })?; + Ok(PredicateForm::make_zeroselector( + pg, + select.clone(), + &arg.kind, + )) + } + _ => Ok(arg.into()), + } + }; + + let name: Expression = sig.fn_name().into(); + let args: Vec = sig + .arguments + .iter() + .map(args_as_expressions) + .try_collect()?; + let statics: Vec = sig + .static_defs + .iter() + .map(|sd| sd.try_into()) + .try_collect()?; + let mut call: Expression = FnCall(Box::new(name), args, statics).into(); + call.build(self, ctx)?; + Ok(vec![call]) + } + + /// Implement a "don't care" (_x) method by calling an existing "merging" (_m). + fn generate_dont_care_pass_through( + &mut self, + ctx: &mut Context, + method: DontCareMethod, + ) -> context::Result> { + PredicationMask::try_from(&ctx.local.signature.name).and_then(|mask| match method { + DontCareMethod::AsMerging if mask.has_merging() => Ok(()), + DontCareMethod::AsZeroing if mask.has_zeroing() => Ok(()), + _ => Err(format!( + "cannot generate don't care passthrough for {}, no {} predicate form is specified", + self.signature.name, + match method { + DontCareMethod::AsMerging => "merging", + DontCareMethod::AsZeroing => "zeroing", + _ => unreachable!(), + } + )), + })?; + + // Determine the function to pass through to. + let mut target_ctx = ctx.local.clone(); + // Change target function predicate form to merging + *target_ctx.input.iter_mut() + .find_map(|arg| arg.predicate_form_mut()) + .expect("failed to generate don't care passthrough, could not find predicate form in the InputSet") = PredicateForm::Merging; + + let mut sig = target_ctx.signature.clone(); + sig.build(&target_ctx)?; + + // We might need to drop an argument for a zeroing pass-through. + let drop = match (method, &self.input.predication_methods.zeroing_method) { + (DontCareMethod::AsZeroing, Some(ZeroingMethod::Drop { drop })) => Some(drop), + _ => None, + }; + + let name: Expression = sig.fn_name().into(); + let args: Vec = sig + .arguments + .iter() + .map(|arg| { + if Some(arg.name.to_string()) == drop.as_ref().map(|v| v.to_string()) { + // This argument is present in the _m form, but missing from the _x form. Clang + // typically replaces these with an uninitialised vector, but to avoid + // materialising uninitialised values in Rust, we instead merge with a known + // vector. This usually results in the same code generation. + // TODO: In many cases, it'll be better to use an unpredicated (or zeroing) form. + sig.arguments + .iter() + .filter(|&other| arg.name.to_string() != other.name.to_string()) + .find_map(|other| { + arg.kind.express_reinterpretation_from(&other.kind, other) + }) + .unwrap_or_else(|| PredicateForm::make_zeroinitializer(&arg.kind)) + } else { + arg.into() + } + }) + .collect(); + let statics: Vec = sig + .static_defs + .iter() + .map(|sd| sd.try_into()) + .try_collect()?; + let mut call: Expression = FnCall(Box::new(name), args, statics).into(); + call.build(self, ctx)?; + Ok(vec![call]) + } + + /// Implement a "_n" variant based on the given operand + fn generate_n_variant( + &self, + mut n_variant_op: WildString, + ctx: &mut Context, + ) -> context::Result { + let mut variant = self.clone(); + + n_variant_op.build_acle(ctx.local)?; + + let n_op_arg_idx = variant + .signature + .arguments + .iter_mut() + .position(|arg| arg.name.to_string() == n_variant_op.to_string()) + .ok_or_else(|| { + format!( + "cannot generate `_n` variant for {}, operand `{n_variant_op}` not found", + variant.signature.name + ) + })?; + + let has_n_wildcard = ctx + .local + .signature + .name + .wildcards() + .any(|w| matches!(w, Wildcard::NVariant)); + + if !has_n_wildcard { + return Err(format!("cannot generate `_n` variant for {}, no wildcard {{_n}} was specified in the intrinsic's name", variant.signature.name)); + } + + // Build signature + variant.signature = ctx.local.signature.clone(); + if let Some(pf) = ctx.local.predicate_form() { + // WARN: this may break in the future according to the underlying implementation + // Drops unwanted arguments if needed (required for the collection of arguments to pass to the function) + pf.post_build(&mut variant)?; + } + + let sig = &mut variant.signature; + + ctx.local + .substitutions + .insert(Wildcard::NVariant, "_n".to_owned()); + + let arg_kind = &mut sig.arguments.get_mut(n_op_arg_idx).unwrap().kind; + *arg_kind = match arg_kind { + TypeKind::Wildcard(Wildcard::SVEType(idx, None)) => { + TypeKind::Wildcard(Wildcard::Type(*idx)) + } + _ => { + return Err(format!( + "cannot generate `_n` variant for {}, the given operand is not a valid SVE type", + variant.signature.name + )) + } + }; + + sig.build(ctx.local)?; + + // Build compose + let name: Expression = self.signature.fn_name().into(); + let args: Vec = sig + .arguments + .iter() + .enumerate() + .map(|(idx, arg)| { + let ty = arg.kind.acle_notation_repr(); + if idx == n_op_arg_idx { + FnCall::new_expression( + WildString::from(format!("svdup_n_{ty}")).into(), + vec![arg.into()], + ) + } else { + arg.into() + } + }) + .collect(); + let statics: Vec = sig + .static_defs + .iter() + .map(|sd| sd.try_into()) + .try_collect()?; + let mut call: Expression = FnCall(Box::new(name), args, statics).into(); + call.build(self, ctx)?; + + variant.compose = vec![call]; + variant.signature.predicate_needs_conversion = true; + + Ok(variant) + } + + fn pre_build(&mut self, ctx: &mut Context) -> context::Result { + self.substitutions + .iter_mut() + .try_for_each(|(k, v)| -> context::Result { + let mut ws = v.get(ctx.local)?; + ws.build_acle(ctx.local)?; + ctx.local + .substitutions + .insert(Wildcard::Custom(k.to_owned()), ws.to_string()); + Ok(()) + })?; + + self.signature.build(ctx.local)?; + + if self.safety.is_none() { + self.safety = match Safety::safe_checked(self) { + Ok(safe) => Some(safe), + Err(err) => { + eprintln!("{err}"); + return Err(format!( + "Refusing to infer unsafety for {name}", + name = self.signature.doc_name() + )); + } + } + } + + if let Some(doc) = &mut self.doc { + doc.build_acle(ctx.local)? + } + + // Add arguments to variable tracking + self.signature + .arguments + .iter() + .for_each(|arg| arg.populate_variables(&mut ctx.local.variables)); + + // Add constant expressions to variable tracking + self.signature + .static_defs + .iter() + .filter_map(StaticDefinition::as_variable) + .for_each(|(var_name, var_properties)| { + ctx.local.variables.insert(var_name, var_properties); + }); + + // Pre-build compose expressions + for idx in 0..self.compose.len() { + let mut ex = self.compose[idx].clone(); + ex.pre_build(ctx)?; + self.compose[idx] = ex; + } + + if !ctx.local.input.is_empty() { + // We simplify the LLVM link transmute logic by deferring to a variant employing the same LLVM link where possible + if let Some(link) = self.compose.iter().find_map(|ex| match ex { + Expression::LLVMLink(link) => Some(link), + _ => None, + }) { + let mut link = link.clone(); + link.build(ctx)?; + + for cfg in ctx.global.arch_cfgs.iter() { + let expected_link = link.resolve(cfg); + if let Some(target_inputset) = ctx.group.links.get(&expected_link) { + self.defer_to_existing_llvm_link(ctx.local, target_inputset)?; + break; + } + } + } + } + + self.assert_instr + .iter_mut() + .try_for_each(|ai| ai.build(ctx))?; + + // Prepend constraint assertions + self.constraints.iter_mut().try_for_each(|c| c.build(ctx))?; + let assertions: Vec<_> = self + .constraints + .iter() + .map(|c| ctx.local.make_assertion_from_constraint(c)) + .try_collect()?; + self.compose.splice(0..0, assertions); + + Ok(()) + } + + fn post_build(&mut self, ctx: &mut Context) -> context::Result { + if let Some(Expression::LLVMLink(link)) = self.compose.last() { + let mut fn_call = link.make_fn_call(&self.signature)?; + // Required to inject conversions + fn_call.build(self, ctx)?; + self.compose.push(fn_call) + } + + if let Some(llvm_link) = self.llvm_link_mut() { + // Turn all Rust unsigned types into signed + llvm_link.sanitise_uints(); + } + + if let Some(predicate_form) = ctx.local.predicate_form() { + predicate_form.post_build(self)? + } + + // Set for ToTokens to display a generic svbool_t + self.signature.predicate_needs_conversion = true; + + // Set base type kind for instruction assertion + self.base_type = ctx + .local + .input + .get(0) + .and_then(|arg| arg.typekind()) + .and_then(|ty| ty.base_type()) + .map(BaseType::clone); + + // Add global target features + self.target_features = ctx + .global + .arch_cfgs + .iter() + .flat_map(|cfg| cfg.target_feature.clone()) + .chain(self.target_features.clone()) + .collect_vec(); + + Ok(()) + } + + fn defer_to_existing_llvm_link( + &mut self, + ctx: &LocalContext, + target_inputset: &InputSet, + ) -> context::Result { + let mut target_ctx = ctx.clone(); + target_ctx.input = target_inputset.clone(); + + let mut target_signature = target_ctx.signature.clone(); + target_signature.build(&target_ctx)?; + + let drop_var = if let Some(pred) = ctx.predicate_form().cloned() { + match pred { + PredicateForm::Zeroing(ZeroingMethod::Drop { drop }) => Some(drop), + PredicateForm::DontCare(DontCareMethod::AsZeroing) => { + if let Some(ZeroingMethod::Drop { drop }) = + self.input.predication_methods.zeroing_method.to_owned() + { + Some(drop) + } else { + None + } + } + _ => None, + } + } else { + None + }; + + let call_method = + |ex, method: &str| Expression::MethodCall(Box::new(ex), method.to_string(), vec![]); + let as_unsigned = |ex| call_method(ex, "as_unsigned"); + let as_signed = |ex| call_method(ex, "as_signed"); + let convert_if_required = |w: Option<&Wildcard>, from: &InputSet, to: &InputSet, ex| { + if let Some(w) = w { + if let Some(dest_idx) = w.get_typeset_index() { + let from_type = from.get(dest_idx); + let to_type = to.get(dest_idx); + + if from_type != to_type { + let from_base_type = from_type + .and_then(|in_arg| in_arg.typekind()) + .and_then(|ty| ty.base_type()) + .map(|bt| bt.kind()); + let to_base_type = to_type + .and_then(|in_arg| in_arg.typekind()) + .and_then(|ty| ty.base_type()) + .map(|bt| bt.kind()); + + match (from_base_type, to_base_type) { + // Use AsSigned for uint -> int + (Some(BaseTypeKind::UInt), Some(BaseTypeKind::Int)) => as_signed(ex), + // Use AsUnsigned for int -> uint + (Some(BaseTypeKind::Int), Some(BaseTypeKind::UInt)) => as_unsigned(ex), + (None, None) => ex, + _ => unreachable!("unsupported conversion case from {from_base_type:?} to {to_base_type:?} hit"), + } + } else { + ex + } + } else { + ex + } + } else { + ex + } + }; + + let args = ctx + .signature + .arguments + .iter() + .filter_map(|arg| { + let var = Expression::Identifier(arg.name.to_owned(), IdentifierType::Variable); + if drop_var.as_ref().map(|v| v.to_string()) != Some(arg.name.to_string()) { + Some(convert_if_required( + arg.kind.wildcard(), + &ctx.input, + target_inputset, + var, + )) + } else { + None + } + }) + .collect_vec(); + + let turbofish = self + .signature + .static_defs + .iter() + .map(|def| { + let name = match def { + StaticDefinition::Constant(Argument { name, .. }) => name.to_string(), + StaticDefinition::Generic(name) => name.to_string(), + }; + Expression::Identifier(name.into(), IdentifierType::Symbol) + }) + .collect_vec(); + + let ret_wildcard = ctx + .signature + .return_type + .as_ref() + .and_then(|t| t.wildcard()); + let call = FnCall(Box::new(target_signature.fn_name().into()), args, turbofish).into(); + + self.compose = vec![convert_if_required( + ret_wildcard, + target_inputset, + &ctx.input, + call, + )]; + + Ok(()) + } +} + +impl ToTokens for Intrinsic { + fn to_tokens(&self, tokens: &mut TokenStream) { + let signature = &self.signature; + let fn_name = signature.fn_name().to_string(); + let target_feature = self.target_features.join(","); + let safety = self + .safety + .as_ref() + .expect("safety should be determined during `pre_build`"); + + if let Some(doc) = &self.doc { + let mut doc = vec![doc.to_string()]; + + doc.push(String::new()); + doc.push(format!("[Arm's documentation](https://developer.arm.com/architectures/instruction-sets/intrinsics/{})", &signature.doc_name())); + + if safety.has_doc_comments() { + doc.push(String::new()); + doc.push("## Safety".to_string()); + for comment in safety.doc_comments() { + doc.push(format!(" * {comment}")); + } + } else { + assert!( + safety.is_safe(), + "{fn_name} is both public and unsafe, and so needs safety documentation" + ); + } + + tokens.append_all(quote! { #(#[doc = #doc])* }); + } else { + assert!( + matches!(self.visibility, FunctionVisibility::Private), + "{fn_name} needs to be private, or to have documentation." + ); + assert!( + !safety.has_doc_comments(), + "{fn_name} needs a documentation section for its safety comments." + ); + } + + tokens.append_all(quote! { + #[inline] + #[target_feature(enable = #target_feature)] + }); + + if !self.assert_instr.is_empty() { + InstructionAssertionsForBaseType(&self.assert_instr, &self.base_type.as_ref()) + .to_tokens(tokens) + } + + match &self.visibility { + FunctionVisibility::Public => tokens.append_all(quote! { pub }), + FunctionVisibility::Private => {} + } + if safety.is_unsafe() { + tokens.append_all(quote! { unsafe }); + } + tokens.append_all(quote! { #signature }); + tokens.append(Punct::new('{', Spacing::Alone)); + + let mut body_unsafe = false; + let mut expressions = self.compose.iter().peekable(); + while let Some(ex) = expressions.next() { + if !body_unsafe && safety.is_safe() && ex.requires_unsafe_wrapper(&fn_name) { + body_unsafe = true; + tokens.append_all(quote! { unsafe }); + tokens.append(Punct::new('{', Spacing::Alone)); + } + // If it's not the last and not a LLVM link, add a trailing semicolon + if expressions.peek().is_some() && !matches!(ex, Expression::LLVMLink(_)) { + tokens.append_all(quote! { #ex; }) + } else { + ex.to_tokens(tokens) + } + } + if body_unsafe { + tokens.append(Punct::new('}', Spacing::Alone)); + } + + tokens.append(Punct::new('}', Spacing::Alone)); + } +} diff --git a/crates/stdarch-gen2/src/load_store_tests.rs b/crates/stdarch-gen2/src/load_store_tests.rs new file mode 100644 index 0000000000..d697a8d22d --- /dev/null +++ b/crates/stdarch-gen2/src/load_store_tests.rs @@ -0,0 +1,818 @@ +use std::fs::File; +use std::io::Write; +use std::path::PathBuf; +use std::str::FromStr; + +use crate::format_code; +use crate::input::InputType; +use crate::intrinsic::Intrinsic; +use crate::typekinds::BaseType; +use crate::typekinds::{ToRepr, TypeKind}; + +use itertools::Itertools; +use lazy_static::lazy_static; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +// Number of vectors in our buffers - the maximum tuple size, 4, plus 1 as we set the vnum +// argument to 1. +const NUM_VECS: usize = 5; +// The maximum vector length (in bits) +const VL_MAX_BITS: usize = 2048; +// The maximum vector length (in bytes) +const VL_MAX_BYTES: usize = VL_MAX_BITS / 8; +// The maximum number of elements in each vector type +const LEN_F32: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_F64: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_I8: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_I16: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_I32: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_I64: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_U8: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_U16: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_U32: usize = VL_MAX_BYTES / core::mem::size_of::(); +const LEN_U64: usize = VL_MAX_BYTES / core::mem::size_of::(); + +/// `load_intrinsics` and `store_intrinsics` is a vector of intrinsics +/// variants, while `out_path` is a file to write to. +pub fn generate_load_store_tests( + load_intrinsics: Vec, + store_intrinsics: Vec, + out_path: Option<&PathBuf>, +) -> Result<(), String> { + let output = match out_path { + Some(out) => { + Box::new(File::create(out).map_err(|e| format!("couldn't create tests file: {e}"))?) + as Box + } + None => Box::new(std::io::stdout()) as Box, + }; + let mut used_stores = vec![false; store_intrinsics.len()]; + let tests: Vec<_> = load_intrinsics + .iter() + .map(|load| { + let store_candidate = load + .signature + .fn_name() + .to_string() + .replace("svld1s", "svst1") + .replace("svld1u", "svst1") + .replace("svldnt1s", "svstnt1") + .replace("svldnt1u", "svstnt1") + .replace("svld", "svst") + .replace("gather", "scatter"); + + let store_index = store_intrinsics + .iter() + .position(|i| i.signature.fn_name().to_string() == store_candidate); + if let Some(i) = store_index { + used_stores[i] = true; + } + + generate_single_test( + load.clone(), + store_index.map(|i| store_intrinsics[i].clone()), + ) + }) + .try_collect()?; + + assert!(used_stores.into_iter().all(|b| b), "Not all store tests have been paired with a load. Consider generating specifc store-only tests"); + + let preamble = + TokenStream::from_str(&PREAMBLE).map_err(|e| format!("Preamble is invalid: {e}"))?; + // Only output manual tests for the SVE set + let manual_tests = match &load_intrinsics[0].target_features[..] { + [s] if s == "sve" => TokenStream::from_str(&MANUAL_TESTS) + .map_err(|e| format!("Manual tests are invalid: {e}"))?, + _ => quote!(), + }; + format_code( + output, + format!( + "// This code is automatically generated. DO NOT MODIFY. +// +// Instead, modify `crates/stdarch-gen2/spec/sve` and run the following command to re-generate this +// file: +// +// ``` +// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec +// ``` +{}", + quote! { #preamble #(#tests)* #manual_tests } + ), + ) + .map_err(|e| format!("couldn't write tests: {e}")) +} + +/// A test looks like this: +/// ``` +/// let data = [scalable vector]; +/// +/// let mut storage = [0; N]; +/// +/// store_intrinsic([true_predicate], storage.as_mut_ptr(), data); +/// [test contents of storage] +/// +/// let loaded == load_intrinsic([true_predicate], storage.as_ptr()) +/// assert!(loaded == data); +/// ``` +/// We intialise our data such that the value stored matches the index it's stored to. +/// By doing this we can validate scatters by checking that each value in the storage +/// array is either 0 or the same as its index. +fn generate_single_test( + load: Intrinsic, + store: Option, +) -> Result { + let chars = LdIntrCharacteristics::new(&load)?; + let fn_name = load.signature.fn_name().to_string(); + + if let Some(ty) = &chars.gather_bases_type { + if ty.base_type().unwrap().get_size() == Ok(32) + && chars.gather_index_type.is_none() + && chars.gather_offset_type.is_none() + { + // We lack a way to ensure data is in the bottom 32 bits of the address space + println!("Skipping test for {fn_name}"); + return Ok(quote!()); + } + } + + if fn_name.starts_with("svldff1") && fn_name.contains("gather") { + // TODO: We can remove this check when first-faulting gathers are fixed in CI's QEMU + // https://gitlab.com/qemu-project/qemu/-/issues/1612 + println!("Skipping test for {fn_name}"); + return Ok(quote!()); + } + + let fn_ident = format_ident!("{fn_name}"); + let test_name = format_ident!( + "test_{fn_name}{}", + if let Some(ref store) = store { + format!("_with_{}", store.signature.fn_name()) + } else { + String::new() + } + ); + + let load_type = &chars.load_type; + let acle_type = load_type.acle_notation_repr(); + + // If there's no return type, fallback to the load type for things that depend on it + let ret_type = &load + .signature + .return_type + .as_ref() + .and_then(TypeKind::base_type) + .unwrap_or(load_type); + + let pred_fn = format_ident!("svptrue_b{}", load_type.size()); + + let load_type_caps = load_type.rust_repr().to_uppercase(); + let data_array = format_ident!("{load_type_caps}_DATA"); + + let size_fn = format_ident!("svcnt{}", ret_type.size_literal()); + + let rust_ret_type = ret_type.rust_repr(); + let assert_fn = format_ident!("assert_vector_matches_{rust_ret_type}"); + + // Use vnum=1, so adjust all values by one vector length + let (length_call, vnum_arg) = if chars.vnum { + if chars.is_prf { + (quote!(), quote!(, 1)) + } else { + (quote!(let len = #size_fn() as usize;), quote!(, 1)) + } + } else { + (quote!(), quote!()) + }; + + let (bases_load, bases_arg) = if let Some(ty) = &chars.gather_bases_type { + // Bases is a vector of (sometimes 32-bit) pointers + // When we combine bases with an offset/index argument, we load from the data arrays + // starting at 1 + let base_ty = ty.base_type().unwrap(); + let rust_type = format_ident!("{}", base_ty.rust_repr()); + let index_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr()); + let size_in_bytes = chars.load_type.get_size().unwrap() / 8; + + if base_ty.get_size().unwrap() == 32 { + // Treat bases as a vector of offsets here - we don't test this without an offset or + // index argument + ( + Some(quote!( + let bases = #index_fn(0, #size_in_bytes.try_into().unwrap()); + )), + quote!(, bases), + ) + } else { + // Treat bases as a vector of pointers + let base_fn = format_ident!("svdup_n_{}", base_ty.acle_notation_repr()); + let data_array = if store.is_some() { + format_ident!("storage") + } else { + format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase()) + }; + + let add_fn = format_ident!("svadd_{}_x", base_ty.acle_notation_repr()); + ( + Some(quote! { + let bases = #base_fn(#data_array.as_ptr() as #rust_type); + let offsets = #index_fn(0, #size_in_bytes.try_into().unwrap()); + let bases = #add_fn(#pred_fn(), bases, offsets); + }), + quote!(, bases), + ) + } + } else { + (None, quote!()) + }; + + let index_arg = if let Some(ty) = &chars.gather_index_type { + let rust_type = format_ident!("{}", ty.rust_repr()); + if chars + .gather_bases_type + .as_ref() + .and_then(TypeKind::base_type) + .map_or(Err(String::new()), BaseType::get_size) + .unwrap() + == 32 + { + // Let index be the base of the data array + let data_array = if store.is_some() { + format_ident!("storage") + } else { + format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase()) + }; + let size_in_bytes = chars.load_type.get_size().unwrap() / 8; + quote!(, #data_array.as_ptr() as #rust_type / (#size_in_bytes as #rust_type) + 1) + } else { + quote!(, 1.try_into().unwrap()) + } + } else { + quote!() + }; + + let offset_arg = if let Some(ty) = &chars.gather_offset_type { + let size_in_bytes = chars.load_type.get_size().unwrap() / 8; + if chars + .gather_bases_type + .as_ref() + .and_then(TypeKind::base_type) + .map_or(Err(String::new()), BaseType::get_size) + .unwrap() + == 32 + { + // Let offset be the base of the data array + let rust_type = format_ident!("{}", ty.rust_repr()); + let data_array = if store.is_some() { + format_ident!("storage") + } else { + format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase()) + }; + quote!(, #data_array.as_ptr() as #rust_type + #size_in_bytes as #rust_type) + } else { + quote!(, #size_in_bytes.try_into().unwrap()) + } + } else { + quote!() + }; + + let (offsets_load, offsets_arg) = if let Some(ty) = &chars.gather_offsets_type { + // Offsets is a scalable vector of per-element offsets in bytes. We re-use the contiguous + // data for this, then multiply to get indices + let offsets_fn = format_ident!("svindex_{}", ty.base_type().unwrap().acle_notation_repr()); + let size_in_bytes = chars.load_type.get_size().unwrap() / 8; + ( + Some(quote! { + let offsets = #offsets_fn(0, #size_in_bytes.try_into().unwrap()); + }), + quote!(, offsets), + ) + } else { + (None, quote!()) + }; + + let (indices_load, indices_arg) = if let Some(ty) = &chars.gather_indices_type { + // There's no need to multiply indices by the load type width + let base_ty = ty.base_type().unwrap(); + let indices_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr()); + ( + Some(quote! { + let indices = #indices_fn(0, 1); + }), + quote! {, indices}, + ) + } else { + (None, quote!()) + }; + + let ptr = if chars.gather_bases_type.is_some() { + quote!() + } else if chars.is_prf { + quote!(, I64_DATA.as_ptr()) + } else { + quote!(, #data_array.as_ptr()) + }; + + let tuple_len = &chars.tuple_len; + let expecteds = if chars.is_prf { + // No return value for prefetches + vec![] + } else { + (0..*tuple_len) + .map(|i| get_expected_range(i, &chars)) + .collect() + }; + let asserts: Vec<_> = + if *tuple_len > 1 { + let svget = format_ident!("svget{tuple_len}_{acle_type}"); + expecteds.iter().enumerate().map(|(i, expected)| { + quote! (#assert_fn(#svget::<{ #i as i32 }>(loaded), #expected);) + }).collect() + } else { + expecteds + .iter() + .map(|expected| quote! (#assert_fn(loaded, #expected);)) + .collect() + }; + + let function = if chars.is_prf { + if fn_name.contains("gather") && fn_name.contains("base") && !fn_name.starts_with("svprf_") + { + // svprf(b|h|w|d)_gather base intrinsics do not have a generic type parameter + quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }>) + } else { + quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }, i64>) + } + } else { + quote!(#fn_ident) + }; + + let octaword_guard = if chars.replicate_width == Some(256) { + let msg = format!("Skipping {test_name} due to SVE vector length"); + quote! { + if svcntb() < 32 { + println!(#msg); + return; + } + } + } else { + quote!() + }; + + let feats = load.target_features.join(","); + + if let Some(store) = store { + let data_init = if *tuple_len == 1 { + quote!(#(#expecteds)*) + } else { + let create = format_ident!("svcreate{tuple_len}_{acle_type}"); + quote!(#create(#(#expecteds),*)) + }; + let input = store.input.types.get(0).unwrap().get(0).unwrap(); + let store_type = input + .get(store.test.get_typeset_index().unwrap()) + .and_then(InputType::typekind) + .and_then(TypeKind::base_type) + .unwrap(); + + let store_type = format_ident!("{}", store_type.rust_repr()); + let storage_len = NUM_VECS * VL_MAX_BITS / chars.load_type.get_size()? as usize; + let store_fn = format_ident!("{}", store.signature.fn_name().to_string()); + let load_type = format_ident!("{}", chars.load_type.rust_repr()); + let (store_ptr, store_mut_ptr) = if chars.gather_bases_type.is_none() { + ( + quote!(, storage.as_ptr() as *const #load_type), + quote!(, storage.as_mut_ptr()), + ) + } else { + (quote!(), quote!()) + }; + let args = quote!(#pred_fn() #store_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg); + let call = if chars.uses_ffr { + // Doing a normal load first maximises the number of elements our ff/nf test loads + let non_ffr_fn_name = format_ident!( + "{}", + fn_name + .replace("svldff1", "svld1") + .replace("svldnf1", "svld1") + ); + quote! { + svsetffr(); + let _ = #non_ffr_fn_name(#args); + let loaded = #function(#args); + } + } else { + // Note that the FFR must be set for all tests as the assert functions mask against it + quote! { + svsetffr(); + let loaded = #function(#args); + } + }; + + Ok(quote! { + #[simd_test(enable = #feats)] + unsafe fn #test_name() { + #octaword_guard + #length_call + let mut storage = [0 as #store_type; #storage_len]; + let data = #data_init; + #bases_load + #offsets_load + #indices_load + + #store_fn(#pred_fn() #store_mut_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg, data); + for (i, &val) in storage.iter().enumerate() { + assert!(val == 0 as #store_type || val == i as #store_type); + } + + #call + #(#asserts)* + + } + }) + } else { + let args = quote!(#pred_fn() #ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg); + let call = if chars.uses_ffr { + // Doing a normal load first maximises the number of elements our ff/nf test loads + let non_ffr_fn_name = format_ident!( + "{}", + fn_name + .replace("svldff1", "svld1") + .replace("svldnf1", "svld1") + ); + quote! { + svsetffr(); + let _ = #non_ffr_fn_name(#args); + let loaded = #function(#args); + } + } else { + // Note that the FFR must be set for all tests as the assert functions mask against it + quote! { + svsetffr(); + let loaded = #function(#args); + } + }; + Ok(quote! { + #[simd_test(enable = #feats)] + unsafe fn #test_name() { + #octaword_guard + #bases_load + #offsets_load + #indices_load + #call + #length_call + + #(#asserts)* + } + }) + } +} + +/// Assumes chars.ret_type is not None +fn get_expected_range(tuple_idx: usize, chars: &LdIntrCharacteristics) -> proc_macro2::TokenStream { + // vnum=1 + let vnum_adjust = if chars.vnum { quote!(len+) } else { quote!() }; + + let bases_adjust = + (chars.gather_index_type.is_some() || chars.gather_offset_type.is_some()) as usize; + + let tuple_len = chars.tuple_len; + let size = chars + .ret_type + .as_ref() + .and_then(TypeKind::base_type) + .unwrap_or(&chars.load_type) + .get_size() + .unwrap() as usize; + + if chars.replicate_width == Some(128) { + // svld1rq + let ty_rust = format_ident!( + "{}", + chars + .ret_type + .as_ref() + .unwrap() + .base_type() + .unwrap() + .rust_repr() + ); + let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect(); + let dup = format_ident!( + "svdupq_n_{}", + chars.ret_type.as_ref().unwrap().acle_notation_repr() + ); + quote!(#dup(#(#args,)*)) + } else if chars.replicate_width == Some(256) { + // svld1ro - we use two interleaved svdups to create a repeating 256-bit pattern + let ty_rust = format_ident!( + "{}", + chars + .ret_type + .as_ref() + .unwrap() + .base_type() + .unwrap() + .rust_repr() + ); + let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr(); + let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect(); + let args2: Vec<_> = ((128 / size)..(256 / size)) + .map(|i| quote!(#i as #ty_rust)) + .collect(); + let dup = format_ident!("svdupq_n_{ret_acle}"); + let interleave = format_ident!("svtrn1q_{ret_acle}"); + quote!(#interleave(#dup(#(#args,)*), #dup(#(#args2,)*))) + } else { + let start = bases_adjust + tuple_idx; + if chars + .ret_type + .as_ref() + .unwrap() + .base_type() + .unwrap() + .is_float() + { + // Use svcvt to create a linear sequence of floats + let cvt_fn = format_ident!("svcvt_f{size}_s{size}_x"); + let pred_fn = format_ident!("svptrue_b{size}"); + let svindex_fn = format_ident!("svindex_s{size}"); + quote! { #cvt_fn(#pred_fn(), #svindex_fn((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap()))} + } else { + let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr(); + let svindex = format_ident!("svindex_{ret_acle}"); + quote!(#svindex((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap())) + } + } +} + +struct LdIntrCharacteristics { + // The data type to load from (not necessarily the data type returned) + load_type: BaseType, + // The data type to return (None for unit) + ret_type: Option, + // The size of tuple to load/store + tuple_len: usize, + // Whether a vnum argument is present + vnum: bool, + // Is the intrinsic first/non-faulting? + uses_ffr: bool, + // Is it a prefetch? + is_prf: bool, + // The size of data loaded with svld1ro/q intrinsics + replicate_width: Option, + // Scalable vector of pointers to load from + gather_bases_type: Option, + // Scalar offset, paired with bases + gather_offset_type: Option, + // Scalar index, paired with bases + gather_index_type: Option, + // Scalable vector of offsets + gather_offsets_type: Option, + // Scalable vector of indices + gather_indices_type: Option, +} + +impl LdIntrCharacteristics { + fn new(intr: &Intrinsic) -> Result { + let input = intr.input.types.get(0).unwrap().get(0).unwrap(); + let load_type = input + .get(intr.test.get_typeset_index().unwrap()) + .and_then(InputType::typekind) + .and_then(TypeKind::base_type) + .unwrap(); + + let ret_type = intr.signature.return_type.clone(); + + let name = intr.signature.fn_name().to_string(); + let tuple_len = name + .chars() + .find(|c| c.is_numeric()) + .and_then(|c| c.to_digit(10)) + .unwrap_or(1) as usize; + + let uses_ffr = name.starts_with("svldff") || name.starts_with("svldnf"); + + let is_prf = name.starts_with("svprf"); + + let replicate_width = if name.starts_with("svld1ro") { + Some(256) + } else if name.starts_with("svld1rq") { + Some(128) + } else { + None + }; + + let get_ty_of_arg = |name: &str| { + intr.signature + .arguments + .iter() + .find(|a| a.name.to_string() == name) + .map(|a| a.kind.clone()) + }; + + let gather_bases_type = get_ty_of_arg("bases"); + let gather_offset_type = get_ty_of_arg("offset"); + let gather_index_type = get_ty_of_arg("index"); + let gather_offsets_type = get_ty_of_arg("offsets"); + let gather_indices_type = get_ty_of_arg("indices"); + + Ok(LdIntrCharacteristics { + load_type: *load_type, + ret_type, + tuple_len, + vnum: name.contains("vnum"), + uses_ffr, + is_prf, + replicate_width, + gather_bases_type, + gather_offset_type, + gather_index_type, + gather_offsets_type, + gather_indices_type, + }) + } +} + +lazy_static! { + static ref PREAMBLE: String = format!( + r#"#![allow(unused)] + +use super::*; +use std::boxed::Box; +use std::convert::{{TryFrom, TryInto}}; +use std::sync::LazyLock; +use std::vec::Vec; +use stdarch_test::simd_test; + +static F32_DATA: LazyLock<[f32; {LEN_F32} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_F32} * {NUM_VECS}) + .map(|i| i as f32) + .collect::>() + .try_into() + .expect("f32 data incorrectly initialised") +}}); +static F64_DATA: LazyLock<[f64; {LEN_F64} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_F64} * {NUM_VECS}) + .map(|i| i as f64) + .collect::>() + .try_into() + .expect("f64 data incorrectly initialised") +}}); +static I8_DATA: LazyLock<[i8; {LEN_I8} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_I8} * {NUM_VECS}) + .map(|i| ((i + 128) % 256 - 128) as i8) + .collect::>() + .try_into() + .expect("i8 data incorrectly initialised") +}}); +static I16_DATA: LazyLock<[i16; {LEN_I16} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_I16} * {NUM_VECS}) + .map(|i| i as i16) + .collect::>() + .try_into() + .expect("i16 data incorrectly initialised") +}}); +static I32_DATA: LazyLock<[i32; {LEN_I32} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_I32} * {NUM_VECS}) + .map(|i| i as i32) + .collect::>() + .try_into() + .expect("i32 data incorrectly initialised") +}}); +static I64_DATA: LazyLock<[i64; {LEN_I64} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_I64} * {NUM_VECS}) + .map(|i| i as i64) + .collect::>() + .try_into() + .expect("i64 data incorrectly initialised") +}}); +static U8_DATA: LazyLock<[u8; {LEN_U8} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_U8} * {NUM_VECS}) + .map(|i| i as u8) + .collect::>() + .try_into() + .expect("u8 data incorrectly initialised") +}}); +static U16_DATA: LazyLock<[u16; {LEN_U16} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_U16} * {NUM_VECS}) + .map(|i| i as u16) + .collect::>() + .try_into() + .expect("u16 data incorrectly initialised") +}}); +static U32_DATA: LazyLock<[u32; {LEN_U32} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_U32} * {NUM_VECS}) + .map(|i| i as u32) + .collect::>() + .try_into() + .expect("u32 data incorrectly initialised") +}}); +static U64_DATA: LazyLock<[u64; {LEN_U64} * {NUM_VECS}]> = LazyLock::new(|| {{ + (0..{LEN_U64} * {NUM_VECS}) + .map(|i| i as u64) + .collect::>() + .try_into() + .expect("u64 data incorrectly initialised") +}}); + +#[target_feature(enable = "sve")] +fn assert_vector_matches_f32(vector: svfloat32_t, expected: svfloat32_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b32(), defined)); + let cmp = svcmpne_f32(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_f64(vector: svfloat64_t, expected: svfloat64_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b64(), defined)); + let cmp = svcmpne_f64(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_i8(vector: svint8_t, expected: svint8_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b8(), defined)); + let cmp = svcmpne_s8(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_i16(vector: svint16_t, expected: svint16_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b16(), defined)); + let cmp = svcmpne_s16(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_i32(vector: svint32_t, expected: svint32_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b32(), defined)); + let cmp = svcmpne_s32(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_i64(vector: svint64_t, expected: svint64_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b64(), defined)); + let cmp = svcmpne_s64(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_u8(vector: svuint8_t, expected: svuint8_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b8(), defined)); + let cmp = svcmpne_u8(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_u16(vector: svuint16_t, expected: svuint16_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b16(), defined)); + let cmp = svcmpne_u16(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_u32(vector: svuint32_t, expected: svuint32_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b32(), defined)); + let cmp = svcmpne_u32(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} + +#[target_feature(enable = "sve")] +fn assert_vector_matches_u64(vector: svuint64_t, expected: svuint64_t) {{ + let defined = svrdffr(); + assert!(svptest_first(svptrue_b64(), defined)); + let cmp = svcmpne_u64(defined, vector, expected); + assert!(!svptest_any(defined, cmp)) +}} +"# + ); +} + +lazy_static! { + static ref MANUAL_TESTS: String = format!( + "#[simd_test(enable = \"sve\")] +unsafe fn test_ffr() {{ + svsetffr(); + let ffr = svrdffr(); + assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svindex_u8(1, 0)); + let pred = svdupq_n_b8(true, false, true, false, true, false, true, false, + true, false, true, false, true, false, true, false); + svwrffr(pred); + let ffr = svrdffr_z(svptrue_b8()); + assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svdup_n_u8_z(pred, 1)); +}} +" + ); +} diff --git a/crates/stdarch-gen2/src/main.rs b/crates/stdarch-gen2/src/main.rs new file mode 100644 index 0000000000..5379d18404 --- /dev/null +++ b/crates/stdarch-gen2/src/main.rs @@ -0,0 +1,273 @@ +#![feature(pattern)] + +mod assert_instr; +mod context; +mod expression; +mod input; +mod intrinsic; +mod load_store_tests; +mod matching; +mod predicate_forms; +mod typekinds; +mod wildcards; +mod wildstring; + +use intrinsic::Test; +use itertools::Itertools; +use quote::quote; +use std::fs::File; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use walkdir::WalkDir; + +fn main() -> Result<(), String> { + parse_args() + .into_iter() + .map(|(filepath, out)| { + File::open(&filepath) + .map(|f| (f, filepath, out)) + .map_err(|e| format!("could not read input file: {e}")) + }) + .map(|res| { + let (file, filepath, out) = res?; + serde_yaml::from_reader(file) + .map(|input: input::GeneratorInput| (input, filepath, out)) + .map_err(|e| format!("could not parse input file: {e}")) + }) + .collect::, _>>()? + .into_iter() + .map(|(input, filepath, out)| { + let intrinsics = input.intrinsics.into_iter() + .map(|intrinsic| intrinsic.generate_variants(&input.ctx)) + .try_collect() + .map(|mut vv: Vec<_>| { + vv.sort_by_cached_key(|variants| { + variants.first().map_or_else(String::default, |variant| { + variant.signature.fn_name().to_string() + }) + }); + vv.into_iter().flatten().collect_vec() + })?; + + let loads = intrinsics.iter() + .filter_map(|i| { + if matches!(i.test, Test::Load(..)) { + Some(i.clone()) + } else { + None + } + }).collect(); + let stores = intrinsics.iter() + .filter_map(|i| { + if matches!(i.test, Test::Store(..)) { + Some(i.clone()) + } else { + None + } + }).collect(); + load_store_tests::generate_load_store_tests(loads, stores, out.as_ref().map(|o| make_tests_filepath(&filepath, o)).as_ref())?; + Ok(( + input::GeneratorInput { + intrinsics, + ctx: input.ctx, + }, + filepath, + out, + )) + }) + .try_for_each( + |result: context::Result<(input::GeneratorInput, PathBuf, Option)>| -> context::Result { + let (generated, filepath, out) = result?; + + let w = match out { + Some(out) => Box::new( + File::create(make_output_filepath(&filepath, &out)) + .map_err(|e| format!("could not create output file: {e}"))?, + ) as Box, + None => Box::new(std::io::stdout()) as Box, + }; + + generate_file(generated, w) + .map_err(|e| format!("could not generate output file: {e}")) + }, + ) +} + +fn parse_args() -> Vec<(PathBuf, Option)> { + let mut args_it = std::env::args().skip(1); + assert!( + 1 <= args_it.len() && args_it.len() <= 2, + "Usage: cargo run -p stdarch-gen2 -- INPUT_DIR [OUTPUT_DIR]" + ); + + let in_path = Path::new(args_it.next().unwrap().as_str()).to_path_buf(); + assert!( + in_path.exists() && in_path.is_dir(), + "invalid path {in_path:#?} given" + ); + + let out_dir = if let Some(dir) = args_it.next() { + let out_path = Path::new(dir.as_str()).to_path_buf(); + assert!( + out_path.exists() && out_path.is_dir(), + "invalid path {out_path:#?} given" + ); + Some(out_path) + } else { + std::env::current_exe() + .map(|mut f| { + f.pop(); + f.push("../../crates/core_arch/src/aarch64/"); + f.exists().then_some(f) + }) + .ok() + .flatten() + }; + + WalkDir::new(in_path) + .into_iter() + .filter_map(Result::ok) + .filter(|f| f.file_type().is_file()) + .map(|f| (f.into_path(), out_dir.clone())) + .collect() +} + +fn generate_file( + generated_input: input::GeneratorInput, + mut out: Box, +) -> std::io::Result<()> { + write!( + out, + r#"// This code is automatically generated. DO NOT MODIFY. +// +// Instead, modify `crates/stdarch-gen2/spec/` and run the following command to re-generate this file: +// +// ``` +// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec +// ``` +#![allow(improper_ctypes)] + +#[cfg(test)] +use stdarch_test::assert_instr; + +use super::*;{uses_neon} + +"#, + uses_neon = generated_input + .ctx + .uses_neon_types + .then_some("\nuse crate::core_arch::arch::aarch64::*;") + .unwrap_or_default(), + )?; + let intrinsics = generated_input.intrinsics; + format_code(out, quote! { #(#intrinsics)* })?; + Ok(()) +} + +pub fn format_code( + mut output: impl std::io::Write, + input: impl std::fmt::Display, +) -> std::io::Result<()> { + let proc = Command::new("rustfmt") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + write!(proc.stdin.as_ref().unwrap(), "{input}")?; + output.write_all(proc.wait_with_output()?.stdout.as_slice()) +} + +/// Derive an output file name from an input file and an output directory. +/// +/// The name is formed by: +/// +/// - ... taking in_filepath.file_name() (dropping all directory components), +/// - ... dropping a .yml or .yaml extension (if present), +/// - ... then dropping a .spec extension (if present). +/// +/// Panics if the resulting name is empty, or if file_name() is not UTF-8. +fn make_output_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf { + make_filepath(in_filepath, out_dirpath, |name: &str| format!("{name}.rs")) +} + +fn make_tests_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf { + make_filepath(in_filepath, out_dirpath, |name: &str| { + format!("ld_st_tests_{name}.rs") + }) +} + +fn make_filepath String>( + in_filepath: &Path, + out_dirpath: &Path, + name_formatter: F, +) -> PathBuf { + let mut parts = in_filepath.iter(); + let name = parts + .next_back() + .and_then(|f| f.to_str()) + .expect("Inputs must have valid, UTF-8 file_name()"); + let dir = parts.next_back().unwrap(); + + let name = name + .trim_end_matches(".yml") + .trim_end_matches(".yaml") + .trim_end_matches(".spec"); + assert!(!name.is_empty()); + + let mut output = out_dirpath.to_path_buf(); + output.push(dir); + output.push(name_formatter(name)); + output +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn infer_output_file() { + macro_rules! t { + ($src:expr, $outdir:expr, $dst:expr) => { + let src: PathBuf = $src.iter().collect(); + let outdir: PathBuf = $outdir.iter().collect(); + let dst: PathBuf = $dst.iter().collect(); + assert_eq!(make_output_filepath(&src, &outdir), dst); + }; + } + // Documented usage. + t!(["x", "NAME.spec.yml"], [""], ["x", "NAME.rs"]); + t!( + ["x", "NAME.spec.yml"], + ["a", "b"], + ["a", "b", "x", "NAME.rs"] + ); + t!( + ["x", "y", "NAME.spec.yml"], + ["out"], + ["out", "y", "NAME.rs"] + ); + t!(["x", "NAME.spec.yaml"], ["out"], ["out", "x", "NAME.rs"]); + t!(["x", "NAME.spec"], ["out"], ["out", "x", "NAME.rs"]); + t!(["x", "NAME.yml"], ["out"], ["out", "x", "NAME.rs"]); + t!(["x", "NAME.yaml"], ["out"], ["out", "x", "NAME.rs"]); + // Unrecognised extensions get treated as part of the stem. + t!( + ["x", "NAME.spac.yml"], + ["out"], + ["out", "x", "NAME.spac.rs"] + ); + t!(["x", "NAME.txt"], ["out"], ["out", "x", "NAME.txt.rs"]); + // Always take the top-level directory from the input path + t!( + ["x", "y", "z", "NAME.spec.yml"], + ["out"], + ["out", "z", "NAME.rs"] + ); + } + + #[test] + #[should_panic] + fn infer_output_file_no_stem() { + make_output_filepath(Path::new(".spec.yml"), Path::new("")); + } +} diff --git a/crates/stdarch-gen2/src/matching.rs b/crates/stdarch-gen2/src/matching.rs new file mode 100644 index 0000000000..0c48062042 --- /dev/null +++ b/crates/stdarch-gen2/src/matching.rs @@ -0,0 +1,170 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use serde::{Deserialize, Serialize}; +use std::fmt; + +use crate::context::{self, LocalContext}; +use crate::typekinds::{BaseType, BaseTypeKind, TypeKind}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MatchSizeValues { + pub default: T, + pub byte: Option, + pub halfword: Option, + pub doubleword: Option, +} + +impl MatchSizeValues { + pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> { + let base_ty = if let Some(w) = ty.wildcard() { + ctx.provide_type_wildcard(w)? + } else { + ty.clone() + }; + + if let BaseType::Sized(_, bitsize) = base_ty.base_type().unwrap() { + match (bitsize, &self.byte, &self.halfword, &self.doubleword) { + (64, _, _, Some(v)) | (16, _, Some(v), _) | (8, Some(v), _, _) => Ok(v), + _ => Ok(&self.default), + } + } else { + Err(format!("cannot match bitsize to unsized type {ty:?}!")) + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MatchKindValues { + pub default: T, + pub float: Option, + pub unsigned: Option, +} + +impl MatchKindValues { + pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> { + let base_ty = if let Some(w) = ty.wildcard() { + ctx.provide_type_wildcard(w)? + } else { + ty.clone() + }; + + match ( + base_ty.base_type().unwrap().kind(), + &self.float, + &self.unsigned, + ) { + (BaseTypeKind::Float, Some(v), _) | (BaseTypeKind::UInt, _, Some(v)) => Ok(v), + _ => Ok(&self.default), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged, deny_unknown_fields)] +pub enum SizeMatchable { + Matched(T), + Unmatched { + match_size: Option, + #[serde(flatten)] + values: MatchSizeValues>, + }, +} + +impl SizeMatchable { + pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result { + match self { + Self::Unmatched { + match_size: None, + values: MatchSizeValues { default, .. }, + } => *self = Self::Matched(*default.to_owned()), + Self::Unmatched { + match_size: Some(ty), + values, + } => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()), + _ => {} + } + Ok(()) + } +} + +impl AsRef for SizeMatchable { + fn as_ref(&self) -> &T { + if let SizeMatchable::Matched(v) = self { + v + } else { + panic!("no match for {self:?} was performed"); + } + } +} + +impl AsMut for SizeMatchable { + fn as_mut(&mut self) -> &mut T { + if let SizeMatchable::Matched(v) = self { + v + } else { + panic!("no match for {self:?} was performed"); + } + } +} + +impl ToTokens for SizeMatchable { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.as_ref().to_tokens(tokens) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged, deny_unknown_fields)] +pub enum KindMatchable { + Matched(T), + Unmatched { + match_kind: Option, + #[serde(flatten)] + values: MatchKindValues>, + }, +} + +impl KindMatchable { + pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result { + match self { + Self::Unmatched { + match_kind: None, + values: MatchKindValues { default, .. }, + } => *self = Self::Matched(*default.to_owned()), + Self::Unmatched { + match_kind: Some(ty), + values, + } => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()), + _ => {} + } + Ok(()) + } +} + +impl AsRef for KindMatchable { + fn as_ref(&self) -> &T { + if let KindMatchable::Matched(v) = self { + v + } else { + panic!("no match for {self:?} was performed"); + } + } +} + +impl AsMut for KindMatchable { + fn as_mut(&mut self) -> &mut T { + if let KindMatchable::Matched(v) = self { + v + } else { + panic!("no match for {self:?} was performed"); + } + } +} + +impl ToTokens for KindMatchable { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.as_ref().to_tokens(tokens) + } +} diff --git a/crates/stdarch-gen2/src/predicate_forms.rs b/crates/stdarch-gen2/src/predicate_forms.rs new file mode 100644 index 0000000000..02789bf7eb --- /dev/null +++ b/crates/stdarch-gen2/src/predicate_forms.rs @@ -0,0 +1,249 @@ +use serde::{Deserialize, Serialize}; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::fmt; +use std::str::FromStr; + +use crate::context; +use crate::expression::{Expression, FnCall, IdentifierType}; +use crate::intrinsic::Intrinsic; +use crate::typekinds::{ToRepr, TypeKind}; +use crate::wildcards::Wildcard; +use crate::wildstring::WildString; + +const ZEROING_SUFFIX: &str = "_z"; +const MERGING_SUFFIX: &str = "_m"; +const DONT_CARE_SUFFIX: &str = "_x"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ZeroingMethod { + /// Drop the specified argument and replace it with a zeroinitializer + Drop { drop: WildString }, + /// Apply zero selection to the specified variable when zeroing + Select { select: WildString }, +} + +impl PartialOrd for ZeroingMethod { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ZeroingMethod { + fn cmp(&self, _: &Self) -> std::cmp::Ordering { + std::cmp::Ordering::Equal + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum DontCareMethod { + #[default] + Inferred, + AsZeroing, + AsMerging, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] +pub struct PredicationMethods { + /// Zeroing method, if the zeroing predicate form is used + #[serde(default)] + pub zeroing_method: Option, + /// Don't care method, if the don't care predicate form is used + #[serde(default)] + pub dont_care_method: DontCareMethod, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum PredicateForm { + /// Enables merging predicate form + Merging, + /// Enables "don't care" predicate form. + DontCare(DontCareMethod), + /// Enables zeroing predicate form. If LLVM zeroselection is performed, then + /// set the `select` field to the variable that gets set. Otherwise set the + /// `drop` field if the zeroinitializer replaces a predicate when merging. + Zeroing(ZeroingMethod), +} + +impl PredicateForm { + pub fn get_suffix(&self) -> &'static str { + match self { + PredicateForm::Zeroing { .. } => ZEROING_SUFFIX, + PredicateForm::Merging => MERGING_SUFFIX, + PredicateForm::DontCare { .. } => DONT_CARE_SUFFIX, + } + } + + pub fn make_zeroinitializer(ty: &TypeKind) -> Expression { + FnCall::new_expression( + format!("svdup_n_{}", ty.acle_notation_repr()) + .parse() + .unwrap(), + vec![if ty.base_type().unwrap().is_float() { + Expression::FloatConstant(0.0) + } else { + Expression::IntConstant(0) + }], + ) + } + + pub fn make_zeroselector(pg_var: WildString, op_var: WildString, ty: &TypeKind) -> Expression { + FnCall::new_expression( + format!("svsel_{}", ty.acle_notation_repr()) + .parse() + .unwrap(), + vec![ + Expression::Identifier(pg_var, IdentifierType::Variable), + Expression::Identifier(op_var, IdentifierType::Variable), + Self::make_zeroinitializer(ty), + ], + ) + } + + pub fn post_build(&self, intrinsic: &mut Intrinsic) -> context::Result { + // Drop the argument + match self { + PredicateForm::Zeroing(ZeroingMethod::Drop { drop: drop_var }) => { + intrinsic.signature.drop_argument(drop_var)? + } + PredicateForm::DontCare(DontCareMethod::AsZeroing) => { + if let ZeroingMethod::Drop { drop } = intrinsic + .input + .predication_methods + .zeroing_method + .to_owned() + .ok_or_else(|| { + "DontCareMethod::AsZeroing without zeroing method.".to_string() + })? + { + intrinsic.signature.drop_argument(&drop)? + } + } + _ => {} + } + + Ok(()) + } + + fn infer_dont_care(mask: &PredicationMask, methods: &PredicationMethods) -> PredicateForm { + let method = if methods.dont_care_method == DontCareMethod::Inferred { + if mask.has_zeroing() + && matches!(methods.zeroing_method, Some(ZeroingMethod::Drop { .. })) + { + DontCareMethod::AsZeroing + } else { + DontCareMethod::AsMerging + } + } else { + methods.dont_care_method + }; + + PredicateForm::DontCare(method) + } + + pub fn compile_list( + mask: &PredicationMask, + methods: &PredicationMethods, + ) -> context::Result> { + let mut forms = Vec::new(); + + if mask.has_merging() { + forms.push(PredicateForm::Merging) + } + + if mask.has_dont_care() { + forms.push(Self::infer_dont_care(mask, methods)) + } + + if mask.has_zeroing() { + if let Some(method) = methods.zeroing_method.to_owned() { + forms.push(PredicateForm::Zeroing(method)) + } else { + return Err( + "cannot create a zeroing variant without a zeroing method specified!" + .to_string(), + ); + } + } + + Ok(forms) + } +} + +#[derive( + Debug, Clone, Copy, Default, PartialEq, Eq, Hash, DeserializeFromStr, SerializeDisplay, +)] +pub struct PredicationMask { + /// Merging + m: bool, + /// Don't care + x: bool, + /// Zeroing + z: bool, +} + +impl PredicationMask { + pub fn has_merging(&self) -> bool { + self.m + } + + pub fn has_dont_care(&self) -> bool { + self.x + } + + pub fn has_zeroing(&self) -> bool { + self.z + } +} + +impl FromStr for PredicationMask { + type Err = String; + + fn from_str(s: &str) -> Result { + let mut result = Self::default(); + for kind in s.bytes() { + match kind { + b'm' => result.m = true, + b'x' => result.x = true, + b'z' => result.z = true, + _ => { + return Err(format!( + "unknown predicate form modifier: {}", + char::from(kind) + )); + } + } + } + + if result.m || result.x || result.z { + Ok(result) + } else { + Err("invalid predication mask".to_string()) + } + } +} + +impl fmt::Display for PredicationMask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.m.then(|| write!(f, "m")).transpose()?; + self.x.then(|| write!(f, "x")).transpose()?; + self.z.then(|| write!(f, "z")).transpose().map(|_| ()) + } +} + +impl TryFrom<&WildString> for PredicationMask { + type Error = String; + + fn try_from(value: &WildString) -> Result { + value + .wildcards() + .find_map(|w| { + if let Wildcard::PredicateForms(mask) = w { + Some(*mask) + } else { + None + } + }) + .ok_or_else(|| "no predicate forms were specified in the name".to_string()) + } +} diff --git a/crates/stdarch-gen2/src/typekinds.rs b/crates/stdarch-gen2/src/typekinds.rs new file mode 100644 index 0000000000..71f6297d94 --- /dev/null +++ b/crates/stdarch-gen2/src/typekinds.rs @@ -0,0 +1,1024 @@ +use lazy_static::lazy_static; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens, TokenStreamExt}; +use regex::Regex; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::fmt; +use std::str::FromStr; + +use crate::context; +use crate::expression::{Expression, FnCall}; +use crate::intrinsic::AccessLevel; +use crate::wildcards::Wildcard; + +const VECTOR_FULL_REGISTER_SIZE: u32 = 128; +const VECTOR_HALF_REGISTER_SIZE: u32 = VECTOR_FULL_REGISTER_SIZE / 2; + +#[derive(Debug, Clone, Copy)] +pub enum TypeRepr { + C, + Rust, + LLVMMachine, + ACLENotation, + Size, + SizeLiteral, + TypeKind, + SizeInBytesLog2, +} + +pub trait ToRepr { + fn repr(&self, repr: TypeRepr) -> String; + + fn c_repr(&self) -> String { + self.repr(TypeRepr::C) + } + + fn rust_repr(&self) -> String { + self.repr(TypeRepr::Rust) + } + + fn llvm_machine_repr(&self) -> String { + self.repr(TypeRepr::LLVMMachine) + } + + fn acle_notation_repr(&self) -> String { + self.repr(TypeRepr::ACLENotation) + } + + fn size(&self) -> String { + self.repr(TypeRepr::Size) + } + + fn size_literal(&self) -> String { + self.repr(TypeRepr::SizeLiteral) + } + + fn type_kind(&self) -> String { + self.repr(TypeRepr::TypeKind) + } + + fn size_in_bytes_log2(&self) -> String { + self.repr(TypeRepr::SizeInBytesLog2) + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] +pub struct TypeKindOptions { + f: bool, + s: bool, + u: bool, +} + +impl TypeKindOptions { + pub fn contains(&self, kind: BaseTypeKind) -> bool { + match kind { + BaseTypeKind::Float => self.f, + BaseTypeKind::Int => self.s, + BaseTypeKind::UInt => self.u, + BaseTypeKind::Bool => false, + } + } +} + +impl FromStr for TypeKindOptions { + type Err = String; + + fn from_str(s: &str) -> Result { + let mut result = Self::default(); + for kind in s.bytes() { + match kind { + b'f' => result.f = true, + b's' => result.s = true, + b'u' => result.u = true, + _ => { + return Err(format!("unknown type kind: {}", char::from(kind))); + } + } + } + Ok(result) + } +} + +impl fmt::Display for TypeKindOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.f.then(|| write!(f, "f")).transpose()?; + self.s.then(|| write!(f, "s")).transpose()?; + self.u.then(|| write!(f, "u")).transpose().map(|_| ()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum BaseTypeKind { + Float, + Int, + UInt, + Bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum BaseType { + Sized(BaseTypeKind, u32), + Unsized(BaseTypeKind), +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr, +)] +pub enum VectorTupleSize { + Two, + Three, + Four, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct VectorType { + base_type: BaseType, + lanes: u32, + is_scalable: bool, + tuple_size: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeDisplay, DeserializeFromStr)] +pub enum TypeKind { + Vector(VectorType), + Base(BaseType), + Pointer(Box, AccessLevel), + Custom(String), + Wildcard(Wildcard), +} + +impl TypeKind { + pub fn base_type(&self) -> Option<&BaseType> { + match self { + Self::Vector(t) => Some(t.base_type()), + Self::Pointer(t, _) => t.base_type(), + Self::Base(t) => Some(t), + Self::Wildcard(..) => None, + Self::Custom(..) => None, + } + } + + pub fn base_type_mut(&mut self) -> Option<&mut BaseType> { + match self { + Self::Vector(t) => Some(t.base_type_mut()), + Self::Pointer(t, _) => t.base_type_mut(), + Self::Base(t) => Some(t), + Self::Wildcard(..) => None, + Self::Custom(..) => None, + } + } + + pub fn populate_wildcard(&mut self, type_kind: TypeKind) -> context::Result { + match self { + Self::Wildcard(..) => *self = type_kind, + Self::Pointer(t, _) => t.populate_wildcard(type_kind)?, + _ => return Err("no wildcard available to populate".to_string()), + } + Ok(()) + } + + pub fn base(&self) -> Option<&BaseType> { + match self { + Self::Base(ty) => Some(ty), + Self::Pointer(tk, _) => tk.base(), + _ => None, + } + } + + pub fn vector(&self) -> Option<&VectorType> { + match self { + Self::Vector(ty) => Some(ty), + _ => None, + } + } + + pub fn vector_mut(&mut self) -> Option<&mut VectorType> { + match self { + Self::Vector(ty) => Some(ty), + _ => None, + } + } + + pub fn wildcard(&self) -> Option<&Wildcard> { + match self { + Self::Wildcard(w) => Some(w), + Self::Pointer(w, _) => w.wildcard(), + _ => None, + } + } + + pub fn make_predicate_from(ty: &TypeKind) -> context::Result { + Ok(TypeKind::Vector(VectorType::make_predicate_from_bitsize( + ty.base_type() + .ok_or_else(|| format!("cannot infer predicate from type {ty}"))? + .get_size() + .map_err(|_| format!("cannot infer predicate from unsized type {ty}"))?, + ))) + } + + pub fn make_vector( + from: TypeKind, + is_scalable: bool, + tuple_size: Option, + ) -> context::Result { + from.base().cloned().map_or_else( + || Err(format!("cannot make a vector type out of {from}!")), + |base| { + let vt = VectorType::make_from_base(base, is_scalable, tuple_size); + Ok(TypeKind::Vector(vt)) + }, + ) + } + + /// Return a new expression that converts the provided `expr` from type `other` to `self`. + /// + /// Conversions are bitwise over the whole value, like `transmute`, though `transmute` + /// itself is only used as a last resort. + /// + /// This can fail (returning `None`) due to incompatible types, and many conversions are simply + /// unimplemented. + pub fn express_reinterpretation_from( + &self, + other: &TypeKind, + expr: impl Into, + ) -> Option { + if self == other { + Some(expr.into()) + } else if let (Some(self_vty), Some(other_vty)) = (self.vector(), other.vector()) { + if self_vty.is_scalable + && self_vty.tuple_size.is_none() + && other_vty.is_scalable + && other_vty.tuple_size.is_none() + { + // Plain scalable vectors. + use BaseTypeKind::*; + match (self_vty.base_type, other_vty.base_type) { + (BaseType::Sized(Int, self_size), BaseType::Sized(UInt, other_size)) + if self_size == other_size => + { + Some(Expression::MethodCall( + Box::new(expr.into()), + "as_signed".parse().unwrap(), + vec![], + )) + } + (BaseType::Sized(UInt, self_size), BaseType::Sized(Int, other_size)) + if self_size == other_size => + { + Some(Expression::MethodCall( + Box::new(expr.into()), + "as_unsigned".parse().unwrap(), + vec![], + )) + } + ( + BaseType::Sized(Float | Int | UInt, _), + BaseType::Sized(Float | Int | UInt, _), + ) => Some(FnCall::new_expression( + // Conversions between float and (u)int, or where the lane size changes. + "simd_reinterpret".parse().unwrap(), + vec![expr.into()], + )), + _ => None, + } + } else { + // Tuples and fixed-width vectors. + None + } + } else { + // Scalar types. + None + } + } +} + +impl FromStr for TypeKind { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(match s { + s if s.starts_with('{') && s.ends_with('}') => { + Self::Wildcard(s[1..s.len() - 1].trim().parse()?) + } + s if s.starts_with('*') => { + let mut split = s[1..].split_whitespace(); + let (ty, rw) = match (split.clone().count(), split.next(), split.next()) { + (2, Some("mut"), Some(ty)) => (ty, AccessLevel::RW), + (2, Some("const"), Some(ty)) => (ty, AccessLevel::R), + (1, Some(ty), None) => (ty, AccessLevel::R), + _ => return Err(format!("invalid pointer type {s:#?} given")), + }; + Self::Pointer(Box::new(ty.parse()?), rw) + } + _ => s + .parse::() + .map(TypeKind::Vector) + .or_else(|_| s.parse::().map(TypeKind::Base)) + .unwrap_or_else(|_| TypeKind::Custom(s.to_string())), + }) + } +} + +impl fmt::Display for TypeKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Vector(ty) => write!(f, "{ty}"), + Self::Pointer(ty, _) => write!(f, "{ty}"), + Self::Base(ty) => write!(f, "{ty}"), + Self::Wildcard(w) => write!(f, "{{{w}}}"), + Self::Custom(s) => write!(f, "{s}"), + } + } +} + +impl ToRepr for TypeKind { + fn repr(&self, repr: TypeRepr) -> String { + match self { + Self::Vector(ty) => ty.repr(repr), + Self::Pointer(ty, _) => ty.repr(repr), + Self::Base(ty) => ty.repr(repr), + Self::Wildcard(w) => format!("{w}"), + Self::Custom(s) => s.to_string(), + } + } +} + +impl ToTokens for TypeKind { + fn to_tokens(&self, tokens: &mut TokenStream) { + if let Self::Pointer(_, rw) = self { + tokens.append_all(match rw { + AccessLevel::RW => quote! { *mut }, + AccessLevel::R => quote! { *const }, + }) + } + + tokens.append_all( + self.to_string() + .parse::() + .expect("invalid syntax"), + ) + } +} + +impl PartialOrd for TypeKind { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TypeKind { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use std::cmp::Ordering::*; + + impl From<&TypeKind> for usize { + fn from(ty: &TypeKind) -> Self { + match ty { + TypeKind::Base(_) => 1, + TypeKind::Pointer(_, _) => 2, + TypeKind::Vector(_) => 3, + TypeKind::Custom(_) => 4, + TypeKind::Wildcard(_) => 5, + } + } + } + + let self_int: usize = self.into(); + let other_int: usize = other.into(); + + if self_int == other_int { + match (self, other) { + (TypeKind::Base(ty1), TypeKind::Base(ty2)) => ty1.cmp(ty2), + (TypeKind::Pointer(ty1, _), TypeKind::Pointer(ty2, _)) => ty1.cmp(ty2), + (TypeKind::Vector(vt1), TypeKind::Vector(vt2)) => vt1.cmp(vt2), + (TypeKind::Custom(s1), TypeKind::Custom(s2)) => s1.cmp(s2), + (TypeKind::Wildcard(..), TypeKind::Wildcard(..)) => Equal, + _ => unreachable!(), + } + } else { + self_int.cmp(&other_int) + } + } +} + +impl VectorType { + pub fn base_type(&self) -> &BaseType { + &self.base_type + } + + pub fn base_type_mut(&mut self) -> &mut BaseType { + &mut self.base_type + } + + fn sanitise_lanes( + mut base_type: BaseType, + lanes: Option, + ) -> Result<(BaseType, u32), String> { + let lanes = match (base_type, lanes) { + (BaseType::Sized(BaseTypeKind::Bool, lanes), None) => { + base_type = BaseType::Sized(BaseTypeKind::Bool, VECTOR_FULL_REGISTER_SIZE / lanes); + lanes + } + (BaseType::Unsized(BaseTypeKind::Bool), None) => { + base_type = BaseType::Sized(BaseTypeKind::Bool, 8); + 16 + } + (BaseType::Sized(_, size), None) => VECTOR_FULL_REGISTER_SIZE / size, + (BaseType::Sized(_, size), Some(lanes)) => match size * lanes { + VECTOR_FULL_REGISTER_SIZE | VECTOR_HALF_REGISTER_SIZE => lanes, + _ => return Err("invalid number of lanes".to_string()), + }, + _ => return Err("cannot infer number of lanes".to_string()), + }; + + Ok((base_type, lanes)) + } + + pub fn make_from_base( + base_ty: BaseType, + is_scalable: bool, + tuple_size: Option, + ) -> VectorType { + if is_scalable { + if let BaseType::Sized(BaseTypeKind::Bool, size) = base_ty { + return Self::make_predicate_from_bitsize(size); + } + } + + let (base_type, lanes) = Self::sanitise_lanes(base_ty, None).unwrap(); + + VectorType { + base_type, + lanes, + is_scalable, + tuple_size, + } + } + + pub fn make_predicate_from_bitsize(size: u32) -> VectorType { + VectorType { + base_type: BaseType::Sized(BaseTypeKind::Bool, size), + lanes: (VECTOR_FULL_REGISTER_SIZE / size), + is_scalable: true, + tuple_size: None, + } + } + + pub fn cast_base_type_as(&mut self, ty: BaseType) { + self.base_type = ty + } +} + +impl FromStr for VectorType { + type Err = String; + + fn from_str(s: &str) -> Result { + lazy_static! { + static ref RE: Regex = Regex::new(r"^(?:(?:sv(?P(?:uint|int|bool|float)(?:\d+)?))|(?:(?P(?:uint|int|bool|float)(?:\d+)?)x(?P[0-9])))(?:x(?P2|3|4))?_t$").unwrap(); + } + + if let Some(c) = RE.captures(s) { + let (base_type, lanes) = Self::sanitise_lanes( + c.name("sv_ty") + .or_else(|| c.name("ty")) + .map(<&str>::from) + .map(BaseType::from_str) + .unwrap()?, + c.name("lanes") + .map(<&str>::from) + .map(u32::from_str) + .transpose() + .unwrap(), + ) + .map_err(|e| format!("invalid {s:#?} vector type: {e}"))?; + + let tuple_size = c + .name("tuple_size") + .map(<&str>::from) + .map(VectorTupleSize::from_str) + .transpose() + .unwrap(); + + Ok(VectorType { + base_type, + is_scalable: c.name("sv_ty").is_some(), + lanes, + tuple_size, + }) + } else { + Err(format!("invalid vector type {s:#?} given")) + } + } +} + +impl ToRepr for VectorType { + fn repr(&self, repr: TypeRepr) -> String { + let make_llvm_repr = |show_unsigned| { + format!( + "{}v{}{}", + if self.is_scalable { "nx" } else { "" }, + self.lanes * (self.tuple_size.map(usize::from).unwrap_or(1) as u32), + match self.base_type { + BaseType::Sized(BaseTypeKind::UInt, size) if show_unsigned => + format!("u{size}"), + _ => self.base_type.llvm_machine_repr(), + } + ) + }; + + if matches!(repr, TypeRepr::ACLENotation) { + self.base_type.acle_notation_repr() + } else if matches!(repr, TypeRepr::LLVMMachine) { + make_llvm_repr(false) + } else if self.is_scalable { + match (self.base_type, self.lanes, self.tuple_size) { + (BaseType::Sized(BaseTypeKind::Bool, _), 16, _) => "svbool_t".to_string(), + (BaseType::Sized(BaseTypeKind::Bool, _), lanes, _) => format!("svbool{lanes}_t"), + (BaseType::Sized(_, size), lanes, _) + if VECTOR_FULL_REGISTER_SIZE != (size * lanes) => + { + // Special internal type case + make_llvm_repr(true) + } + (ty, _, None) => format!("sv{}_t", ty.c_repr()), + (ty, _, Some(tuple_size)) => format!("sv{}x{tuple_size}_t", ty.c_repr()), + } + } else { + match self.tuple_size { + Some(tuple_size) => format!( + "{}x{}x{}_t", + self.base_type.c_repr(), + self.lanes, + tuple_size + ), + None => format!("{}x{}_t", self.base_type.c_repr(), self.lanes), + } + } + } +} + +impl fmt::Display for VectorType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.c_repr()) + } +} + +impl From for usize { + fn from(t: VectorTupleSize) -> Self { + match t { + VectorTupleSize::Two => 2, + VectorTupleSize::Three => 3, + VectorTupleSize::Four => 4, + } + } +} + +impl FromStr for VectorTupleSize { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "2" => Ok(Self::Two), + "3" => Ok(Self::Three), + "4" => Ok(Self::Four), + _ => Err(format!("invalid vector tuple size `{s}` provided")), + } + } +} + +impl TryFrom for VectorTupleSize { + type Error = String; + + fn try_from(value: usize) -> Result { + match value { + 2 => Ok(Self::Two), + 3 => Ok(Self::Three), + 4 => Ok(Self::Four), + _ => Err(format!("invalid vector tuple size `{value}` provided")), + } + } +} + +impl fmt::Display for VectorTupleSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", usize::from(*self)) + } +} + +impl FromStr for BaseTypeKind { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "float" | "f" => Ok(Self::Float), + "int" | "i" => Ok(Self::Int), + "uint" | "u" => Ok(Self::UInt), + "bool" | "b" => Ok(Self::Bool), + _ => Err(format!("no match for {s}")), + } + } +} + +impl ToRepr for BaseTypeKind { + fn repr(&self, repr: TypeRepr) -> String { + match (repr, self) { + (TypeRepr::C, Self::Float) => "float", + (TypeRepr::C, Self::Int) => "int", + (TypeRepr::C, Self::UInt) => "uint", + (TypeRepr::Rust | TypeRepr::LLVMMachine | TypeRepr::ACLENotation, Self::Float) => "f", + (TypeRepr::Rust, Self::Int) | (TypeRepr::LLVMMachine, Self::Int | Self::UInt) => "i", + (TypeRepr::Rust | TypeRepr::ACLENotation, Self::UInt) => "u", + (TypeRepr::ACLENotation, Self::Int) => "s", + (TypeRepr::ACLENotation, Self::Bool) => "b", + (_, Self::Bool) => "bool", + _ => { + unreachable!("no base type kind available for representation {repr:?}") + } + } + .to_string() + } +} + +impl fmt::Display for BaseTypeKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.c_repr()) + } +} + +impl BaseType { + pub fn get_size(&self) -> Result { + match self { + Self::Sized(_, size) => Ok(*size), + _ => Err(format!("unexpected invalid base type given {self:#?}")), + } + } + + pub fn kind(&self) -> &BaseTypeKind { + match self { + BaseType::Sized(kind, _) | BaseType::Unsized(kind) => kind, + } + } + + pub fn is_bool(&self) -> bool { + self.kind() == &BaseTypeKind::Bool + } + + pub fn is_float(&self) -> bool { + self.kind() == &BaseTypeKind::Float + } +} + +impl FromStr for BaseType { + type Err = String; + + fn from_str(s: &str) -> Result { + lazy_static! { + static ref RE: Regex = Regex::new(r"^(?P[a-zA-Z]+)(?P\d+)?(_t)?$").unwrap(); + } + + if let Some(c) = RE.captures(s) { + let kind = c["kind"].parse()?; + let size = c + .name("size") + .map(<&str>::from) + .map(u32::from_str) + .transpose() + .unwrap(); + + match size { + Some(size) => Ok(Self::Sized(kind, size)), + None => Ok(Self::Unsized(kind)), + } + } else { + Err(format!("failed to parse type `{s}`")) + } + } +} + +impl ToRepr for BaseType { + fn repr(&self, repr: TypeRepr) -> String { + use BaseType::*; + use BaseTypeKind::*; + use TypeRepr::*; + match (self, &repr) { + (Sized(Bool, _) | Unsized(Bool), LLVMMachine) => "i1".to_string(), + (Sized(_, size), SizeLiteral) if *size == 8 => "b".to_string(), + (Sized(_, size), SizeLiteral) if *size == 16 => "h".to_string(), + (Sized(_, size), SizeLiteral) if *size == 32 => "w".to_string(), + (Sized(_, size), SizeLiteral) if *size == 64 => "d".to_string(), + (_, SizeLiteral) => unreachable!("cannot represent {self:#?} as size literal"), + (Sized(Float, _) | Unsized(Float), TypeKind) => "f".to_string(), + (Sized(Int, _) | Unsized(Int), TypeKind) => "s".to_string(), + (Sized(UInt, _) | Unsized(UInt), TypeKind) => "u".to_string(), + (Sized(_, size), Size) => size.to_string(), + (Sized(_, size), SizeInBytesLog2) => { + assert!(size.is_power_of_two() && *size >= 8); + (size >> 3).trailing_zeros().to_string() + } + (Sized(kind, size), _) => format!("{}{size}", kind.repr(repr)), + (Unsized(kind), _) => kind.repr(repr), + } + } +} + +impl fmt::Display for BaseType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.rust_repr()) + } +} + +#[cfg(test)] +mod tests { + use crate::typekinds::*; + + #[test] + fn test_predicate() { + assert_eq!( + "svbool_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::Bool, 8), + is_scalable: true, + lanes: 16, + tuple_size: None + }) + ); + } + + #[test] + fn test_llvm_internal_predicate() { + assert_eq!( + "svbool4_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::Bool, 32), + is_scalable: true, + lanes: 4, + tuple_size: None + }) + ); + } + + #[test] + fn test_llvm_internal_predicate_llvm() { + assert_eq!( + "svbool4_t".parse::().unwrap().llvm_machine_repr(), + "nxv4i1" + ); + } + + #[test] + fn test_llvm_internal_predicate_acle() { + assert_eq!( + "svbool4_t" + .parse::() + .unwrap() + .acle_notation_repr(), + "b32" + ); + } + + #[test] + fn test_predicate_from_bitsize() { + let pg = VectorType::make_predicate_from_bitsize(32); + assert_eq!(pg.acle_notation_repr(), "b32"); + assert_eq!(pg, "svbool4_t".parse().unwrap()); + assert_eq!(pg.lanes, 4); + assert_eq!(pg.base_type, BaseType::Sized(BaseTypeKind::Bool, 32)); + } + + #[test] + fn test_scalable_single() { + assert_eq!( + "svuint8_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 8), + is_scalable: true, + lanes: 16, + tuple_size: None + }) + ); + } + + #[test] + fn test_scalable_tuple() { + assert_eq!( + "svint64x3_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::Int, 64), + is_scalable: true, + lanes: 2, + tuple_size: Some(VectorTupleSize::Three), + }) + ); + } + + #[test] + fn test_scalable_single_llvm() { + assert_eq!( + "svuint32_t" + .parse::() + .unwrap() + .llvm_machine_repr(), + "nxv4i32" + ); + } + + #[test] + fn test_scalable_tuple_llvm() { + assert_eq!( + "svint32x4_t" + .parse::() + .unwrap() + .llvm_machine_repr(), + "nxv16i32" + ); + } + + #[test] + fn test_vector_single_full() { + assert_eq!( + "uint32x4_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 32), + is_scalable: false, + lanes: 4, + tuple_size: None, + }) + ); + } + + #[test] + fn test_vector_single_half() { + assert_eq!( + "uint32x2_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 32), + is_scalable: false, + lanes: 2, + tuple_size: None, + }) + ); + } + + #[test] + fn test_vector_tuple() { + assert_eq!( + "uint64x2x4_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 64), + is_scalable: false, + lanes: 2, + tuple_size: Some(VectorTupleSize::Four), + }) + ); + } + + #[test] + fn test_const_pointer() { + let p = "*u32".parse::().unwrap(); + assert_eq!( + p, + TypeKind::Pointer( + Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::UInt, 32))), + AccessLevel::R + ) + ); + assert_eq!(p.to_token_stream().to_string(), "* const u32") + } + + #[test] + fn test_mut_pointer() { + let p = "*mut u32".parse::().unwrap(); + assert_eq!( + p, + TypeKind::Pointer( + Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::UInt, 32))), + AccessLevel::RW + ) + ); + assert_eq!(p.to_token_stream().to_string(), "* mut u32") + } + + #[test] + #[should_panic] + fn test_invalid_vector_single() { + assert_eq!( + "uint32x8_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 32), + is_scalable: false, + lanes: 8, + tuple_size: None, + }) + ); + } + + #[test] + #[should_panic] + fn test_invalid_vector_tuple() { + assert_eq!( + "uint32x4x5_t".parse::().unwrap(), + TypeKind::Vector(VectorType { + base_type: BaseType::Sized(BaseTypeKind::UInt, 32), + is_scalable: false, + lanes: 8, + tuple_size: None, // cannot represent + }) + ); + } + + #[test] + fn test_base() { + assert_eq!( + "u32".parse::().unwrap(), + TypeKind::Base(BaseType::Sized(BaseTypeKind::UInt, 32)), + ) + } + + #[test] + fn test_custom() { + assert_eq!( + "svpattern".parse::().unwrap(), + TypeKind::Custom("svpattern".to_string()), + ) + } + + #[test] + fn test_wildcard_type() { + assert_eq!( + "{type}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::Type(None)), + ) + } + + #[test] + fn test_wildcard_typeset() { + assert_eq!( + "{type[0]}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::Type(Some(0))), + ) + } + + #[test] + fn test_wildcard_sve_type() { + assert_eq!( + "{sve_type}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::SVEType(None, None)), + ) + } + + #[test] + fn test_wildcard_sve_typeset() { + assert_eq!( + "{sve_type[0]}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::SVEType(Some(0), None)), + ) + } + + #[test] + fn test_wildcard_sve_tuple_type() { + assert_eq!( + "{sve_type_x2}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::SVEType(None, Some(VectorTupleSize::Two))), + ) + } + + #[test] + fn test_wildcard_sve_tuple_typeset() { + assert_eq!( + "{sve_type_x2[0]}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::SVEType(Some(0), Some(VectorTupleSize::Two))), + ) + } + + #[test] + fn test_wildcard_predicate() { + assert_eq!( + "{predicate}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::Predicate(None)) + ) + } + + #[test] + fn test_wildcard_scale() { + assert_eq!( + "{sve_type as i8}".parse::().unwrap(), + TypeKind::Wildcard(Wildcard::Scale( + Box::new(Wildcard::SVEType(None, None)), + Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::Int, 8))) + )) + ) + } + + #[test] + fn test_size_in_bytes_log2() { + assert_eq!("i8".parse::().unwrap().size_in_bytes_log2(), "0"); + assert_eq!("i16".parse::().unwrap().size_in_bytes_log2(), "1"); + assert_eq!("i32".parse::().unwrap().size_in_bytes_log2(), "2"); + assert_eq!("i64".parse::().unwrap().size_in_bytes_log2(), "3") + } + + #[test] + #[should_panic] + fn test_invalid_size_in_bytes_log2() { + "i9".parse::().unwrap().size_in_bytes_log2(); + } +} diff --git a/crates/stdarch-gen2/src/wildcards.rs b/crates/stdarch-gen2/src/wildcards.rs new file mode 100644 index 0000000000..9d6194d517 --- /dev/null +++ b/crates/stdarch-gen2/src/wildcards.rs @@ -0,0 +1,179 @@ +use lazy_static::lazy_static; +use regex::Regex; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::fmt; +use std::str::FromStr; + +use crate::{ + predicate_forms::PredicationMask, + typekinds::{ToRepr, TypeKind, TypeKindOptions, VectorTupleSize}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeDisplay, DeserializeFromStr)] +pub enum Wildcard { + Type(Option), + /// NEON type derivated by a base type + NEONType(Option, Option), + /// SVE type derivated by a base type + SVEType(Option, Option), + /// Integer representation of bitsize + Size(Option), + /// Integer representation of bitsize minus one + SizeMinusOne(Option), + /// Literal representation of the bitsize: b(yte), h(half), w(ord) or d(ouble) + SizeLiteral(Option), + /// Literal representation of the type kind: f(loat), s(igned), u(nsigned) + TypeKind(Option, Option), + /// Log2 of the size in bytes + SizeInBytesLog2(Option), + /// Predicate to be inferred from the specified type + Predicate(Option), + /// Predicate to be inferred from the greatest type + MaxPredicate, + + Scale(Box, Box), + + // Other wildcards + LLVMLink, + NVariant, + /// Predicate forms to use and placeholder for a predicate form function name modifier + PredicateForms(PredicationMask), + + /// User-set wildcard through `substitutions` + Custom(String), +} + +impl Wildcard { + pub fn is_nonpredicate_type(&self) -> bool { + matches!( + self, + Wildcard::Type(..) | Wildcard::NEONType(..) | Wildcard::SVEType(..) + ) + } + + pub fn get_typeset_index(&self) -> Option { + match self { + Wildcard::Type(idx) | Wildcard::NEONType(idx, ..) | Wildcard::SVEType(idx, ..) => { + Some(idx.unwrap_or(0)) + } + _ => None, + } + } +} + +impl FromStr for Wildcard { + type Err = String; + + fn from_str(s: &str) -> Result { + lazy_static! { + static ref RE: Regex = Regex::new(r"^(?P\w+?)(?:_x(?P[2-4]))?(?:\[(?P\d+)\])?(?:\.(?P\w+))?(?:\s+as\s+(?P.*?))?$").unwrap(); + } + + if let Some(c) = RE.captures(s) { + let wildcard_name = &c["wildcard"]; + let inputset_index = c + .name("index") + .map(<&str>::from) + .map(usize::from_str) + .transpose() + .map_err(|_| format!("{:#?} is not a valid type index", &c["index"]))?; + let tuple_size = c + .name("tuple_size") + .map(<&str>::from) + .map(VectorTupleSize::from_str) + .transpose() + .map_err(|_| format!("{:#?} is not a valid tuple size", &c["tuple_size"]))?; + let modifiers = c.name("modifiers").map(<&str>::from); + + let wildcard = match (wildcard_name, inputset_index, tuple_size, modifiers) { + ("type", index, None, None) => Ok(Wildcard::Type(index)), + ("neon_type", index, tuple, None) => Ok(Wildcard::NEONType(index, tuple)), + ("sve_type", index, tuple, None) => Ok(Wildcard::SVEType(index, tuple)), + ("size", index, None, None) => Ok(Wildcard::Size(index)), + ("size_minus_one", index, None, None) => Ok(Wildcard::SizeMinusOne(index)), + ("size_literal", index, None, None) => Ok(Wildcard::SizeLiteral(index)), + ("type_kind", index, None, modifiers) => Ok(Wildcard::TypeKind( + index, + modifiers.map(|modifiers| modifiers.parse()).transpose()?, + )), + ("size_in_bytes_log2", index, None, None) => Ok(Wildcard::SizeInBytesLog2(index)), + ("predicate", index, None, None) => Ok(Wildcard::Predicate(index)), + ("max_predicate", None, None, None) => Ok(Wildcard::MaxPredicate), + ("llvm_link", None, None, None) => Ok(Wildcard::LLVMLink), + ("_n", None, None, None) => Ok(Wildcard::NVariant), + (w, None, None, None) if w.starts_with('_') => { + // test for predicate forms + let pf_mask = PredicationMask::from_str(&w[1..]); + if let Ok(mask) = pf_mask { + if mask.has_merging() { + Ok(Wildcard::PredicateForms(mask)) + } else { + Err("cannot add predication without a Merging form".to_string()) + } + } else { + Err(format!("invalid wildcard `{s:#?}`")) + } + } + (cw, None, None, None) => Ok(Wildcard::Custom(cw.to_string())), + _ => Err(format!("invalid wildcard `{s:#?}`")), + }?; + + let scale_to = c + .name("scale_to") + .map(<&str>::from) + .map(TypeKind::from_str) + .transpose() + .map_err(|_| format!("{:#?} is not a valid type", &c["scale_to"]))?; + + if let Some(scale_to) = scale_to { + Ok(Wildcard::Scale(Box::new(wildcard), Box::new(scale_to))) + } else { + Ok(wildcard) + } + } else { + Err(format!("invalid wildcard `{s:#?}`")) + } + } +} + +impl fmt::Display for Wildcard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Type(None) => write!(f, "type"), + Self::Type(Some(index)) => write!(f, "type[{index}]"), + Self::NEONType(None, None) => write!(f, "neon_type"), + Self::NEONType(Some(index), None) => write!(f, "neon_type[{index}]"), + Self::NEONType(None, Some(tuple_size)) => write!(f, "neon_type_x{tuple_size}"), + Self::NEONType(Some(index), Some(tuple_size)) => { + write!(f, "neon_type_x{tuple_size}[{index}]") + } + Self::SVEType(None, None) => write!(f, "sve_type"), + Self::SVEType(Some(index), None) => write!(f, "sve_type[{index}]"), + Self::SVEType(None, Some(tuple_size)) => write!(f, "sve_type_x{tuple_size}"), + Self::SVEType(Some(index), Some(tuple_size)) => { + write!(f, "sve_type_x{tuple_size}[{index}]") + } + Self::Size(None) => write!(f, "size"), + Self::Size(Some(index)) => write!(f, "size[{index}]"), + Self::SizeMinusOne(None) => write!(f, "size_minus_one"), + Self::SizeMinusOne(Some(index)) => write!(f, "size_minus_one[{index}]"), + Self::SizeLiteral(None) => write!(f, "size_literal"), + Self::SizeLiteral(Some(index)) => write!(f, "size_literal[{index}]"), + Self::TypeKind(None, None) => write!(f, "type_kind"), + Self::TypeKind(None, Some(opts)) => write!(f, "type_kind.{opts}"), + Self::TypeKind(Some(index), None) => write!(f, "type_kind[{index}]"), + Self::TypeKind(Some(index), Some(opts)) => write!(f, "type_kind[{index}].{opts}"), + Self::SizeInBytesLog2(None) => write!(f, "size_in_bytes_log2"), + Self::SizeInBytesLog2(Some(index)) => write!(f, "size_in_bytes_log2[{index}]"), + Self::Predicate(None) => write!(f, "predicate"), + Self::Predicate(Some(index)) => write!(f, "predicate[{index}]"), + Self::MaxPredicate => write!(f, "max_predicate"), + Self::LLVMLink => write!(f, "llvm_link"), + Self::NVariant => write!(f, "_n"), + Self::PredicateForms(mask) => write!(f, "_{mask}"), + + Self::Scale(wildcard, ty) => write!(f, "{wildcard} as {}", ty.rust_repr()), + Self::Custom(cw) => write!(f, "{cw}"), + } + } +} diff --git a/crates/stdarch-gen2/src/wildstring.rs b/crates/stdarch-gen2/src/wildstring.rs new file mode 100644 index 0000000000..1f9e6c9ada --- /dev/null +++ b/crates/stdarch-gen2/src/wildstring.rs @@ -0,0 +1,353 @@ +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens, TokenStreamExt}; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::str::pattern::Pattern; +use std::{fmt, str::FromStr}; + +use crate::context::LocalContext; +use crate::typekinds::{ToRepr, TypeRepr}; +use crate::wildcards::Wildcard; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WildStringPart { + String(String), + Wildcard(Wildcard), +} + +/// Wildcard-able string +#[derive(Debug, Clone, PartialEq, Eq, Default, SerializeDisplay, DeserializeFromStr)] +pub struct WildString(Vec); + +impl WildString { + pub fn has_wildcards(&self) -> bool { + for part in self.0.iter() { + if let WildStringPart::Wildcard(..) = part { + return true; + } + } + + false + } + + pub fn wildcards(&self) -> impl Iterator + '_ { + self.0.iter().filter_map(|part| match part { + WildStringPart::Wildcard(w) => Some(w), + _ => None, + }) + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> impl Iterator + '_ { + self.0.iter_mut() + } + + pub fn starts_with(&self, s2: &str) -> bool { + self.to_string().starts_with(s2) + } + + pub fn prepend_str(&mut self, s: impl Into) { + self.0.insert(0, WildStringPart::String(s.into())) + } + + pub fn push_str(&mut self, s: impl Into) { + self.0.push(WildStringPart::String(s.into())) + } + + pub fn push_wildcard(&mut self, w: Wildcard) { + self.0.push(WildStringPart::Wildcard(w)) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn replace<'a, P>(&'a self, from: P, to: &str) -> WildString + where + P: Pattern<'a> + Copy, + { + WildString( + self.0 + .iter() + .map(|part| match part { + WildStringPart::String(s) => WildStringPart::String(s.replace(from, to)), + part => part.clone(), + }) + .collect_vec(), + ) + } + + pub fn build_acle(&mut self, ctx: &LocalContext) -> Result<(), String> { + self.build(ctx, TypeRepr::ACLENotation) + } + + pub fn build(&mut self, ctx: &LocalContext, repr: TypeRepr) -> Result<(), String> { + self.iter_mut().try_for_each(|wp| -> Result<(), String> { + if let WildStringPart::Wildcard(w) = wp { + let value = ctx + .provide_substitution_wildcard(w) + .or_else(|_| ctx.provide_type_wildcard(w).map(|ty| ty.repr(repr)))?; + *wp = WildStringPart::String(value); + } + Ok(()) + }) + } +} + +impl From for WildString { + fn from(s: String) -> Self { + WildString(vec![WildStringPart::String(s)]) + } +} + +impl FromStr for WildString { + type Err = String; + + fn from_str(s: &str) -> Result { + enum State { + Normal { start: usize }, + Wildcard { start: usize, count: usize }, + EscapeTokenOpen { start: usize, at: usize }, + EscapeTokenClose { start: usize, at: usize }, + } + + let mut ws = WildString::default(); + match s + .char_indices() + .try_fold(State::Normal { start: 0 }, |state, (idx, ch)| { + match (state, ch) { + (State::Normal { start }, '{') => Ok(State::EscapeTokenOpen { start, at: idx }), + (State::Normal { start }, '}') => { + Ok(State::EscapeTokenClose { start, at: idx }) + } + (State::EscapeTokenOpen { start, at }, '{') + | (State::EscapeTokenClose { start, at }, '}') => { + if start < at { + ws.push_str(&s[start..at]) + } + + Ok(State::Normal { start: idx }) + } + (State::EscapeTokenOpen { at, .. }, '}') => Err(format!( + "empty wildcard given in string {s:?} at position {at}" + )), + (State::EscapeTokenOpen { start, at }, _) => { + if start < at { + ws.push_str(&s[start..at]) + } + + Ok(State::Wildcard { + start: idx, + count: 0, + }) + } + (State::EscapeTokenClose { at, .. }, _) => Err(format!( + "closing a non-wildcard/bad escape in string {s:?} at position {at}" + )), + // Nesting wildcards is only supported for `{foo as {bar}}`, wildcards cannot be + // nested at the start of a WildString. + (State::Wildcard { start, count }, '{') => Ok(State::Wildcard { + start, + count: count + 1, + }), + (State::Wildcard { start, count: 0 }, '}') => { + ws.push_wildcard(s[start..idx].parse()?); + Ok(State::Normal { start: idx + 1 }) + } + (State::Wildcard { start, count }, '}') => Ok(State::Wildcard { + start, + count: count - 1, + }), + (state @ State::Normal { .. }, _) | (state @ State::Wildcard { .. }, _) => { + Ok(state) + } + } + })? { + State::Normal { start } => { + if start < s.len() { + ws.push_str(&s[start..]); + } + + Ok(ws) + } + State::EscapeTokenOpen { at, .. } | State::Wildcard { start: at, .. } => Err(format!( + "unclosed wildcard in string {s:?} at position {at}" + )), + State::EscapeTokenClose { at, .. } => Err(format!( + "closing a non-wildcard/bad escape in string {s:?} at position {at}" + )), + } + } +} + +impl fmt::Display for WildString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + self.0 + .iter() + .map(|part| match part { + WildStringPart::String(s) => s.to_owned(), + WildStringPart::Wildcard(w) => format!("{{{w}}}"), + }) + .join("") + ) + } +} + +impl ToTokens for WildString { + fn to_tokens(&self, tokens: &mut TokenStream) { + assert!( + !self.has_wildcards(), + "cannot convert string with wildcards {self:?} to TokenStream" + ); + let str = self.to_string(); + tokens.append_all(quote! { #str }) + } +} + +#[cfg(test)] +mod tests { + use crate::typekinds::*; + use crate::wildstring::*; + + #[test] + fn test_empty_string() { + let ws: WildString = "".parse().unwrap(); + assert_eq!(ws.0.len(), 0); + } + + #[test] + fn test_plain_string() { + let ws: WildString = "plain string".parse().unwrap(); + assert_eq!(ws.0.len(), 1); + assert_eq!( + ws, + WildString(vec![WildStringPart::String("plain string".to_string())]) + ) + } + + #[test] + fn test_escaped_curly_brackets() { + let ws: WildString = "VALUE = {{value}}".parse().unwrap(); + assert_eq!(ws.to_string(), "VALUE = {value}"); + assert!(!ws.has_wildcards()); + } + + #[test] + fn test_escaped_curly_brackets_wildcard() { + let ws: WildString = "TYPE = {{{type}}}".parse().unwrap(); + assert_eq!(ws.to_string(), "TYPE = {{type}}"); + assert_eq!(ws.0.len(), 4); + assert!(ws.has_wildcards()); + } + + #[test] + fn test_wildcard_right_boundary() { + let s = "string test {type}"; + let ws: WildString = s.parse().unwrap(); + assert_eq!(&ws.to_string(), s); + assert!(ws.has_wildcards()); + } + + #[test] + fn test_wildcard_left_boundary() { + let s = "{type} string test"; + let ws: WildString = s.parse().unwrap(); + assert_eq!(&ws.to_string(), s); + assert!(ws.has_wildcards()); + } + + #[test] + fn test_recursive_wildcard() { + let s = "string test {type[0] as {type[1]}}"; + let ws: WildString = s.parse().unwrap(); + + assert_eq!(ws.0.len(), 2); + assert_eq!( + ws, + WildString(vec![ + WildStringPart::String("string test ".to_string()), + WildStringPart::Wildcard(Wildcard::Scale( + Box::new(Wildcard::Type(Some(0))), + Box::new(TypeKind::Wildcard(Wildcard::Type(Some(1)))), + )) + ]) + ); + } + + #[test] + fn test_scale_wildcard() { + let s = "string {type[0] as i8} test"; + let ws: WildString = s.parse().unwrap(); + + assert_eq!(ws.0.len(), 3); + assert_eq!( + ws, + WildString(vec![ + WildStringPart::String("string ".to_string()), + WildStringPart::Wildcard(Wildcard::Scale( + Box::new(Wildcard::Type(Some(0))), + Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::Int, 8))), + )), + WildStringPart::String(" test".to_string()) + ]) + ); + } + + #[test] + fn test_solitaire_wildcard() { + let ws: WildString = "{type}".parse().unwrap(); + assert_eq!(ws.0.len(), 1); + assert_eq!( + ws, + WildString(vec![WildStringPart::Wildcard(Wildcard::Type(None))]) + ) + } + + #[test] + fn test_empty_wildcard() { + "string {}" + .parse::() + .expect_err("expected parse error"); + } + + #[test] + fn test_invalid_open_wildcard_right() { + "string {" + .parse::() + .expect_err("expected parse error"); + } + + #[test] + fn test_invalid_close_wildcard_right() { + "string }" + .parse::() + .expect_err("expected parse error"); + } + + #[test] + fn test_invalid_open_wildcard_left() { + "{string" + .parse::() + .expect_err("expected parse error"); + } + + #[test] + fn test_invalid_close_wildcard_left() { + "}string" + .parse::() + .expect_err("expected parse error"); + } + + #[test] + fn test_consecutive_wildcards() { + let s = "svprf{size_literal[1]}_gather_{type[0]}{index_or_offset}"; + let ws: WildString = s.parse().unwrap(); + assert_eq!(ws.to_string(), s) + } +}