From 07b2f58099d18d4e125f1a134a81418c014b511a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 14:59:33 +0100 Subject: [PATCH] feat!: Allow CustomConsts to (optionally) be hashable (#1397) * Add trait TryHash as prereq for CustomConst * Automatically impl'd if your const impl's Hash * Can also trivially implement (i.e. `impl TryHash for Foo { }`) to say "no, not hashable" * Derive Hash for most consts, but not ConstF64 BREAKING CHANGE: any `impl CustomConst` will need to either `impl Hash` or `impl MaybeHash` --- hugr-core/src/extension/prelude.rs | 8 +- hugr-core/src/ops/constant.rs | 75 ++++++++++++++++++- hugr-core/src/ops/constant/custom.rs | 44 +++++++++-- .../std_extensions/arithmetic/float_types.rs | 4 +- .../std_extensions/arithmetic/int_types.rs | 2 +- hugr-core/src/std_extensions/collections.rs | 13 +++- 6 files changed, 131 insertions(+), 15 deletions(-) diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 952c8a7e2..daca033f3 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -149,7 +149,7 @@ pub const STRING_CUSTOM_TYPE: CustomType = /// String type. pub const STRING_TYPE: Type = Type::new_extension(STRING_CUSTOM_TYPE); -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] /// Structure for holding constant string values. pub struct ConstString(String); @@ -329,7 +329,7 @@ pub fn const_fail_tuple( const_left_tuple(values, ty_ok) } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstUsize(u64); @@ -364,7 +364,7 @@ impl CustomConst for ConstUsize { } } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstError { /// Integer tag/signal for the error. @@ -409,7 +409,7 @@ impl CustomConst for ConstError { } } -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] /// A structure for holding references to external symbols. pub struct ConstExternalSymbol { /// The symbol name that this value refers to. Must be nonempty. diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 59a79f8f2..5b9309407 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -2,6 +2,9 @@ mod custom; +use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::hash::{Hash, Hasher}; + use super::{NamedOp, OpName, OpTrait, StaticTag}; use super::{OpTag, OpType}; use crate::extension::ExtensionSet; @@ -16,7 +19,7 @@ use thiserror::Error; pub use custom::{ downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst, - CustomSerialized, + CustomSerialized, TryHash, }; #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -134,6 +137,24 @@ impl Sum { // For valid instances, the type row will not have any row variables. self.sum_type.as_tuple().map(|_| self.values.as_ref()) } + + fn try_hash(&self, st: &mut H) -> bool { + maybe_hash_values(&self.values, st) && { + st.write_usize(self.tag); + self.sum_type.hash(st); + true + } + } +} + +pub(crate) fn maybe_hash_values(vals: &[Value], st: &mut H) -> bool { + // We can't mutate the Hasher with the first element + // if any element, even the last, fails. + let mut hasher = DefaultHasher::new(); + vals.iter().all(|e| e.try_hash(&mut hasher)) && { + st.write_u64(hasher.finish()); + true + } } impl TryFrom for Sum { @@ -508,6 +529,17 @@ impl Value { None } } + + /// Hashes this value, if possible. [Value::Extension]s are hashable according + /// to their implementation of [TryHash]; [Value::Function]s never are; + /// [Value::Sum]s are if their contents are. + pub fn try_hash(&self, st: &mut H) -> bool { + match self { + Value::Extension { e } => e.value().try_hash(&mut *st), + Value::Function { .. } => false, + Value::Sum(s) => s.try_hash(st), + } + } } impl From for Value @@ -527,6 +559,8 @@ pub type ValueNameRef = str; #[cfg(test)] mod test { + use std::collections::HashSet; + use super::Value; use crate::builder::inout_sig; use crate::builder::test::simple_dfg_hugr; @@ -547,7 +581,7 @@ mod test { use super::*; - #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] + #[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] /// A custom constant value used in testing pub(crate) struct CustomTestValue(pub CustomType); @@ -727,6 +761,43 @@ mod test { assert_ne!(json_const.get_type(), t); } + #[rstest] + fn hash_tuple(const_tuple: Value) { + let vals = [ + Value::unit(), + Value::true_val(), + Value::false_val(), + ConstUsize::new(13).into(), + Value::tuple([ConstUsize::new(13).into()]), + Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(14).into()]), + Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(15).into()]), + const_tuple, + ]; + + let num_vals = vals.len(); + let hashes = vals.map(|v| { + let mut h = DefaultHasher::new(); + v.try_hash(&mut h).then_some(()).unwrap(); + h.finish() + }); + assert_eq!(HashSet::from(hashes).len(), num_vals); // all distinct + } + + #[test] + fn unhashable_tuple() { + let tup = Value::tuple([ConstUsize::new(5).into(), ConstF64::new(4.97).into()]); + let mut h1 = DefaultHasher::new(); + let r = tup.try_hash(&mut h1); + assert!(!r); + + // Check that didn't do anything, by checking the hasher behaves + // just like one which never saw the tuple + h1.write_usize(5); + let mut h2 = DefaultHasher::new(); + h2.write_usize(5); + assert_eq!(h1.finish(), h2.finish()); + } + mod proptest { use super::super::{OpaqueValue, Sum}; use crate::{ diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 7e685bc6e..a69f76ec0 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -5,19 +5,17 @@ //! [`Const`]: crate::ops::Const use std::any::Any; +use std::hash::{Hash, Hasher}; use downcast_rs::{impl_downcast, Downcast}; use thiserror::Error; use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; - use crate::types::{CustomCheckFailure, Type}; use crate::IncomingPort; -use super::Value; - -use super::ValueName; +use super::{Value, ValueName}; /// Extensible constant values. /// @@ -37,7 +35,7 @@ use super::ValueName; /// extension::ExtensionSet, std_extensions::arithmetic::int_types}; /// use serde_json::json; /// -/// #[derive(std::fmt::Debug, Clone, Serialize,Deserialize)] +/// #[derive(std::fmt::Debug, Clone, Hash, Serialize,Deserialize)] /// struct CC(i64); /// /// #[typetag::serde] @@ -55,7 +53,7 @@ use super::ValueName; /// ``` #[typetag::serde(tag = "c", content = "v")] pub trait CustomConst: - Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast + Send + Sync + std::fmt::Debug + TryHash + CustomConstBoxClone + Any + Downcast { /// An identifier for the constant. fn name(&self) -> ValueName; @@ -90,6 +88,32 @@ pub trait CustomConst: fn get_type(&self) -> Type; } +/// Prerequisite for `CustomConst`. Allows to declare a custom hash function, +/// but the easiest options are either to `impl TryHash for ... {}` to indicate +/// "not hashable", or else to implement/derive [Hash]. +pub trait TryHash { + /// Hashes the value, if possible; else return `false` without mutating the `Hasher`. + /// This relates with [CustomConst::equal_consts] just like [Hash] with [Eq]: + /// * if `x.equal_consts(y)` ==> `x.try_hash(s)` behaves equivalently to `y.try_hash(s)` + /// * if `x.hash(s)` behaves differently from `y.hash(s)` ==> `x.equal_consts(y) == false` + /// + /// As with [Hash], these requirements can trivially be satisfied by either + /// * `equal_consts` always returning `false`, or + /// * `try_hash` always behaving the same (e.g. returning `false`, as it does by default) + /// + /// Note: uses `dyn` rather than being parametrized by `` to be object-safe. + fn try_hash(&self, _state: &mut dyn Hasher) -> bool { + false + } +} + +impl TryHash for T { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + Hash::hash(self, &mut st); + true + } +} + impl PartialEq for dyn CustomConst { fn eq(&self, other: &Self) -> bool { (*self).equal_consts(other) @@ -253,6 +277,14 @@ impl CustomSerialized { } } +impl TryHash for CustomSerialized { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + // Consistent with equality, same serialization <=> same hash. + self.value.to_string().hash(&mut st); + true + } +} + #[typetag::serde] impl CustomConst for CustomSerialized { fn name(&self) -> ValueName { diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 335009e92..a046ebe0e 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -1,6 +1,6 @@ //! Basic floating-point types -use crate::ops::constant::ValueName; +use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; use crate::{ extension::{ExtensionId, ExtensionSet}, @@ -56,6 +56,8 @@ impl ConstF64 { } } +impl TryHash for ConstF64 {} + #[typetag::serde] impl CustomConst for ConstF64 { fn name(&self) -> ValueName { diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 6ec3b7724..522f8b2b9 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -87,7 +87,7 @@ const fn type_arg(log_width: u8) -> TypeArg { } /// An integer (either signed or unsigned) -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] pub struct ConstInt { log_width: u8, // We always use a u64 for the value. The interpretation is: diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 270f234b6..24e3cfdf1 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,5 +1,7 @@ //! List type and operations. +use std::hash::{Hash, Hasher}; + mod list_fold; use std::str::FromStr; @@ -12,7 +14,7 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::extension::prelude::{either_type, option_type, USIZE_T}; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; -use crate::ops::constant::ValueName; +use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; use crate::ops::{OpName, Value}; use crate::types::{TypeName, TypeRowRV}; use crate::{ @@ -58,6 +60,15 @@ impl ListValue { } } +impl TryHash for ListValue { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + maybe_hash_values(&self.0, &mut st) && { + self.1.hash(&mut st); + true + } + } +} + #[typetag::serde] impl CustomConst for ListValue { fn name(&self) -> ValueName {