Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concatenate, exp, log and sigmoid #10

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ RustyNum offers a variety of numerical operations and data types, with more feat
| Mean | `np.mean(a)` | `rnp.mean(a)` |
| Min | `np.min(a)` | `rnp.min(a)` |
| Max | `np.max(a)` | `rnp.max(a)` |
| Exp | `np.exp(a)` | `rnp.exp(a)` |
| Log | `np.log(a)` | `rnp.log(a)` |
| Sigmoid | `1 / (1 + np.exp(-a))` | `rnp.sigmoid(a)` |
| Dot Product | `np.dot(a, b)` | `rnp.dot(a, b)` |
| Reshape | `a.reshape((2, 3))` | `a.reshape([2, 3])` |
| Concatenate | `np.concatenate([a,b], axis=0)` | `rnp.concatenate([a,b], axis=0)` |
| Element-wise Add | `a + b` | `a + b` |
| Element-wise Sub | `a - b` | `a - b` |
| Element-wise Mul | `a * b` | `a * b` |
Expand Down Expand Up @@ -187,7 +191,7 @@ Planned Features:

- N-dimensional arrays
- Useful for filters, image processing, and machine learning
- Additional operations: concat, exp, sigmoid, log, median, argmin, argmax, sort, std, var, zeros, cumsum, interp
- Additional operations: median, argmin, argmax, sort, std, var, zeros, cumsum, interp
- Integer support
- Extended shaping and reshaping capabilities
- C++ and WASM bindings
Expand Down
106 changes: 106 additions & 0 deletions bindings/python/rustynum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,58 @@ def tolist(self) -> Union[List[float], List[List[float]]]:
flat_list[i * shape[1] : (i + 1) * shape[1]] for i in range(shape[0])
]

def exp(self) -> "NumArray":
"""
Computes the exponential of all elements in the NumArray.

Returns:
A new NumArray with the exponential of all elements.
"""
return NumArray(self.inner.exp(), dtype=self.dtype)

def log(self) -> "NumArray":
"""
Computes the natural logarithm of all elements in the NumArray.

Returns:
A new NumArray with the natural logarithm of all elements.
"""
return NumArray(self.inner.log(), dtype=self.dtype)

def sigmoid(self) -> "NumArray":
"""
Computes the sigmoid of all elements in the NumArray.

Returns:
A new NumArray with the sigmoid of all elements.
"""
return NumArray(self.inner.sigmoid(), dtype=self.dtype)

def concatenate(self, other: "NumArray", axis: int) -> "NumArray":
"""
Concatenates the NumArray with another NumArray along the specified axis.

Parameters:
other: Another NumArray to concatenate with.
axis: Axis along which to concatenate.

Returns:
A new NumArray containing the concatenated data.
"""
if self.dtype != other.dtype:
raise ValueError("dtype mismatch between arrays")
if self.shape[1 - axis] != other.shape[1 - axis]:
raise ValueError("Arrays must have the same shape along the specified axis")

if self.dtype == "float32":
result = _rustynum.concatenate_f32([self.inner, other.inner], axis)
elif self.dtype == "float64":
result = _rustynum.concatenate_f64([self.inner, other.inner], axis)
else:
raise ValueError("Unsupported dtype for concatenation")

return NumArray(result, dtype=self.dtype)


