From 0ff68af456a27c727662719bb89a4b695429867c Mon Sep 17 00:00:00 2001 From: IgorSusmelj Date: Sun, 15 Sep 2024 21:29:37 +0200 Subject: [PATCH] Add concatenate, exp, log and sigmoid --- README.md | 6 +- bindings/python/rustynum/__init__.py | 106 ++++++++ bindings/python/src/lib.rs | 102 +++++++ bindings/python/tests/test_basics.py | 52 ++++ rustynum-rs/src/num_array/linalg.rs | 13 +- rustynum-rs/src/num_array/num_array.rs | 336 +++++++++++++++++++++++- rustynum-rs/src/num_array/operations.rs | 36 ++- rustynum-rs/src/traits.rs | 36 ++- 8 files changed, 676 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9b51286..47de6cd 100644 --- a/README.md +++ b/README.md @@ -91,8 +91,12 @@ RustyNum offers a variety of numerical operations and data types, with more feat | Mean | `np.mean(a)` | `rnp.mean(a)` | | Min | `np.min(a)` | `rnp.min(a)` | | Max | `np.max(a)` | `rnp.max(a)` | +| Exp | `np.exp(a)` | `rnp.exp(a)` | +| Log | `np.log(a)` | `rnp.log(a)` | +| Sigmoid | `1 / (1 + np.exp(-a))` | `rnp.sigmoid(a)` | | Dot Product | `np.dot(a, b)` | `rnp.dot(a, b)` | | Reshape | `a.reshape((2, 3))` | `a.reshape([2, 3])` | +| Concatenate | `np.concatenate([a,b], axis=0)` | `rnp.concatenate([a,b], axis=0)` | | Element-wise Add | `a + b` | `a + b` | | Element-wise Sub | `a - b` | `a - b` | | Element-wise Mul | `a * b` | `a * b` | @@ -187,7 +191,7 @@ Planned Features: - N-dimensional arrays - Useful for filters, image processing, and machine learning -- Additional operations: concat, exp, sigmoid, log, median, argmin, argmax, sort, std, var, zeros, cumsum, interp +- Additional operations: median, argmin, argmax, sort, std, var, zeros, cumsum, interp - Integer support - Extended shaping and reshaping capabilities - C++ and WASM bindings diff --git a/bindings/python/rustynum/__init__.py b/bindings/python/rustynum/__init__.py index 98ee8fb..759eab1 100644 --- a/bindings/python/rustynum/__init__.py +++ b/bindings/python/rustynum/__init__.py @@ -407,6 +407,58 @@ def tolist(self) -> Union[List[float], List[List[float]]]: flat_list[i * shape[1] : (i + 1) * shape[1]] for i in range(shape[0]) ] + def exp(self) -> "NumArray": + """ + Computes the exponential of all elements in the NumArray. + + Returns: + A new NumArray with the exponential of all elements. + """ + return NumArray(self.inner.exp(), dtype=self.dtype) + + def log(self) -> "NumArray": + """ + Computes the natural logarithm of all elements in the NumArray. + + Returns: + A new NumArray with the natural logarithm of all elements. + """ + return NumArray(self.inner.log(), dtype=self.dtype) + + def sigmoid(self) -> "NumArray": + """ + Computes the sigmoid of all elements in the NumArray. + + Returns: + A new NumArray with the sigmoid of all elements. + """ + return NumArray(self.inner.sigmoid(), dtype=self.dtype) + + def concatenate(self, other: "NumArray", axis: int) -> "NumArray": + """ + Concatenates the NumArray with another NumArray along the specified axis. + + Parameters: + other: Another NumArray to concatenate with. + axis: Axis along which to concatenate. + + Returns: + A new NumArray containing the concatenated data. + """ + if self.dtype != other.dtype: + raise ValueError("dtype mismatch between arrays") + if self.shape[1 - axis] != other.shape[1 - axis]: + raise ValueError("Arrays must have the same shape along the specified axis") + + if self.dtype == "float32": + result = _rustynum.concatenate_f32([self.inner, other.inner], axis) + elif self.dtype == "float64": + result = _rustynum.concatenate_f64([self.inner, other.inner], axis) + else: + raise ValueError("Unsupported dtype for concatenation") + + return NumArray(result, dtype=self.dtype) + def zeros(shape: List[int], dtype: str = "float32") -> "NumArray": """ @@ -531,3 +583,57 @@ def dot(a: "NumArray", b: "NumArray") -> Union[float, "NumArray"]: return NumArray([out], dtype="float32").item() else: raise TypeError("Both arguments must be NumArray instances.") + + +def exp(a: "NumArray") -> "NumArray": + if isinstance(a, NumArray): + return a.exp() + elif isinstance(a, (int, float)): + return NumArray([a], dtype="float32").exp() + else: + raise TypeError( + "Unsupported operand type for exp: '{}'".format(type(a).__name__) + ) + + +def log(a: "NumArray") -> "NumArray": + if isinstance(a, NumArray): + return a.log() + elif isinstance(a, (int, float)): + return NumArray([a], dtype="float32").log() + else: + raise TypeError( + "Unsupported operand type for log: '{}'".format(type(a).__name__) + ) + + +def sigmoid(a: "NumArray") -> "NumArray": + if isinstance(a, NumArray): + return a.sigmoid() + elif isinstance(a, (int, float)): + return NumArray([a], dtype="float32").sigmoid() + else: + raise TypeError( + "Unsupported operand type for sigmoid: '{}'".format(type(a).__name__) + ) + + +def concatenate(arrays: List["NumArray"], axis: int) -> "NumArray": + # axis can be any integer, but most of the time it would only be 0 or 1 + if not all(isinstance(a, NumArray) for a in arrays): + raise TypeError("All elements in 'arrays' must be NumArray instances.") + if not all(a.dtype == arrays[0].dtype for a in arrays): + raise ValueError("dtype mismatch between arrays") + if not all(a.shape[1 - axis] == arrays[0].shape[1 - axis] for a in arrays): + raise ValueError("Arrays must have the same shape along the specified axis") + + if arrays[0].dtype == "float32": + return NumArray( + _rustynum.concatenate_f32([a.inner for a in arrays], axis), dtype="float32" + ) + elif arrays[0].dtype == "float64": + return NumArray( + _rustynum.concatenate_f64([a.inner for a in arrays], axis), dtype="float64" + ) + else: + raise ValueError("Unsupported dtype for concatenation") diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 463c63a..a397b01 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -5,11 +5,13 @@ use pyo3::wrap_pyfunction; use rustynum_rs::{NumArray32, NumArray64}; #[pyclass] +#[derive(Clone)] struct PyNumArray32 { inner: NumArray32, } #[pyclass] +#[derive(Clone)] struct PyNumArray64 { inner: NumArray64, } @@ -124,6 +126,24 @@ impl PyNumArray32 { inner: self.inner.reshape(&shape), }) } + + fn exp(&self) -> PyNumArray32 { + PyNumArray32 { + inner: self.inner.exp(), + } + } + + fn log(&self) -> PyNumArray32 { + PyNumArray32 { + inner: self.inner.log(), + } + } + + fn sigmoid(&self) -> PyNumArray32 { + PyNumArray32 { + inner: self.inner.sigmoid(), + } + } } #[pymethods] @@ -234,6 +254,24 @@ impl PyNumArray64 { inner: self.inner.reshape(&shape), }) } + + fn exp(&self) -> PyNumArray64 { + PyNumArray64 { + inner: self.inner.exp(), + } + } + + fn log(&self) -> PyNumArray64 { + PyNumArray64 { + inner: self.inner.log(), + } + } + + fn sigmoid(&self) -> PyNumArray64 { + PyNumArray64 { + inner: self.inner.sigmoid(), + } + } } #[pyfunction] @@ -313,6 +351,34 @@ fn max_f32(a: &PyNumArray32) -> PyResult { Ok(a.inner.max()) } +#[pyfunction] +fn exp_f32(a: &PyNumArray32) -> PyNumArray32 { + PyNumArray32 { + inner: a.inner.exp(), + } +} + +#[pyfunction] +fn log_f32(a: &PyNumArray32) -> PyNumArray32 { + PyNumArray32 { + inner: a.inner.log(), + } +} + +#[pyfunction] +fn sigmoid_f32(a: &PyNumArray32) -> PyNumArray32 { + PyNumArray32 { + inner: a.inner.sigmoid(), + } +} + +#[pyfunction] +fn concatenate_f32(arrays: Vec, axis: usize) -> PyResult { + let rust_arrays: Vec = arrays.iter().map(|array| array.inner.clone()).collect(); + let result = NumArray32::concatenate(&rust_arrays, axis); + Ok(PyNumArray32 { inner: result }) +} + #[pyfunction] fn zeros_f64(shape: Vec) -> PyResult { Python::with_gil(|py| { @@ -390,6 +456,34 @@ fn max_f64(a: &PyNumArray64) -> PyResult { Ok(a.inner.max()) } +#[pyfunction] +fn exp_f64(a: &PyNumArray64) -> PyNumArray64 { + PyNumArray64 { + inner: a.inner.exp(), + } +} + +#[pyfunction] +fn log_f64(a: &PyNumArray64) -> PyNumArray64 { + PyNumArray64 { + inner: a.inner.log(), + } +} + +#[pyfunction] +fn sigmoid_f64(a: &PyNumArray64) -> PyNumArray64 { + PyNumArray64 { + inner: a.inner.sigmoid(), + } +} + +#[pyfunction] +fn concatenate_f64(arrays: Vec, axis: usize) -> PyResult { + let rust_arrays: Vec = arrays.iter().map(|array| array.inner.clone()).collect(); + let result = NumArray64::concatenate(&rust_arrays, axis); + Ok(PyNumArray64 { inner: result }) +} + #[pymodule] fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; @@ -403,6 +497,10 @@ fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(mean_f32, m)?)?; m.add_function(wrap_pyfunction!(min_f32, m)?)?; m.add_function(wrap_pyfunction!(max_f32, m)?)?; + m.add_function(wrap_pyfunction!(exp_f32, m)?)?; + m.add_function(wrap_pyfunction!(log_f32, m)?)?; + m.add_function(wrap_pyfunction!(sigmoid_f32, m)?)?; + m.add_function(wrap_pyfunction!(concatenate_f32, m)?)?; m.add_function(wrap_pyfunction!(zeros_f64, m)?)?; m.add_function(wrap_pyfunction!(ones_f64, m)?)?; @@ -413,6 +511,10 @@ fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(mean_f64, m)?)?; m.add_function(wrap_pyfunction!(min_f64, m)?)?; m.add_function(wrap_pyfunction!(max_f64, m)?)?; + m.add_function(wrap_pyfunction!(exp_f64, m)?)?; + m.add_function(wrap_pyfunction!(log_f64, m)?)?; + m.add_function(wrap_pyfunction!(sigmoid_f64, m)?)?; + m.add_function(wrap_pyfunction!(concatenate_f64, m)?)?; Ok(()) } diff --git a/bindings/python/tests/test_basics.py b/bindings/python/tests/test_basics.py index 873998c..af2466f 100644 --- a/bindings/python/tests/test_basics.py +++ b/bindings/python/tests/test_basics.py @@ -28,3 +28,55 @@ def test_linspace(): b = np.linspace(0, 10, 5, dtype="float32") assert a.tolist() == [0.0, 2.5, 5.0, 7.5, 10.0], "Linspace failed" assert np.allclose(a.tolist(), b, atol=1e-9), "Linspace failed" + + +def test_exp(): + a = rnp.NumArray([0.0, 1.0, 2.0, 3.0], dtype="float32") + b = np.exp(np.array([0.0, 1.0, 2.0, 3.0], dtype="float32")) + assert np.allclose(a.exp().tolist(), b, atol=1e-9), "Exp failed" + + +def test_log(): + a = rnp.NumArray([1.0, 2.0, 4.0, 8.0], dtype="float32") + b = np.log(np.array([1.0, 2.0, 4.0, 8.0], dtype="float32")) + assert np.allclose(a.log().tolist(), b, atol=1e-9), "Log failed" + + +def test_sigmoid(): + a = rnp.NumArray([0.0, 1.0, 2.0, 3.0], dtype="float32") + b = 1 / (1 + np.exp(-np.array([0.0, 1.0, 2.0, 3.0], dtype="float32"))) + assert np.allclose(a.sigmoid().tolist(), b, atol=1e-9), "Sigmoid failed" + + +def test_concatenate_along_axis_0(): + a = rnp.NumArray([[1.0, 2.0], [3.0, 4.0]], dtype="float32") + b = rnp.NumArray([[5.0, 6.0], [7.0, 8.0]], dtype="float32") + c = np.concatenate( + [ + np.array([[1.0, 2.0], [3.0, 4.0]], dtype="float32"), + np.array([[5.0, 6.0], [7.0, 8.0]], dtype="float32"), + ], + axis=0, + ) + + assert rnp.concatenate([a, b], axis=0).shape == c.shape, "Shape mismatch" + assert np.allclose( + rnp.concatenate([a, b], axis=0).tolist(), c, atol=1e-9 + ), "Concatenate failed" + + +def test_concatenate_along_axis_1(): + a = rnp.NumArray([[1.0, 2.0], [3.0, 4.0]], dtype="float32") + b = rnp.NumArray([[5.0, 6.0], [7.0, 8.0]], dtype="float32") + c = np.concatenate( + [ + np.array([[1.0, 2.0], [3.0, 4.0]], dtype="float32"), + np.array([[5.0, 6.0], [7.0, 8.0]], dtype="float32"), + ], + axis=1, + ) + + assert rnp.concatenate([a, b], axis=1).shape == c.shape, "Shape mismatch" + assert np.allclose( + rnp.concatenate([a, b], axis=1).tolist(), c, atol=1e-9 + ), "Concatenate failed" diff --git a/rustynum-rs/src/num_array/linalg.rs b/rustynum-rs/src/num_array/linalg.rs index 2a45a00..353a1e0 100644 --- a/rustynum-rs/src/num_array/linalg.rs +++ b/rustynum-rs/src/num_array/linalg.rs @@ -5,9 +5,9 @@ use super::num_array::{NumArray, NumArray32, NumArray64}; use std::iter::Sum; use crate::simd_ops::SimdOps; -use crate::traits::{FromU32, FromUsize, NumOps}; +use crate::traits::{ExpLog, FromU32, FromUsize, NumOps}; use std::fmt::Debug; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, Neg, Sub}; /// Performs matrix-vector multiplication. /// @@ -35,6 +35,9 @@ where + PartialOrd + FromU32 + FromUsize + + FromUsize + + ExpLog + + Neg + NumOps + Debug, Ops: SimdOps, @@ -83,6 +86,9 @@ where + PartialOrd + FromU32 + FromUsize + + FromUsize + + ExpLog + + Neg + NumOps + Debug, Ops: SimdOps, @@ -133,6 +139,9 @@ where + PartialOrd + FromU32 + FromUsize + + FromUsize + + ExpLog + + Neg + NumOps + Debug, Ops: SimdOps, diff --git a/rustynum-rs/src/num_array/num_array.rs b/rustynum-rs/src/num_array/num_array.rs index 3f5b3a6..9fa3ae0 100644 --- a/rustynum-rs/src/num_array/num_array.rs +++ b/rustynum-rs/src/num_array/num_array.rs @@ -27,12 +27,12 @@ use std::fmt::Debug; use std::iter::Sum; use std::marker::PhantomData; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, Neg, Sub}; use std::simd::{f32x16, f64x8}; use crate::num_array::linalg::matrix_multiply; use crate::simd_ops::SimdOps; -use crate::traits::{FromU32, FromUsize, NumOps}; +use crate::traits::{ExpLog, FromU32, FromUsize, NumOps}; pub type NumArray32 = NumArray; pub type NumArray64 = NumArray; @@ -73,7 +73,7 @@ where impl NumArray where - T: Copy + Debug, // Only require what's absolutely necessary for this operation + T: Copy + Debug + Neg, // Only require what's absolutely necessary for this operation { /// Creates a new 1D array from the given data. /// @@ -114,6 +114,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -195,6 +197,105 @@ where } } + /// Concatenates multiple `NumArray` instances along the specified axis. + /// + /// # Parameters + /// * `arrays` - A slice of `NumArray` instances to concatenate. + /// * `axis` - The axis along which to concatenate. + /// + /// # Returns + /// A new `NumArray` instance resulting from the concatenation. + /// + /// # Panics + /// Panics if the shapes of the arrays are incompatible for concatenation along the specified axis. + /// + /// # Example + /// ``` + /// use rustynum_rs::NumArray32; + /// + /// let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0], vec![3]); + /// let b = NumArray32::new_with_shape(vec![4.0, 5.0], vec![2]); + /// let concatenated = NumArray32::concatenate(&[a, b], 0); + /// assert_eq!(concatenated.get_data(), &[1.0, 2.0, 3.0, 4.0, 5.0]); + /// ``` + pub fn concatenate(arrays: &[Self], axis: usize) -> Self { + // Ensure there is at least one array to concatenate + assert!( + !arrays.is_empty(), + "At least one array must be provided for concatenation." + ); + + // Determine the reference shape from the first array + let reference_shape = arrays[0].shape(); + + // Validate that all arrays have the same number of dimensions + let ndim = reference_shape.len(); + assert!( + axis < ndim, + "Concatenation axis {} is out of bounds for arrays with {} dimensions.", + axis, + ndim + ); + for array in arrays.iter() { + assert!( + array.shape().len() == ndim, + "All arrays must have the same number of dimensions." + ); + // Validate that shapes match on all axes except the concatenation axis + for (i, (&dim_ref, &dim_other)) in + reference_shape.iter().zip(array.shape().iter()).enumerate() + { + if i != axis { + assert!( + dim_ref == dim_other, + "All arrays must have the same shape except along the concatenation axis. Mismatch found at axis {}.", + i + ); + } + } + } + + // Compute the new shape + let mut new_shape = reference_shape.to_vec(); + let total_concat_dim: usize = arrays.iter().map(|array| array.shape()[axis]).sum(); + new_shape[axis] = total_concat_dim; + + // Compute elements_before_axis and elements_after_axis + let elements_before_axis: usize = reference_shape.iter().take(axis).product(); + let elements_after_axis: usize = reference_shape.iter().skip(axis + 1).product(); + + // Initialize the new data vector with the appropriate capacity + let total_size: usize = new_shape.iter().product(); + let mut concatenated_data = Vec::with_capacity(total_size); + + // Iterate over each outer slice and concatenate data from all arrays + for outer in 0..elements_before_axis { + for array in arrays.iter() { + let axis_size = array.shape()[axis]; + let slice_size = axis_size * elements_after_axis; + + // Calculate the start and end indices for the current slice + let start = outer * axis_size * elements_after_axis; + let end = start + slice_size; + + // Safety check to prevent out-of-bounds access + assert!( + end <= array.data.len(), + "Slice indices out of bounds. Attempted to access {}..{} in an array with length {}.", + start, + end, + array.data.len() + ); + + // Append the slice to the concatenated data + concatenated_data.extend_from_slice(&array.data[start..end]); + } + } + + // Create and return the new NumArray with the concatenated data and new shape + Self::new_with_shape(concatenated_data, new_shape) + } + /// Transposes a 2D matrix from row-major to column-major format. /// /// # Returns @@ -566,6 +667,80 @@ where Ops::max(&self.data) } + /// Applies the exponential function to each element of the `NumArray`. + /// + /// # Returns + /// A new `NumArray` instance where each element is the exponential of the corresponding element in the original array. + /// + /// # Example + /// ``` + /// use rustynum_rs::NumArray32; + /// + /// let array = NumArray32::new(vec![0.0, 1.0, 2.0]); + /// let exp_array = array.exp(); + /// assert_eq!(exp_array.get_data(), &[1.0, 2.7182817, 7.389056]); + /// ``` + pub fn exp(&self) -> Self { + let exp_data = self.data.iter().map(|&x| x.exp()).collect::>(); + Self::new_with_shape(exp_data, self.shape.clone()) + } + + /// Applies the natural logarithm to each element of the `NumArray`. + /// + /// # Returns + /// A new `NumArray` instance where each element is the natural logarithm of the corresponding element in the original array. + /// + /// # Panics + /// Panics if any element in the array is non-positive, as the logarithm is undefined for such values. + /// + /// # Example + /// ``` + /// use rustynum_rs::NumArray32; + /// + /// let array = NumArray32::new(vec![1.0, 2.718282, 7.389056]); + /// let log_array = array.log(); + /// assert_eq!(log_array.get_data(), &[0.0, 1.0, 2.0]); + /// ``` + pub fn log(&self) -> Self { + // Ensure all elements are positive + for &x in &self.data { + assert!( + x > T::from_u32(0), + "Logarithm undefined for non-positive values." + ); + } + + let log_data = self.data.iter().map(|&x| x.log()).collect::>(); + Self::new_with_shape(log_data, self.shape.clone()) + } + + /// Applies the sigmoid function to each element of the `NumArray`. + /// + /// The sigmoid function is defined as `1 / (1 + exp(-x))` for each element `x`. + /// + /// # Returns + /// A new `NumArray` instance where each element is the sigmoid of the corresponding element in the original array. + /// + /// # Example + /// ``` + /// use rustynum_rs::NumArray32; + /// + /// let array = NumArray32::new(vec![0.0, 2.0, -2.0]); + /// let sigmoid_array = array.sigmoid(); + /// let expected = vec![0.5, 0.880797, 0.119203]; + /// for (computed, &exp_val) in sigmoid_array.get_data().iter().zip(expected.iter()) { + /// assert!((computed - exp_val).abs() < 1e-5, "Expected {}, got {}", exp_val, computed); + /// } + /// ``` + pub fn sigmoid(&self) -> Self { + let sigmoid_data = self + .data + .iter() + .map(|&x| T::from_u32(1) / (T::from_u32(1) + (-x).exp())) + .collect::>(); + Self::new_with_shape(sigmoid_data, self.shape.clone()) + } + /// Normalizes the array. /// /// # Returns @@ -855,6 +1030,106 @@ mod tests { assert_eq!(ones_array.get_data(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); } + #[test] + fn test_concatenate_1d_arrays() { + let a = NumArray32::new(vec![1.0, 2.0, 3.0]); + let b = NumArray32::new(vec![4.0, 5.0]); + let c = NumArray32::new(vec![6.0]); + + let concatenated = NumArray32::concatenate(&[a.clone(), b.clone(), c.clone()], 0); + assert_eq!(concatenated.shape(), &[6]); + assert_eq!(concatenated.get_data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + } + + #[test] + fn test_concatenate_2d_arrays_axis0() { + let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]); + let b = NumArray32::new_with_shape(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]); + + let concatenated = NumArray32::concatenate(&[a.clone(), b.clone()], 0); + assert_eq!(concatenated.shape(), &[4, 2]); + assert_eq!( + concatenated.get_data(), + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + ); + } + + #[test] + fn test_concatenate_2d_arrays_axis1() { + let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]); + let b = NumArray32::new_with_shape(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]); + + let concatenated = NumArray32::concatenate(&[a.clone(), b.clone()], 1); + assert_eq!(concatenated.shape(), &[2, 4]); + assert_eq!( + concatenated.get_data(), + &[1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0] + ); + } + + #[test] + fn test_concatenate_multiple_2d_arrays() { + let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]); + let b = NumArray32::new_with_shape(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]); + let c = NumArray32::new_with_shape(vec![9.0, 10.0, 11.0, 12.0], vec![2, 2]); + + let concatenated = NumArray32::concatenate(&[a.clone(), b.clone(), c.clone()], 0); + assert_eq!(concatenated.shape(), &[6, 2]); + assert_eq!( + concatenated.get_data(), + &[ + 1.0, 2.0, // a + 3.0, 4.0, 5.0, 6.0, // b + 7.0, 8.0, 9.0, 10.0, // c + 11.0, 12.0 + ] + ); + } + + #[test] + fn test_concatenate_incompatible_shapes() { + let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0], vec![3, 1]); + let b = NumArray32::new_with_shape(vec![4.0, 5.0], vec![2, 1]); + + // Attempt to concatenate along axis 0 (rows) + let concatenated = NumArray32::concatenate(&[a.clone(), b.clone()], 0); + assert_eq!(concatenated.shape(), &[5, 1]); + assert_eq!(concatenated.get_data(), &[1.0, 2.0, 3.0, 4.0, 5.0]); + + // Attempt to concatenate along axis 1 (columns) should panic due to mismatched row sizes + let a = NumArray32::new_with_shape(vec![1.0, 2.0, 3.0], vec![3, 1]); + let b = NumArray32::new_with_shape(vec![4.0, 5.0, 6.0, 7.0], vec![4, 1]); + + let result = std::panic::catch_unwind(|| NumArray32::concatenate(&[a, b], 1)); + assert!( + result.is_err(), + "Concatenation should fail due to incompatible shapes." + ); + } + + #[test] + fn test_concatenate_empty_input() { + // Attempt to concatenate with an empty slice should panic + let result = std::panic::catch_unwind(|| NumArray32::concatenate(&[], 0)); + assert!( + result.is_err(), + "Concatenation should fail when no arrays are provided." + ); + } + + #[test] + fn test_concatenate_different_dimensions() { + let a = NumArray32::new(vec![1.0, 2.0, 3.0]); // Shape: [3] (1D) + let b = NumArray32::new_with_shape(vec![4.0, 5.0], vec![1, 2]); // Shape: [1, 2] (2D) + + // Attempt to concatenate arrays with different dimensions should panic + let result = std::panic::catch_unwind(|| NumArray32::concatenate(&[a, b], 0)); + assert!( + result.is_err(), + "Concatenation should fail due to differing dimensions." + ); + } + #[test] fn test_matrix_transpose() { let matrix = NumArray32::new_with_shape( @@ -1187,4 +1462,59 @@ mod tests { let column_slice = array.column_slice(1); assert_eq!(column_slice, vec![2.0, 5.0]); } + + #[test] + fn test_exp_f32() { + let array = NumArray32::new(vec![0.0, 1.0, 2.0]); + let exp_array = array.exp(); + // Using approximate values for floating-point comparisons + let expected = vec![1.0, 2.7182817, 7.389056]; + for (computed, &exp_val) in exp_array.get_data().iter().zip(expected.iter()) { + assert!( + (computed - exp_val).abs() < 1e-5, + "Expected {}, got {}", + exp_val, + computed + ); + } + } + + #[test] + fn test_log_f32() { + let array = NumArray32::new(vec![1.0, 2.7182817, 7.389056]); + let log_array = array.log(); + // Using approximate values for floating-point comparisons + let expected = vec![0.0, 1.0, 2.0]; + for (computed, &log_val) in log_array.get_data().iter().zip(expected.iter()) { + assert!( + (computed - log_val).abs() < 1e-5, + "Expected {}, got {}", + log_val, + computed + ); + } + } + + #[test] + #[should_panic(expected = "Logarithm undefined for non-positive values.")] + fn test_log_f32_with_non_positive() { + let array = NumArray32::new(vec![1.0, -1.0, 0.0]); + let _ = array.log(); // Should panic + } + + #[test] + fn test_sigmoid_f32() { + let array = NumArray32::new(vec![0.0, 2.0, -2.0]); + let sigmoid_array = array.sigmoid(); + // Using approximate values for floating-point comparisons + let expected = vec![0.5, 0.880797, 0.119203]; + for (computed, &exp_val) in sigmoid_array.get_data().iter().zip(expected.iter()) { + assert!( + (computed - exp_val).abs() < 1e-5, + "Expected {}, got {}", + exp_val, + computed + ); + } + } } diff --git a/rustynum-rs/src/num_array/operations.rs b/rustynum-rs/src/num_array/operations.rs index 9fdcdee..5a8e3b6 100644 --- a/rustynum-rs/src/num_array/operations.rs +++ b/rustynum-rs/src/num_array/operations.rs @@ -1,11 +1,11 @@ #[allow(unused_imports)] use super::num_array::{NumArray, NumArray32, NumArray64}; use crate::simd_ops::SimdOps; -use crate::traits::{FromU32, FromUsize, NumOps}; +use crate::traits::{ExpLog, FromU32, FromUsize, NumOps}; use std::fmt::Debug; use std::iter::Sum; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, Neg, Sub}; impl Add for NumArray where @@ -20,6 +20,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -45,6 +47,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, // Ensure Ops is appropriate for T @@ -70,6 +74,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -100,6 +106,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -130,6 +138,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -155,6 +165,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, // Ensure Ops is appropriate for T @@ -180,6 +192,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -210,6 +224,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -240,6 +256,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -265,6 +283,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, // Ensure Ops is appropriate for T @@ -290,6 +310,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -320,6 +342,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -350,6 +374,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -375,6 +401,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, // Ensure Ops is appropriate for T @@ -400,6 +428,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, @@ -430,6 +460,8 @@ where + PartialOrd + FromU32 + FromUsize + + ExpLog + + Neg + Default + Debug, Ops: SimdOps, diff --git a/rustynum-rs/src/traits.rs b/rustynum-rs/src/traits.rs index 2e066a3..d3a274a 100644 --- a/rustynum-rs/src/traits.rs +++ b/rustynum-rs/src/traits.rs @@ -102,7 +102,6 @@ impl FromU32 for i64 { } } - /// A trait for converting from `usize` to a specific type. pub trait FromUsize { /// Converts a `usize` value to the implementing type. @@ -169,8 +168,6 @@ impl FromUsize for i64 { } } - - pub trait NumOps: Sized + Add + Mul + Sub + Div + Copy { @@ -232,3 +229,36 @@ impl NumOps for i64 { 0 } } + +/// A trait that provides exponential and logarithmic operations. +pub trait ExpLog: Sized { + /// Returns the exponential of the number. + fn exp(self) -> Self; + + /// Returns the natural logarithm of the number. + /// + /// # Panics + /// + /// Panics if the number is non-positive, as the logarithm is undefined for such values. + fn log(self) -> Self; +} + +impl ExpLog for f32 { + fn exp(self) -> Self { + f32::exp(self) + } + + fn log(self) -> Self { + f32::ln(self) + } +} + +impl ExpLog for f64 { + fn exp(self) -> Self { + f64::exp(self) + } + + fn log(self) -> Self { + f64::ln(self) + } +}