diff --git a/specification/hugr.md b/specification/hugr.md index 0632882d5..ca9832447 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -907,17 +907,17 @@ The declaration of the `params` uses a language that is a distinct, simplified form of the [Type System](#type-system) - writing terminals that appear in the YAML in quotes, the value of each member of `params` is given by the following production: ``` -TypeParam ::= "Type"("Any"|"Copy"|"Eq") | "USize" | "Extensions" | "List"(TypeParam) | "Tuple"([TypeParam]) | Opaque +TypeParam ::= "Type"("Any"|"Copy"|"Eq") | "BoundedUSize(u64)" | "Extensions" | "List"(TypeParam) | "Tuple"([TypeParam]) | Opaque Opaque ::= string<[TypeArgs]> -TypeArgs ::= Type(Type) | USize(u64) | Extensions | List([TypeArg]) | Tuple([TypeArg]) +TypeArgs ::= Type(Type) | BoundedUSize(u64) | Extensions | List([TypeArg]) | Tuple([TypeArg]) Type ::= Name<[TypeArg]> ``` (We write `[Foo]` to indicate a list of Foo's; and omit `<>` where the contents is the empty list). -To use an OpDef as an Op, or a TypeDef as a type, the user must provide a type argument for each type param in the def: a type in the appropriate class, a constant usize, a set of extensions, a list or tuple of arguments. +To use an OpDef as an Op, or a TypeDef as a type, the user must provide a type argument for each type param in the def: a type in the appropriate class, a bounded usize, a set of extensions, a list or tuple of arguments. **Implementation note** Reading this format into Rust is made easy by `serde` and [serde\_yaml](https://github.com/dtolnay/serde-yaml) (see the diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index e4dfc92fa..a79497a61 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -30,7 +30,7 @@ lazy_static! { prelude .add_type( SmolStr::new_inline("array"), - vec![TypeParam::Type(TypeBound::Any), TypeParam::USize], + vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()], "array".into(), TypeDefBound::FromParams(vec![0]), ) @@ -68,7 +68,7 @@ pub(crate) const BOOL_T: Type = Type::new_simple_predicate(2); pub fn new_array(typ: Type, size: u64) -> Type { let array_def = PRELUDE.get_type("array").unwrap(); let custom_t = array_def - .instantiate_concrete(vec![TypeArg::Type(typ), TypeArg::USize(size)]) + .instantiate_concrete(vec![TypeArg::Type(typ), TypeArg::BoundedNat(size)]) .unwrap(); Type::new_extension(custom_t) } diff --git a/src/ops/constant.rs b/src/ops/constant.rs index ed2c0587d..6f367d07b 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -199,7 +199,12 @@ mod test { #[test] fn test_yaml_const() { - let typ_int = CustomType::new("mytype", vec![TypeArg::USize(8)], "myrsrc", TypeBound::Eq); + let typ_int = CustomType::new( + "mytype", + vec![TypeArg::BoundedNat(8)], + "myrsrc", + TypeBound::Eq, + ); let val: Value = CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into())).into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 6cca97c80..c2cad2fa0 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -7,27 +7,23 @@ use smol_str::SmolStr; use crate::{ extension::{ExtensionSet, SignatureError}, type_row, - types::{ - type_param::{TypeArg, TypeParam}, - Type, TypeRow, - }, + types::{type_param::TypeArg, Type, TypeRow}, utils::collect_array, Extension, }; -use super::float_types::FLOAT64_TYPE; -use super::int_types::{get_width, int_type}; +use super::int_types::int_type; +use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.conversions"); fn ftoi_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg] = collect_array(arg_values); - let n: u8 = get_width(arg)?; Ok(( type_row![FLOAT64_TYPE], vec![Type::new_sum(vec![ - int_type(n), + int_type(arg.clone()), crate::extension::prelude::ERROR_TYPE, ])] .into(), @@ -37,9 +33,8 @@ fn ftoi_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), fn itof_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg] = collect_array(arg_values); - let n: u8 = get_width(arg)?; Ok(( - vec![int_type(n)].into(), + vec![int_type(arg.clone())].into(), type_row![FLOAT64_TYPE], ExtensionSet::default(), )) @@ -59,7 +54,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ftoi_sig, ) .unwrap(); @@ -67,7 +62,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "trunc_s".into(), "float to signed int".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ftoi_sig, ) .unwrap(); @@ -75,7 +70,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "convert_u".into(), "unsigned int to float".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], itof_sig, ) .unwrap(); @@ -83,7 +78,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "convert_s".into(), "signed int to float".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], itof_sig, ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 9408f1bfe..eae8847de 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -2,10 +2,9 @@ use smol_str::SmolStr; -use super::int_types::{get_width, int_type}; +use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; use crate::type_row; -use crate::types::type_param::TypeParam; use crate::utils::collect_array; use crate::{ extension::{ExtensionSet, SignatureError}, @@ -18,35 +17,35 @@ pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.int"); fn iwiden_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let m: u8 = get_width(arg0)?; - let n: u8 = get_width(arg1)?; + let m: u8 = get_log_width(arg0)?; + let n: u8 = get_log_width(arg1)?; if m > n { return Err(SignatureError::InvalidTypeArgs); } Ok(( - vec![int_type(m)].into(), - vec![int_type(n)].into(), + vec![int_type(arg0.clone())].into(), + vec![int_type(arg1.clone())].into(), ExtensionSet::default(), )) } fn inarrow_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let m: u8 = get_width(arg0)?; - let n: u8 = get_width(arg1)?; + let m: u8 = get_log_width(arg0)?; + let n: u8 = get_log_width(arg1)?; if m < n { return Err(SignatureError::InvalidTypeArgs); } Ok(( - vec![int_type(m)].into(), - vec![Type::new_sum(vec![int_type(n), ERROR_TYPE])].into(), + vec![int_type(arg0.clone())].into(), + vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(), ExtensionSet::default(), )) } fn itob_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { Ok(( - vec![int_type(1)].into(), + vec![int_type(type_arg(0))].into(), type_row![BOOL_T], ExtensionSet::default(), )) @@ -55,16 +54,15 @@ fn itob_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), fn btoi_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { Ok(( type_row![BOOL_T], - vec![int_type(1)].into(), + vec![int_type(type_arg(0))].into(), ExtensionSet::default(), )) } fn icmp_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg] = collect_array(arg_values); - let n: u8 = get_width(arg)?; Ok(( - vec![int_type(n); 2].into(), + vec![int_type(arg.clone()); 2].into(), type_row![BOOL_T], ExtensionSet::default(), )) @@ -72,29 +70,25 @@ fn icmp_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), fn ibinop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg] = collect_array(arg_values); - let n: u8 = get_width(arg)?; Ok(( - vec![int_type(n); 2].into(), - vec![int_type(n)].into(), + vec![int_type(arg.clone()); 2].into(), + vec![int_type(arg.clone())].into(), ExtensionSet::default(), )) } fn iunop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg] = collect_array(arg_values); - let n: u8 = get_width(arg)?; Ok(( - vec![int_type(n)].into(), - vec![int_type(n)].into(), + vec![int_type(arg.clone())].into(), + vec![int_type(arg.clone())].into(), ExtensionSet::default(), )) } fn idivmod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let n: u8 = get_width(arg0)?; - let m: u8 = get_width(arg1)?; - let intpair: TypeRow = vec![int_type(n), int_type(m)].into(); + let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); Ok(( intpair.clone(), vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])].into(), @@ -104,33 +98,27 @@ fn idivmod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet fn idiv_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let n: u8 = get_width(arg0)?; - let m: u8 = get_width(arg1)?; Ok(( - vec![int_type(n), int_type(m)].into(), - vec![Type::new_sum(vec![int_type(n), ERROR_TYPE])].into(), + vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), + vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])].into(), ExtensionSet::default(), )) } fn imod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let n: u8 = get_width(arg0)?; - let m: u8 = get_width(arg1)?; Ok(( - vec![int_type(n), int_type(m)].into(), - vec![Type::new_sum(vec![int_type(m), ERROR_TYPE])].into(), + vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), + vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(), ExtensionSet::default(), )) } fn ish_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { let [arg0, arg1] = collect_array(arg_values); - let n: u8 = get_width(arg0)?; - let m: u8 = get_width(arg1)?; Ok(( - vec![int_type(n), int_type(m)].into(), - vec![int_type(n)].into(), + vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), + vec![int_type(arg0.clone())].into(), ExtensionSet::default(), )) } @@ -146,7 +134,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "iwiden_u".into(), "widen an unsigned integer to a wider one with the same value".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], iwiden_sig, ) .unwrap(); @@ -154,7 +142,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "iwiden_s".into(), "widen a signed integer to a wider one with the same value".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], iwiden_sig, ) .unwrap(); @@ -163,7 +151,7 @@ pub fn extension() -> Extension { "inarrow_u".into(), "narrow an unsigned integer to a narrower one with the same value if possible" .to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], inarrow_sig, ) .unwrap(); @@ -171,7 +159,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "inarrow_s".into(), "narrow a signed integer to a narrower one with the same value if possible".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], inarrow_sig, ) .unwrap(); @@ -195,7 +183,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ieq".into(), "equality test".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -203,7 +191,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ine".into(), "inequality test".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -211,7 +199,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -219,7 +207,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -227,7 +215,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -235,7 +223,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -243,7 +231,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -251,7 +239,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -259,7 +247,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -267,7 +255,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], icmp_sig, ) .unwrap(); @@ -275,7 +263,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -283,7 +271,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -291,7 +279,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -299,7 +287,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -307,7 +295,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -315,7 +303,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -323,7 +311,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], iunop_sig, ) .unwrap(); @@ -331,7 +319,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -341,7 +329,7 @@ pub fn extension() -> Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { .add_op_custom_sig_simple( "idiv_u".into(), "as idivmod_u but discarding the second output".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], idiv_sig, ) .unwrap(); @@ -367,7 +355,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imod_u".into(), "as idivmod_u but discarding the first output".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], idiv_sig, ) .unwrap(); @@ -375,7 +363,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "idiv_s".into(), "as idivmod_s but discarding the second output".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], imod_sig, ) .unwrap(); @@ -383,7 +371,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "imod_s".into(), "as idivmod_s but discarding the first output".to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], imod_sig, ) .unwrap(); @@ -391,7 +379,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "iabs".into(), "convert signed to unsigned by taking absolute value".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], iunop_sig, ) .unwrap(); @@ -399,7 +387,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "iand".into(), "bitwise AND".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -407,7 +395,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ior".into(), "bitwise OR".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -415,7 +403,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "ixor".into(), "bitwise XOR".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], ibinop_sig, ) .unwrap(); @@ -423,7 +411,7 @@ pub fn extension() -> Extension { .add_op_custom_sig_simple( "inot".into(), "bitwise NOT".to_owned(), - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], iunop_sig, ) .unwrap(); @@ -433,7 +421,7 @@ pub fn extension() -> Extension { "shift first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits dropped, rightmost bits set to zero" .to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], ish_sig, ) .unwrap(); @@ -443,7 +431,7 @@ pub fn extension() -> Extension { "shift first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits dropped, leftmost bits set to zero)" .to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], ish_sig, ) .unwrap(); @@ -453,7 +441,7 @@ pub fn extension() -> Extension { "rotate first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits replace rightmost bits)" .to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], ish_sig, ) .unwrap(); @@ -463,7 +451,7 @@ pub fn extension() -> Extension { "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)" .to_owned(), - vec![TypeParam::USize, TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM, LOG_WIDTH_TYPE_PARAM], ish_sig, ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 84ca1ccfa..4016c7355 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -1,9 +1,10 @@ //! Basic integer types +use std::num::NonZeroU64; + use smol_str::SmolStr; use crate::{ - extension::SignatureError, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -11,117 +12,121 @@ use crate::{ values::CustomConst, Extension, }; - +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.int.types"); -/// Identfier for the integer type. +/// Identifier for the integer type. const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); -fn int_custom_type(n: u8) -> CustomType { - CustomType::new( - INT_TYPE_ID, - [TypeArg::USize(n as u64)], - EXTENSION_ID, - TypeBound::Copyable, - ) +fn int_custom_type(width_arg: TypeArg) -> CustomType { + CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Copyable) } -/// Integer type of a given bit width. +/// Integer type of a given bit width (specified by the TypeArg). /// Depending on the operation, the semantic interpretation may be unsigned integer, signed integer /// or bit string. -pub fn int_type(n: u8) -> Type { - Type::new_extension(int_custom_type(n)) +pub(super) fn int_type(width_arg: TypeArg) -> Type { + Type::new_extension(int_custom_type(width_arg)) +} + +lazy_static! { + /// Array of valid integer types, indexed by log width of the integer. + pub static ref INT_TYPES: [Type; (MAX_LOG_WIDTH + 1) as usize] = (0..MAX_LOG_WIDTH + 1) + .map(|i| int_type(TypeArg::BoundedNat(i as u64))) + .collect::>() + .try_into() + .unwrap(); } -fn is_valid_width(n: u8) -> bool { - (n == 1) - || (n == 2) - || (n == 4) - || (n == 8) - || (n == 16) - || (n == 32) - || (n == 64) - || (n == 128) +const fn is_valid_log_width(n: u8) -> bool { + n <= MAX_LOG_WIDTH } -/// Get the bit width of the specified integer type, or error if the width is not supported. -pub fn get_width(arg: &TypeArg) -> Result { - let n: u8 = match arg { - TypeArg::USize(n) => *n as u8, - _ => { - return Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: TypeParam::USize, - } - .into()); - } - }; - if !is_valid_width(n) { - return Err(TypeArgError::InvalidValue(arg.clone()).into()); +/// The largest allowed log width. +pub const MAX_LOG_WIDTH: u8 = 7; + +/// Type parameter for the log width of the integer. +// SAFETY: unsafe block should be ok as the value is definitely not zero. +pub const LOG_WIDTH_TYPE_PARAM: TypeParam = + TypeParam::bounded_nat(unsafe { NonZeroU64::new_unchecked(MAX_LOG_WIDTH as u64 + 1) }); + +/// Get the log width of the specified type argument or error if the argument +/// is invalid. +pub(super) fn get_log_width(arg: &TypeArg) -> Result { + match arg { + TypeArg::BoundedNat(n) if is_valid_log_width(*n as u8) => Ok(*n as u8), + _ => Err(TypeArgError::TypeMismatch { + arg: arg.clone(), + param: LOG_WIDTH_TYPE_PARAM, + }), } - Ok(n) } +pub(super) const fn type_arg(log_width: u8) -> TypeArg { + TypeArg::BoundedNat(log_width as u64) +} /// An unsigned integer #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ConstIntU { - width: u8, + log_width: u8, value: u128, } /// A signed integer #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ConstIntS { - width: u8, + log_width: u8, value: i128, } impl ConstIntU { /// Create a new [`ConstIntU`] - pub fn new(width: u8, value: u128) -> Result { - if !is_valid_width(width) { + pub fn new(log_width: u8, value: u128) -> Result { + if !is_valid_log_width(log_width) { return Err(ConstTypeError::CustomCheckFail( crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()), )); } - if (width <= 64) && (value >= (1u128 << width)) { + if (log_width <= 6) && (value >= (1u128 << (1u8 << log_width))) { return Err(ConstTypeError::CustomCheckFail( crate::types::CustomCheckFailure::Message( "Invalid unsigned integer value.".to_owned(), ), )); } - Ok(Self { width, value }) + Ok(Self { log_width, value }) } } impl ConstIntS { /// Create a new [`ConstIntS`] - pub fn new(width: u8, value: i128) -> Result { - if !is_valid_width(width) { + pub fn new(log_width: u8, value: i128) -> Result { + if !is_valid_log_width(log_width) { return Err(ConstTypeError::CustomCheckFail( crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()), )); } - if (width <= 64) && (value >= (1i128 << (width - 1)) || value < -(1i128 << (width - 1))) { + let width = 1u8 << log_width; + if (log_width <= 6) && (value >= (1i128 << (width - 1)) || value < -(1i128 << (width - 1))) + { return Err(ConstTypeError::CustomCheckFail( crate::types::CustomCheckFailure::Message( "Invalid signed integer value.".to_owned(), ), )); } - Ok(Self { width, value }) + Ok(Self { log_width, value }) } } #[typetag::serde] impl CustomConst for ConstIntU { fn name(&self) -> SmolStr { - format!("u{}({})", self.width, self.value).into() + format!("u{}({})", self.log_width, self.value).into() } fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - if typ.clone() == int_custom_type(self.width) { + if typ.clone() == int_custom_type(type_arg(self.log_width)) { Ok(()) } else { Err(CustomCheckFailure::Message( @@ -137,10 +142,10 @@ impl CustomConst for ConstIntU { #[typetag::serde] impl CustomConst for ConstIntS { fn name(&self) -> SmolStr { - format!("i{}({})", self.width, self.value).into() + format!("i{}({})", self.log_width, self.value).into() } fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - if typ.clone() == int_custom_type(self.width) { + if typ.clone() == int_custom_type(type_arg(self.log_width)) { Ok(()) } else { Err(CustomCheckFailure::Message( @@ -160,7 +165,7 @@ pub fn extension() -> Extension { extension .add_type( INT_TYPE_ID, - vec![TypeParam::USize], + vec![LOG_WIDTH_TYPE_PARAM], "integral value of a given bit width".to_owned(), TypeBound::Copyable.into(), ) @@ -185,35 +190,28 @@ mod test { #[test] fn test_int_widths() { - let type_arg_32 = TypeArg::USize(32); - assert_matches!(get_width(&type_arg_32), Ok(32)); - - let type_arg_33 = TypeArg::USize(33); - assert_matches!( - get_width(&type_arg_33), - Err(SignatureError::TypeArgMismatch(_)) - ); - - let type_arg_128 = TypeArg::USize(128); - assert_matches!(get_width(&type_arg_128), Ok(128)); + let type_arg_32 = TypeArg::BoundedNat(5); + assert_matches!(get_log_width(&type_arg_32), Ok(5)); - let type_arg_256 = TypeArg::USize(256); + let type_arg_128 = TypeArg::BoundedNat(7); + assert_matches!(get_log_width(&type_arg_128), Ok(7)); + let type_arg_256 = TypeArg::BoundedNat(8); assert_matches!( - get_width(&type_arg_256), - Err(SignatureError::TypeArgMismatch(_)) + get_log_width(&type_arg_256), + Err(TypeArgError::TypeMismatch { .. }) ); } #[test] fn test_int_consts() { - let const_u32_7 = ConstIntU::new(32, 7); - let const_u64_7 = ConstIntU::new(64, 7); - let const_u32_8 = ConstIntU::new(32, 8); + let const_u32_7 = ConstIntU::new(5, 7); + let const_u64_7 = ConstIntU::new(6, 7); + let const_u32_8 = ConstIntU::new(5, 8); assert_ne!(const_u32_7, const_u64_7); assert_ne!(const_u32_7, const_u32_8); - assert_eq!(const_u32_7, ConstIntU::new(32, 7)); + assert_eq!(const_u32_7, ConstIntU::new(5, 7)); assert_matches!( - ConstIntU::new(8, 256), + ConstIntU::new(3, 256), Err(ConstTypeError::CustomCheckFail(_)) ); assert_matches!( @@ -221,9 +219,9 @@ mod test { Err(ConstTypeError::CustomCheckFail(_)) ); assert_matches!( - ConstIntS::new(8, 128), + ConstIntS::new(3, 128), Err(ConstTypeError::CustomCheckFail(_)) ); - assert_matches!(ConstIntS::new(8, -128), Ok(_)); + assert_matches!(ConstIntS::new(3, -128), Ok(_)); } } diff --git a/src/std_extensions/arithmetic/mod.rs b/src/std_extensions/arithmetic/mod.rs index 1d717eed5..76e8263aa 100644 --- a/src/std_extensions/arithmetic/mod.rs +++ b/src/std_extensions/arithmetic/mod.rs @@ -5,3 +5,23 @@ pub mod float_ops; pub mod float_types; pub mod int_ops; pub mod int_types; + +#[cfg(test)] +mod test { + use crate::{ + std_extensions::arithmetic::int_types::{int_type, INT_TYPES}, + types::type_param::TypeArg, + }; + + use super::int_types::MAX_LOG_WIDTH; + + #[test] + fn test_int_types() { + for i in 0..MAX_LOG_WIDTH + 1 { + assert_eq!( + INT_TYPES[i as usize], + int_type(TypeArg::BoundedNat(i as u64)) + ) + } + } +} diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 7c1bb50ee..9718edd9e 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -159,7 +159,9 @@ mod test { .instantiate_concrete([TypeArg::Type(USIZE_T)]) .unwrap(); - assert!(list_def.instantiate_concrete([TypeArg::USize(3)]).is_err()); + assert!(list_def + .instantiate_concrete([TypeArg::BoundedNat(3)]) + .is_err()); list_def.check_custom(&list_type).unwrap(); let list_value = ListValue(vec![ConstUsize::new(3).into()]); diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index c5d45bd62..957985482 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -27,7 +27,7 @@ pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("logic"); /// Extension for basic logical operations. fn extension() -> Extension { - const H_INT: TypeParam = TypeParam::USize; + const H_INT: TypeParam = TypeParam::max_nat(); let mut extension = Extension::new(EXTENSION_ID); extension @@ -53,7 +53,7 @@ fn extension() -> Extension { |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u64 = match a { - TypeArg::USize(n) => *n, + TypeArg::BoundedNat(n) => *n, _ => { return Err(TypeArgError::TypeMismatch { arg: a.clone(), @@ -79,7 +79,7 @@ fn extension() -> Extension { |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u64 = match a { - TypeArg::USize(n) => *n, + TypeArg::BoundedNat(n) => *n, _ => { return Err(TypeArgError::TypeMismatch { arg: a.clone(), @@ -139,7 +139,7 @@ pub(crate) mod test { /// Generate a logic extension and operation over [`crate::prelude::BOOL_T`] pub(crate) fn and_op() -> LeafOp { EXTENSION - .instantiate_extension_op(AND_NAME, [TypeArg::USize(2)]) + .instantiate_extension_op(AND_NAME, [TypeArg::BoundedNat(2)]) .unwrap() .into() } diff --git a/src/types/type_param.rs b/src/types/type_param.rs index 3d15effca..fcf02b103 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -4,6 +4,8 @@ //! //! [`TypeDef`]: crate::extension::TypeDef +use std::num::NonZeroU64; + use thiserror::Error; use crate::extension::ExtensionSet; @@ -12,6 +14,20 @@ use super::CustomType; use super::Type; use super::TypeBound; +#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +/// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] +// A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) +pub struct UpperBound(Option); +impl UpperBound { + fn valid_value(&self, val: u64) -> bool { + match (val, self.0) { + (0, _) | (_, None) => true, + (val, Some(inner)) if NonZeroU64::new(val).unwrap() < inner => true, + _ => false, + } + } +} + /// A parameter declared by an OpDef. Specifies a value /// that must be provided by each operation node. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] @@ -19,8 +35,8 @@ use super::TypeBound; pub enum TypeParam { /// Argument is a [TypeArg::Type]. Type(TypeBound), - /// Argument is a [TypeArg::USize]. - USize, + /// Argument is a [TypeArg::BoundedNat] that is less than the upper bound. + BoundedNat(UpperBound), /// Argument is a [TypeArg::Opaque], defined by a [CustomType]. Opaque(CustomType), /// Argument is a [TypeArg::Sequence]. A list of indeterminate size containing parameters. @@ -33,14 +49,26 @@ pub enum TypeParam { Extensions, } +impl TypeParam { + /// [`TypeParam::BoundedNat`] with the maximum bound (`u64::MAX` + 1) + pub const fn max_nat() -> Self { + Self::BoundedNat(UpperBound(None)) + } + + /// [`TypeParam::BoundedNat`] with the stated upper bound (non-exclusive) + pub const fn bounded_nat(upper_bound: NonZeroU64) -> Self { + Self::BoundedNat(UpperBound(Some(upper_bound))) + } +} + /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] pub enum TypeArg { /// Where the (Type/Op)Def declares that an argument is a [TypeParam::Type] Type(Type), - /// Instance of [TypeParam::USize]. 64-bit unsigned integer. - USize(u64), + /// Instance of [TypeParam::BoundedNat]. 64-bit unsigned integer. + BoundedNat(u64), ///Instance of [TypeParam::Opaque] An opaque value, stored as serialized blob. Opaque(CustomTypeArg), /// Instance of [TypeParam::List] or [TypeParam::Tuple], defined by a @@ -92,7 +120,10 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr .try_for_each(|(arg, param)| check_type_arg(arg, param)) } } - (TypeArg::USize(_), TypeParam::USize) => Ok(()), + (TypeArg::BoundedNat(val), TypeParam::BoundedNat(bound)) if bound.valid_value(*val) => { + Ok(()) + } + (TypeArg::Opaque(arg), TypeParam::Opaque(param)) if param.bound() == TypeBound::Eq && &arg.typ == param => {