Skip to content

Commit

Permalink
refactor: put all binding methods in StoragePtr class (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
manoflearning authored Aug 21, 2024
1 parent 7ff6c7d commit 32e40b2
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 138 deletions.
14 changes: 0 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,5 @@ use pyo3::prelude::*;
#[pyo3::pymodule]
fn cranberry(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<storage_ptr::StoragePtr>()?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_full, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_from_vec, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_clone, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_drop, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_neg, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_sqrt, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_exp, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_log, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_add, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_sub, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_mul, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_div, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_sum, m)?)?;
m.add_function(wrap_pyfunction!(storage_ptr::storage_max, m)?)?;
Ok(())
}
249 changes: 125 additions & 124 deletions src/storage_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use uuid::Uuid;

use crate::{device::Device, storage::Storage};
use crate::storage::Storage;

#[pyo3::pyclass]
#[derive(PartialEq)]
Expand Down Expand Up @@ -33,127 +33,128 @@ impl StoragePtr {
}
}

#[pyo3::pyfunction]
pub fn storage_full(value: f32, size: usize, device: &str) -> StoragePtr {
println!("{:?}", Device::from_str(device));
StoragePtr::from_storage(&Storage::new(value, size, device))
}
#[pyo3::pyfunction]
pub fn storage_from_vec(vec: Vec<f32>, device: &str) -> StoragePtr {
println!("{:?}", Device::from_str(device));
StoragePtr::from_storage(&Storage::from_vec(vec, device))
}
#[pyo3::pyfunction]
pub fn storage_clone(storage_ptr: &mut StoragePtr) -> StoragePtr {
storage_ptr.get_storage_mut().incref();
StoragePtr::new(storage_ptr.ptr)
}
#[pyo3::pyfunction]
pub fn storage_drop(storage_ptr: &mut StoragePtr) {
storage_ptr.get_storage_mut().decref();
storage_ptr.ptr = 0;
}
#[pyo3::pyfunction]
pub fn storage_neg(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::neg(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[pyo3::pyfunction]
pub fn storage_sqrt(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::sqrt(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size)
}
#[pyo3::pyfunction]
pub fn storage_exp(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::exp(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[pyo3::pyfunction]
pub fn storage_log(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::log(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[pyo3::pyfunction]
pub fn storage_add(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::add(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[pyo3::pyfunction]
pub fn storage_sub(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::sub(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[pyo3::pyfunction]
pub fn storage_mul(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::mul(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[pyo3::pyfunction]
pub fn storage_div(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::div(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[pyo3::pyfunction]
pub fn storage_sum(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::sum(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[pyo3::pyfunction]
pub fn storage_max(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::max(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
#[pyo3::pymethods]
impl StoragePtr {
#[staticmethod]
fn full(value: f32, size: usize, device: &str) -> StoragePtr {
StoragePtr::from_storage(&Storage::new(value, size, device))
}
#[staticmethod]
fn from_vec(vec: Vec<f32>, device: &str) -> StoragePtr {
StoragePtr::from_storage(&Storage::from_vec(vec, device))
}
#[staticmethod]
fn clone(storage_ptr: &mut StoragePtr) -> StoragePtr {
storage_ptr.get_storage_mut().incref();
StoragePtr::new(storage_ptr.ptr)
}
#[staticmethod]
fn drop(storage_ptr: &mut StoragePtr) {
storage_ptr.get_storage_mut().decref();
storage_ptr.ptr = 0;
}
#[staticmethod]
fn neg(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::neg(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[staticmethod]
fn sqrt(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::sqrt(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size)
}
#[staticmethod]
fn exp(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::exp(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[staticmethod]
fn log(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::log(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[staticmethod]
fn add(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::add(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[staticmethod]
fn sub(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::sub(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[staticmethod]
fn mul(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::mul(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[staticmethod]
fn div(
a: &StoragePtr,
b: &StoragePtr,
c: &mut StoragePtr,
idx_a: usize,
idx_b: usize,
idx_c: usize,
size: usize,
) {
Storage::div(
a.get_storage(),
b.get_storage(),
c.get_storage_mut(),
idx_a,
idx_b,
idx_c,
size,
);
}
#[staticmethod]
fn sum(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::sum(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
#[staticmethod]
fn max(a: &StoragePtr, b: &mut StoragePtr, idx_a: usize, idx_b: usize, size: usize) {
Storage::max(a.get_storage(), b.get_storage_mut(), idx_a, idx_b, size);
}
}

0 comments on commit 32e40b2

Please sign in to comment.