From 0d333fb0b9aeedc586a27a1df17e76cfd556456a Mon Sep 17 00:00:00 2001 From: Terence Liu Date: Tue, 13 Aug 2024 16:01:14 -0400 Subject: [PATCH] broaden dep versions; use half num-traits feature --- Cargo.lock | 12 ++++++++++-- Cargo.toml | 21 ++++++++++----------- src/core.rs | 46 ++++++---------------------------------------- 3 files changed, 26 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8aa2783..4c741da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -166,6 +166,7 @@ checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ "cfg-if", "crunchy", + "num-traits", ] [[package]] @@ -184,11 +185,17 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "matrixmultiply" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" dependencies = [ "autocfg", "rawpointer", @@ -254,6 +261,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 28ef1a8..cc6774f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,19 +12,18 @@ keywords = ["msgpack", "numpy", "serde", "serialization", "ndarray"] categories = ["encoding", "science"] [dependencies] -anyhow = "1" -half = "2" -ndarray = "0.15" -num-traits = "0.2" -serde = { version = "1", features = ["derive"] } -thiserror = "1" -rmp-serde = "1.3" -serde_bytes = "0.11" +anyhow = "^1" +thiserror = "^1" +half = { version = "^2", features = ["num-traits"] } +num-traits = ">=0.2, <1" +ndarray = ">=0.15, <1" +serde = { version = "^1", features = ["derive"] } +serde_bytes = ">=0.11, <1" +rmp-serde = "^1.1" [dev-dependencies] -rmp-serde = "1.3" -ctor = "0.2" -rstest = "0.21" +ctor = ">=0.2, <1" +rstest = ">=0.21, <1" [[bin]] name = "test_helpers_serialize" diff --git a/src/core.rs b/src/core.rs index 392d7ef..3521b69 100644 --- a/src/core.rs +++ b/src/core.rs @@ -46,9 +46,8 @@ impl Scalar { self.to() } - // f16 doesn't implement NumCast, so we need to convert it to f32 first pub fn to_f16(&self) -> Option { - self.to::().map(f16::from_f32) + self.to() } pub fn to_u32(&self) -> Option { @@ -83,8 +82,7 @@ impl Scalar { Scalar::I8(v) => NumCast::from(*v), Scalar::U16(v) => NumCast::from(*v), Scalar::I16(v) => NumCast::from(*v), - // f16 doesn't implement ToPrimitive, so we need to convert it to f32 first - Scalar::F16(v) => NumCast::from(v.to_f32()), + Scalar::F16(v) => NumCast::from(*v), Scalar::U32(v) => NumCast::from(*v), Scalar::I32(v) => NumCast::from(*v), Scalar::F32(v) => NumCast::from(*v), @@ -158,14 +156,7 @@ impl NDArray { pub fn into_f16_array(self) -> Option> { match self { NDArray::F16(arr) => Some(arr), - // round trip through f32 if not already f16 - _ => self.convert_into::().map(|arr| { - Array::from_shape_vec( - arr.raw_dim(), - arr.into_iter().map(f16::from_f32).collect::>(), - ) - .unwrap() - }), + _ => self.convert_into::(), } } @@ -218,7 +209,7 @@ impl NDArray { NDArray::I8(arr) => Self::convert_array(arr), NDArray::U16(arr) => Self::convert_array(arr), NDArray::I16(arr) => Self::convert_array(arr), - NDArray::F16(arr) => Self::convert_f16_array(arr), + NDArray::F16(arr) => Self::convert_array(arr), NDArray::U32(arr) => Self::convert_array(arr), NDArray::I32(arr) => Self::convert_array(arr), NDArray::F32(arr) => Self::convert_array(arr), @@ -248,15 +239,6 @@ impl NDArray { .ok() .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap()) } - - fn convert_f16_array(arr: Array) -> Option> { - let raw_dim = arr.raw_dim(); - arr.into_iter() - .map(|v| NumCast::from(v.to_f32()).ok_or(())) - .collect::, _>>() - .ok() - .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap()) - } } /*********************************************************************************/ @@ -322,14 +304,7 @@ impl<'a> CowNDArray<'a> { match self { CowNDArray::F16(arr) => Some(arr), // round trip through f32 if not already f16 - _ => self.convert_into::().map(|arr| { - Array::from_shape_vec( - arr.raw_dim(), - arr.into_iter().map(f16::from_f32).collect::>(), - ) - .unwrap() - .into() - }), + _ => self.convert_into::(), } } @@ -382,7 +357,7 @@ impl<'a> CowNDArray<'a> { CowNDArray::I8(arr) => Self::convert_array(arr), CowNDArray::U16(arr) => Self::convert_array(arr), CowNDArray::I16(arr) => Self::convert_array(arr), - CowNDArray::F16(arr) => Self::convert_f16_array(arr), + CowNDArray::F16(arr) => Self::convert_array(arr), CowNDArray::U32(arr) => Self::convert_array(arr), CowNDArray::I32(arr) => Self::convert_array(arr), CowNDArray::F32(arr) => Self::convert_array(arr), @@ -412,13 +387,4 @@ impl<'a> CowNDArray<'a> { .ok() .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into()) } - - fn convert_f16_array(arr: CowArray) -> Option> { - let raw_dim = arr.raw_dim(); - arr.into_iter() - .map(|v| NumCast::from(v.to_f32()).ok_or(())) - .collect::, _>>() - .ok() - .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into()) - } }