Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add min max axes operator is #19

Merged
merged 6 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 0 additions & 78 deletions bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ crate-type = ["cdylib"]


[dependencies]
numpy = "0.19.0"
pyo3 = { version = "0.19.2", features = ["extension-module"] }

# Adjust the path to point to your rustynum-rs crate relative to the bindings/python directory
rustynum-rs = { path = "../../rustynum-rs" }

[profile.release]
strip = true
lto = true
panic = "abort"
60 changes: 46 additions & 14 deletions bindings/python/rustynum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,31 +346,63 @@ def mean(
)
return NumArray(result, dtype=self.dtype)

def min(self) -> float:
def min(
self, axes: Union[None, int, Sequence[int]] = None
) -> Union["NumArray", float]:
"""
Finds the minimum value in the NumArray.
Return the minimum along the specified axes.

Parameters:
axes: Optional; Axis or axes along which to find the minimum. If None,
the minimum of all elements is computed as a scalar.

Returns:
The minimum value as a float.
"""
return (
_rustynum.min_f32(self.inner)
A new NumArray with the minimum values along the specified axes,
or a scalar if no axes are given.
"""
if axes is None:
return (
_rustynum.min_f32(self.inner)
if self.dtype == "float32"
else _rustynum.min_f64(self.inner)
)

axes = [axes] if isinstance(axes, int) else axes
result = (
_rustynum.min_axes_f32(self.inner, axes)
if self.dtype == "float32"
else _rustynum.min_f64(self.inner)
else _rustynum.min_axes_f64(self.inner, axes)
)
return NumArray(result, dtype=self.dtype)

def max(self) -> float:
def max(
self, axes: Union[None, int, Sequence[int]] = None
) -> Union["NumArray", float]:
"""
Finds the maximum value in the NumArray.
Return the maximum along the specified axes.

Parameters:
axes: Optional; Axis or axes along which to find the maximum. If None,
the maximum of all elements is computed as a scalar.

Returns:
The maximum value as a float.
"""
return (
_rustynum.max_f32(self.inner)
A new NumArray with the maximum values along the specified axes,
or a scalar if no axes are given.
"""
if axes is None:
return (
_rustynum.max_f32(self.inner)
if self.dtype == "float32"
else _rustynum.max_f64(self.inner)
)

axes = [axes] if isinstance(axes, int) else axes
result = (
_rustynum.max_axes_f32(self.inner, axes)
if self.dtype == "float32"
else _rustynum.max_f64(self.inner)
else _rustynum.max_axes_f64(self.inner, axes)
)
return NumArray(result, dtype=self.dtype)

def __imul__(self, scalar: float) -> "NumArray":
"""
Expand Down
102 changes: 102 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,32 @@ impl PyNumArrayF32 {
inner: self.inner.sigmoid(),
}
}

fn min_axes(&self, axes: Option<&PyList>) -> PyResult<PyNumArrayF32> {
Python::with_gil(|py| {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?;
self.inner.min_axes(Some(&axes_vec))
}
None => self.inner.min_axes(None),
};
Ok(PyNumArrayF32 { inner: result })
})
}

fn max_axes(&self, axes: Option<&PyList>) -> PyResult<PyNumArrayF32> {
Python::with_gil(|py| {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?;
self.inner.max_axes(Some(&axes_vec))
}
None => self.inner.max_axes(None),
};
Ok(PyNumArrayF32 { inner: result })
})
}
}

#[pymethods]
Expand Down Expand Up @@ -320,6 +346,32 @@ impl PyNumArrayF64 {
inner: self.inner.sigmoid(),
}
}

fn min_axes(&self, axes: Option<&PyList>) -> PyResult<PyNumArrayF64> {
Python::with_gil(|py| {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?;
self.inner.min_axes(Some(&axes_vec))
}
None => self.inner.min_axes(None),
};
Ok(PyNumArrayF64 { inner: result })
})
}

fn max_axes(&self, axes: Option<&PyList>) -> PyResult<PyNumArrayF64> {
Python::with_gil(|py| {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?;
self.inner.max_axes(Some(&axes_vec))
}
None => self.inner.max_axes(None),
};
Ok(PyNumArrayF64 { inner: result })
})
}
}

#[pymethods]
Expand Down Expand Up @@ -746,11 +798,35 @@ fn min_f32(a: &PyNumArrayF32) -> PyResult<f32> {
Ok(a.inner.min())
}

#[pyfunction]
fn min_axes_f32(a: &PyNumArrayF32, axes: Option<&PyList>) -> PyResult<PyNumArrayF32> {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?; // Convert PyList to Vec<usize>
a.inner.min_axes(Some(&axes_vec))
}
None => a.inner.min_axes(None),
};
Ok(PyNumArrayF32 { inner: result })
}

#[pyfunction]
fn max_f32(a: &PyNumArrayF32) -> PyResult<f32> {
Ok(a.inner.max())
}

#[pyfunction]
fn max_axes_f32(a: &PyNumArrayF32, axes: Option<&PyList>) -> PyResult<PyNumArrayF32> {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?; // Convert PyList to Vec<usize>
a.inner.max_axes(Some(&axes_vec))
}
None => a.inner.max_axes(None),
};
Ok(PyNumArrayF32 { inner: result })
}

#[pyfunction]
fn exp_f32(a: &PyNumArrayF32) -> PyNumArrayF32 {
PyNumArrayF32 {
Expand Down Expand Up @@ -851,11 +927,35 @@ fn min_f64(a: &PyNumArrayF64) -> PyResult<f64> {
Ok(a.inner.min())
}

#[pyfunction]
fn min_axes_f64(a: &PyNumArrayF64, axes: Option<&PyList>) -> PyResult<PyNumArrayF64> {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?; // Convert PyList to Vec<usize>
a.inner.min_axes(Some(&axes_vec))
}
None => a.inner.min_axes(None),
};
Ok(PyNumArrayF64 { inner: result })
}

#[pyfunction]
fn max_f64(a: &PyNumArrayF64) -> PyResult<f64> {
Ok(a.inner.max())
}

#[pyfunction]
fn max_axes_f64(a: &PyNumArrayF64, axes: Option<&PyList>) -> PyResult<PyNumArrayF64> {
let result = match axes {
Some(axes_list) => {
let axes_vec: Vec<usize> = axes_list.extract()?; // Convert PyList to Vec<usize>
a.inner.max_axes(Some(&axes_vec))
}
None => a.inner.max_axes(None),
};
Ok(PyNumArrayF64 { inner: result })
}

#[pyfunction]
fn exp_f64(a: &PyNumArrayF64) -> PyNumArrayF64 {
PyNumArrayF64 {
Expand Down Expand Up @@ -898,7 +998,9 @@ fn _rustynum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(linspace_f32, m)?)?;
m.add_function(wrap_pyfunction!(mean_f32, m)?)?;
m.add_function(wrap_pyfunction!(min_f32, m)?)?;
m.add_function(wrap_pyfunction!(min_axes_f32, m)?)?;
m.add_function(wrap_pyfunction!(max_f32, m)?)?;
m.add_function(wrap_pyfunction!(max_axes_f32, m)?)?;
m.add_function(wrap_pyfunction!(exp_f32, m)?)?;
m.add_function(wrap_pyfunction!(log_f32, m)?)?;
m.add_function(wrap_pyfunction!(sigmoid_f32, m)?)?;
Expand Down
Loading
Loading