-
Notifications
You must be signed in to change notification settings - Fork 0
/
f32.rs
124 lines (111 loc) · 3.54 KB
/
f32.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArrayDyn};
use pyo3::basic::CompareOp;
use pyo3::types::PyType;
use pyo3::{prelude::*, PyMappingProtocol, PyNumberProtocol, PyObjectProtocol};
use crate::monomorphs::RaggedBufferI64;
use crate::ragged_buffer::RaggedBuffer;
use super::IndicesOrInt;
#[pyclass]
#[derive(Clone)]
pub struct RaggedBufferF32(pub RaggedBuffer<f32>);
#[pymethods]
impl RaggedBufferF32 {
#[new]
pub fn new(features: usize) -> Self {
RaggedBufferF32(RaggedBuffer::new(features))
}
#[classmethod]
fn from_array(_cls: &PyType, array: PyReadonlyArray3<f32>) -> Self {
RaggedBufferF32(RaggedBuffer::from_array(array))
}
#[classmethod]
fn from_flattened(
_cls: &PyType,
flattened: PyReadonlyArray2<f32>,
lengths: PyReadonlyArray1<i64>,
) -> Self {
RaggedBufferF32(RaggedBuffer::from_flattened(flattened, lengths))
}
fn push(&mut self, features: PyReadonlyArrayDyn<f32>) {
self.0.push(features);
}
fn clear(&mut self) {
self.0.clear();
}
fn as_array<'a>(
&self,
py: Python<'a>,
) -> &'a numpy::PyArray<f32, numpy::ndarray::Dim<[usize; 2]>> {
self.0.as_array(py)
}
fn extend(&mut self, other: &RaggedBufferF32) -> PyResult<()> {
self.0.extend(&other.0)
}
fn size0(&self) -> usize {
self.0.size0()
}
fn size1(&self, py: Python, i: Option<usize>) -> PyResult<PyObject> {
match i {
Some(i) => self.0.size1(i).map(|s| s.into_py(py)),
None => Ok(self.0.lengths(py).into_py(py)),
}
}
fn size2(&self) -> usize {
self.0.size2()
}
fn indices(&self, dim: usize) -> PyResult<RaggedBufferI64> {
Ok(RaggedBufferI64(self.0.indices(dim)?))
}
fn flat_indices(&self) -> PyResult<RaggedBufferI64> {
Ok(RaggedBufferI64(self.0.flat_indices()?))
}
// TODO: eliminate copy
#[classmethod]
fn cat(_cls: &PyType, buffers: Vec<RaggedBufferF32>, dim: usize) -> PyResult<Self> {
Ok(RaggedBufferF32(RaggedBuffer::cat(
&buffers.iter().map(|b| &b.0).collect::<Vec<_>>(),
dim,
)?))
}
}
#[pyproto]
impl PyObjectProtocol for RaggedBufferF32 {
fn __str__(&self) -> PyResult<String> {
self.0.__str__()
}
fn __repr__(&self) -> PyResult<String> {
self.0.__str__()
}
fn __richcmp__(&self, other: RaggedBufferF32, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Eq => Ok(self.0 == other.0),
CompareOp::Ne => Ok(self.0 != other.0),
_ => Err(pyo3::exceptions::PyTypeError::new_err(
"Only == and != are supported",
)),
}
}
}
#[derive(FromPyObject)]
pub enum RaggedBufferF32OrF32 {
RB(RaggedBufferF32),
Scalar(f32),
}
#[pyproto]
impl PyNumberProtocol for RaggedBufferF32 {
fn __add__(lhs: RaggedBufferF32, rhs: RaggedBufferF32OrF32) -> PyResult<RaggedBufferF32> {
match rhs {
RaggedBufferF32OrF32::RB(rhs) => Ok(RaggedBufferF32(lhs.0.add(&rhs.0)?)),
RaggedBufferF32OrF32::Scalar(rhs) => Ok(RaggedBufferF32(lhs.0.add_scalar(rhs))),
}
}
}
#[pyproto]
impl<'p> PyMappingProtocol for RaggedBufferF32 {
fn __getitem__(&self, index: IndicesOrInt<'p>) -> PyResult<RaggedBufferF32> {
match index {
IndicesOrInt::Indices(indices) => Ok(RaggedBufferF32(self.0.swizzle(indices)?)),
IndicesOrInt::Int(i) => Ok(RaggedBufferF32(self.0.get(i))),
}
}
}