Skip to content

Commit

Permalink
Very ad hoc "values" method
Browse files Browse the repository at this point in the history
  • Loading branch information
kojix2 committed Sep 28, 2023
1 parent 291bf0e commit 6530685
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ext/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ impl PyTensor {
))
}

// FXIME: Do not use `to_f64` here.
fn values(&self) -> PyResult<Vec<f64>> {
let values = self
.0
.to_dtype(DType::F64)
.map_err(wrap_err)?
.flatten_all()
.map_err(wrap_err)?
.to_vec1()
.map_err(wrap_err)?;
Ok(values)
}

/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
fn shape(&self) -> Vec<usize> {
Expand Down Expand Up @@ -699,6 +712,7 @@ fn init(ruby: &Ruby) -> PyResult<()> {
rb_tensor.define_singleton_method("randn", function!(PyTensor::randn, 1))?;
rb_tensor.define_singleton_method("ones", function!(PyTensor::ones, 1))?;
rb_tensor.define_singleton_method("zeros", function!(PyTensor::zeros, 1))?;
rb_tensor.define_method("values", method!(PyTensor::values, 0))?;
rb_tensor.define_method("shape", method!(PyTensor::shape, 0))?;
rb_tensor.define_method("stride", method!(PyTensor::stride, 0))?;
rb_tensor.define_method("dtype", method!(PyTensor::dtype, 0))?;
Expand Down

0 comments on commit 6530685

Please sign in to comment.