def zeros(shape: List[int], dtype: str = "float32") -> "NumArray":
"""
Expand Down Expand Up @@ -531,3 +583,57 @@ def dot(a: "NumArray", b: "NumArray") -> Union[float, "NumArray"]:
return NumArray([out], dtype="float32").item()
else:
raise TypeError("Both arguments must be NumArray instances.")


def exp(a: "NumArray") -> "NumArray":
if isinstance(a, NumArray):
return a.exp()
elif isinstance(a, (int, float)):
return NumArray([a], dtype="float32").exp()
else:
raise TypeError(
"Unsupported operand type for exp: '{}'".format(type(a).__name__)
)


def log(a: "NumArray") -> "NumArray":
if isinstance(a, NumArray):
return a.log()
elif isinstance(a, (int, float)):
return NumArray([a], dtype="float32").log()
else:
raise TypeError(
"Unsupported operand type for log: '{}'".format(type(a).__name__)
)


def sigmoid(a: "NumArray") -> "NumArray":
if isinstance(a, NumArray):
return a.sigmoid()
elif isinstance(a, (int, float)):
return NumArray([a], dtype="float32").sigmoid()
else:
raise TypeError(
"Unsupported operand type for sigmoid: '{}'".format(type(a).__name__)
)


def concatenate(arrays: List["NumArray"], axis: int) -> "NumArray":
# axis can be any integer, but most of the time it would only be 0 or 1
if not all(isinstance(a, NumArray) for a in arrays):
raise TypeError("All elements in 'arrays' must be NumArray instances.")
if not all(a.dtype == arrays[0].dtype for a in arrays):
raise ValueError("dtype mismatch between arrays")
if not all(a.shape[1 - axis] == arrays[0].shape[1 - axis] for a in arrays):
raise ValueError("Arrays must have the same shape along the specified axis")

if arrays[0].dtype == "float32":
return NumArray(
_rustynum.concatenate_f32([a.inner for a in arrays], axis), dtype="float32"
)
elif arrays[0].dtype == "float64":
return NumArray(
_rustynum.concatenate_f64([a.inner for a in arrays], axis), dtype="float64"
)
else:
raise ValueError("Unsupported dtype for concatenation")
102 changes: 102 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ use pyo3::wrap_pyfunction;
use rustynum_rs::{NumArray32, NumArray64};

#[pyclass]
#[derive(Clone)]
struct PyNumArray32 {
inner: NumArray32,
}

#[pyclass]
#[derive(Clone)]
struct PyNumArray64 {
inner: NumArray64,
}
Expand Down Expand Up @@ -124,6 +126,24 @@ impl PyNumArray32 {
inner: self.inner.reshape(&shape),
})
}

fn exp(&self) -> PyNumArray32 {
PyNumArray32 {
inner: self.inner.exp(),
}
}

fn log(&self) -> PyNumArray32 {
PyNumArray32 {
inner: self.inner.log(),
}
}

fn sigmoid(&self) -> PyNumArray32 {
PyNumArray32 {
inner: self.inner.sigmoid(),
}
}
}

#[pymethods]
Expand Down Expand Up @@ -234,6 +254,24 @@ impl PyNumArray64 {
inner: self.inner.reshape(&shape),
})
}

fn exp(&self) -> PyNumArray64 {
PyNumArray64 {
inner: self.inner.exp(),
}
}

fn log(&self) -> PyNumArray64 {
PyNumArray64 {
inner: self.inner.log(),
}
}

fn sigmoid(&self) -> PyNumArray64 {
PyNumArray64 {
inner: self.inner.sigmoid(),
}
}
}

#[pyfunction]
Expand Down Expand Up @@ -313,6 +351,34 @@ fn max_f32(a: &PyNumArray32) -> PyResult<f32> {
Ok(a.inner.max())
}

#[pyfunction]
fn exp_f32(a: &PyNumArray32) -> PyNumArray32 {
PyNumArray32 {
inner: a.inner.exp(),
}
}

#[pyfunction]
fn log_f32(a: &PyNumArray32) -> PyNumArray32 {
PyNumArray32 {
inner: a.inner.log(),
}
}

#[pyfunction]
fn sigmoid_f32(a: &PyNumArray32) -> PyNumArray32 {
PyNumArray32 {
inner: a.inner.sigmoid(),
}
}

#[pyfunction]
fn concatenate_f32(arrays: Vec<PyNumArray32>, axis: usize) -> PyResult<PyNumArray32> {
let rust_arrays: Vec<NumArray32> = arrays.iter().map(|array| array.inner.clone()).collect();
let result = NumArray32::concatenate(&rust_arrays, axis);
Ok(PyNumArray32 { inner: result })
}

#[pyfunction]
fn zeros_f64(shape: Vec<usize>) -> PyResult<PyNumArray64> {
Python::with_gil(|py| {
Expand Down Expand Up @@ -390,6 +456,34 @@ fn max_f64(a: &PyNumArray64) -> PyResult<f64> {
Ok(a.inner.max())
}

#[pyfunction]
fn exp_f64(a: &PyNumArray64) -> PyNumArray64 {
PyNumArray64 {
inner: a.inner.exp(),
}
}

#[pyfunction]
fn log_f64(a: &PyNumArray64) -> PyNumArray64 {
PyNumArray64 {
inner: a.inner.log(),
}
}

#[pyfunction]
fn sigmoid_f64(a: &PyNumArray64) -> PyNumArray64 {
PyNumArray64 {
inner: a.inner.sigmoid(),
}
}

#[pyfunction]
fn concatenate_f64(arrays: Vec<PyNumArray64>, axis: usize) -> PyResult<PyNumArray64> {
let rust_arrays: Vec<NumArray64> = arrays.iter().map(|array| array.inner.clone()).collect();
let result = NumArray64::concatenate(&rust_arrays, axis);
Ok(PyNumArray64 { inner: result })
}

#[pymodule]
fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyNumArray32>()?;
Expand All @@ -403,6 +497,10 @@ fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(mean_f32, m)?)?;
m.add_function(wrap_pyfunction!(min_f32, m)?)?;
m.add_function(wrap_pyfunction!(max_f32, m)?)?;
m.add_function(wrap_pyfunction!(exp_f32, m)?)?;
m.add_function(wrap_pyfunction!(log_f32, m)?)?;
m.add_function(wrap_pyfunction!(sigmoid_f32, m)?)?;
m.add_function(wrap_pyfunction!(concatenate_f32, m)?)?;

m.add_function(wrap_pyfunction!(zeros_f64, m)?)?;
m.add_function(wrap_pyfunction!(ones_f64, m)?)?;
Expand All @@ -413,6 +511,10 @@ fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(mean_f64, m)?)?;
m.add_function(wrap_pyfunction!(min_f64, m)?)?;
m.add_function(wrap_pyfunction!(max_f64, m)?)?;
m.add_function(wrap_pyfunction!(exp_f64, m)?)?;
m.add_function(wrap_pyfunction!(log_f64, m)?)?;
m.add_function(wrap_pyfunction!(sigmoid_f64, m)?)?;
m.add_function(wrap_pyfunction!(concatenate_f64, m)?)?;

Ok(())
}
52 changes: 52 additions & 0 deletions bindings/python/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,55 @@ def test_linspace():
b = np.linspace(0, 10, 5, dtype="float32")
assert a.tolist() == [0.0, 2.5, 5.0, 7.5, 10.0], "Linspace failed"
assert np.allclose(a.tolist(), b, atol=1e-9), "Linspace failed"


def test_exp():
a = rnp.NumArray([0.0, 1.0, 2.0, 3.0], dtype="float32")
b = np.exp(np.array([0.0, 1.0, 2.0, 3.0], dtype="float32"))
assert np.allclose(a.exp().tolist(), b, atol=1e-9), "Exp failed"


def test_log():
a = rnp.NumArray([1.0, 2.0, 4.0, 8.0], dtype="float32")
b = np.log(np.array([1.0, 2.0, 4.0, 8.0], dtype="float32"))
assert np.allclose(a.log().tolist(), b, atol=1e-9), "Log failed"


def test_sigmoid():
a = rnp.NumArray([0.0, 1.0, 2.0, 3.0], dtype="float32")
b = 1 / (1 + np.exp(-np.array([0.0, 1.0, 2.0, 3.0], dtype="float32")))
assert np.allclose(a.sigmoid().tolist(), b, atol=1e-9), "Sigmoid failed"


def test_concatenate_along_axis_0():
a = rnp.NumArray([[1.0, 2.0], [3.0, 4.0]], dtype="float32")
b = rnp.NumArray([[5.0, 6.0], [7.0, 8.0]], dtype="float32")
c = np.concatenate(
[
np.array([[1.0, 2.0], [3.0, 4.0]], dtype="float32"),
np.array([[5.0, 6.0], [7.0, 8.0]], dtype="float32"),
],
axis=0,
)

assert rnp.concatenate([a, b], axis=0).shape == c.shape, "Shape mismatch"
assert np.allclose(
rnp.concatenate([a, b], axis=0).tolist(), c, atol=1e-9
), "Concatenate failed"


def test_concatenate_along_axis_1():
a = rnp.NumArray([[1.0, 2.0], [3.0, 4.0]], dtype="float32")
b = rnp.NumArray([[5.0, 6.0], [7.0, 8.0]], dtype="float32")
c = np.concatenate(
[
np.array([[1.0, 2.0], [3.0, 4.0]], dtype="float32"),
np.array([[5.0, 6.0], [7.0, 8.0]], dtype="float32"),
],
axis=1,
)

assert rnp.concatenate([a, b], axis=1).shape == c.shape, "Shape mismatch"
assert np.allclose(
rnp.concatenate([a, b], axis=1).tolist(), c, atol=1e-9
), "Concatenate failed"
13 changes: 11 additions & 2 deletions rustynum-rs/src/num_array/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use super::num_array::{NumArray, NumArray32, NumArray64};
use std::iter::Sum;

use crate::simd_ops::SimdOps;
use crate::traits::{FromU32, FromUsize, NumOps};
use crate::traits::{ExpLog, FromU32, FromUsize, NumOps};
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
use std::ops::{Add, Div, Mul, Neg, Sub};

/// Performs matrix-vector multiplication.
///
Expand Down Expand Up @@ -35,6 +35,9 @@ where
+ PartialOrd
+ FromU32
+ FromUsize
+ FromUsize
+ ExpLog
+ Neg<Output = T>
+ NumOps
+ Debug,
Ops: SimdOps<T>,
Expand Down Expand Up @@ -83,6 +86,9 @@ where
+ PartialOrd
+ FromU32
+ FromUsize
+ FromUsize
+ ExpLog
+ Neg<Output = T>
+ NumOps
+ Debug,
Ops: SimdOps<T>,
Expand Down Expand Up @@ -133,6 +139,9 @@ where
+ PartialOrd
+ FromU32
+ FromUsize
+ FromUsize
+ ExpLog
+ Neg<Output = T>
+ NumOps
+ Debug,
Ops: SimdOps<T>,
Expand Down
Loading