Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorStorage improvements #107

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/kornia-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ version.workspace = true
[dependencies]

# external
arrow-buffer = "52.0.0"
arrow-buffer = "52.2.0"
serde = { version = "1", features = ["derive"] }
thiserror = "1"

Expand Down
1 change: 1 addition & 0 deletions crates/kornia-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod serde;
pub mod storage;

pub use crate::allocator::{CpuAllocator, TensorAllocator};
pub use crate::storage::SafeTensorType;
pub use crate::tensor::{Tensor, TensorError};

/// Type alias for a 1-dimensional tensor.
Expand Down
10 changes: 7 additions & 3 deletions crates/kornia-core/src/serde.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use crate::{allocator::TensorAllocator, storage::TensorStorage, Tensor};
use crate::{
allocator::TensorAllocator,
storage::{SafeTensorType, TensorStorage},
Tensor,
};

use serde::ser::SerializeStruct;
use serde::Deserialize;

impl<T, const N: usize, A: TensorAllocator> serde::Serialize for Tensor<T, N, A>
where
T: serde::Serialize + arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
T: serde::Serialize + SafeTensorType,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand All @@ -22,7 +26,7 @@ where
impl<'de, T, const N: usize, A: TensorAllocator + Default> serde::Deserialize<'de>
for Tensor<T, N, A>
where
T: serde::Deserialize<'de> + arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
T: serde::Deserialize<'de> + SafeTensorType,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand Down
57 changes: 45 additions & 12 deletions crates/kornia-core/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
use super::allocator::{TensorAllocator, TensorAllocatorError};
use arrow_buffer::{ArrowNativeType, Buffer};
use std::marker::PhantomData;
use arrow_buffer::{Buffer, ScalarBuffer};
use std::sync::Arc;
use std::{alloc::Layout, ptr::NonNull};

/// A trait to define the types that can be used in a tensor.
pub trait SafeTensorType: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe {}

/// Implement the `SafeTensorType` trait for the supported types.
impl SafeTensorType for u8 {}
impl SafeTensorType for u16 {}
impl SafeTensorType for u32 {}
impl SafeTensorType for u64 {}
impl SafeTensorType for i8 {}
impl SafeTensorType for i16 {}
impl SafeTensorType for i32 {}
impl SafeTensorType for i64 {}
impl SafeTensorType for f32 {}
impl SafeTensorType for f64 {}

/// represents a contiguous memory region that can be shared with other buffers and across thread boundaries.
///
/// NOTE: https://docs.rs/arrow/latest/arrow/buffer/struct.Buffer.html
/// NOTE: https://docs.rs/arrow-buffer/latest/arrow_buffer/buffer/struct.ScalarBuffer.html
///
/// # Safety
///
Expand All @@ -16,18 +30,19 @@ use std::{alloc::Layout, ptr::NonNull};
///
/// * `data` - The buffer containing the tensor storage.
/// * `alloc` - The allocator used to allocate the tensor storage.
/// * `marker` - The marker type for the tensor storage.
pub struct TensorStorage<T: ArrowNativeType, A: TensorAllocator> {
pub struct TensorStorage<T, A: TensorAllocator>
where
T: SafeTensorType,
{
/// The buffer containing the tensor storage.
pub data: Buffer,
data: ScalarBuffer<T>,
alloc: A,
marker: PhantomData<T>,
}

/// Implement the `TensorStorage` struct.
impl<T, A: TensorAllocator> TensorStorage<T, A>
where
T: ArrowNativeType + std::panic::RefUnwindSafe,
T: SafeTensorType,
{
/// Creates a new tensor storage with the given length and allocator.
///
Expand All @@ -54,9 +69,8 @@ where
};

Ok(Self {
data: buffer,
data: buffer.into(),
alloc,
marker: PhantomData,
})
}

Expand All @@ -82,9 +96,8 @@ where

// create tensor storage
let storage = Self {
data: buffer,
data: buffer.into(),
alloc,
marker: PhantomData,
};

Ok(storage)
Expand All @@ -94,6 +107,26 @@ where
pub fn alloc(&self) -> &A {
&self.alloc
}

/// Returns the length of the tensor storage.
pub fn len(&self) -> usize {
self.data.len()
}

/// Returns whether the tensor storage is empty.
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}

/// Return the data pointer as a slice.
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len()) }
}

/// Return the data pointer as a mutable slice.
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.len()) }
}
}

#[cfg(test)]
Expand Down
25 changes: 10 additions & 15 deletions crates/kornia-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use thiserror::Error;

use super::{
allocator::{CpuAllocator, TensorAllocator, TensorAllocatorError},
storage::TensorStorage,
storage::{SafeTensorType, TensorStorage},
};

/// An error type for tensor operations.
Expand Down Expand Up @@ -42,7 +42,7 @@ fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {

/// A data structure to represent a multi-dimensional tensor.
///
/// NOTE: internally the data is stored as an arrow::Buffer which represents a contiguous memory
/// NOTE: internally the data is stored as an arrow::ScalarBuffer which represents a contiguous memory
/// region that can be shared with other buffers and across thread boundaries.
///
/// # Attributes
Expand All @@ -61,7 +61,7 @@ fn get_strides_from_shape<const N: usize>(shape: [usize; N]) -> [usize; N] {
/// assert_eq!(t.shape, [2, 2]);
pub struct Tensor<T, const N: usize, A: TensorAllocator = CpuAllocator>
where
T: arrow_buffer::ArrowNativeType,
T: SafeTensorType,
{
/// The storage of the tensor.
pub storage: TensorStorage<T, A>,
Expand All @@ -74,7 +74,7 @@ where
/// Implementation of the Tensor struct.
impl<T, const N: usize, A> Tensor<T, N, A>
where
T: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
T: SafeTensorType,
A: TensorAllocator,
{
/// Create a new `Tensor` with uninitialized data.
Expand Down Expand Up @@ -104,8 +104,7 @@ where
///
/// A slice containing the data of the tensor.
pub fn as_slice(&self) -> &[T] {
let slice = self.storage.data.typed_data::<T>();
slice
self.storage.as_slice()
}

/// Get the data of the tensor as a mutable slice.
Expand All @@ -114,11 +113,7 @@ where
///
/// A mutable slice containing the data of the tensor.
pub fn as_slice_mut(&mut self) -> &mut [T] {
// convert the data to a typed slice
let slice = self.storage.data.typed_data::<T>();

// TODO: verify if there is a better way to do this
unsafe { std::slice::from_raw_parts_mut(slice.as_ptr() as *mut T, slice.len()) }
self.storage.as_mut_slice()
}

/// Creates a new `Tensor` with the given shape and data.
Expand Down Expand Up @@ -250,7 +245,7 @@ where

/// Returns the number of elements in the tensor.
pub fn numel(&self) -> usize {
self.storage.data.len()
self.storage.len()
}

/// Get the offset of the element at the given index.
Expand Down Expand Up @@ -372,7 +367,7 @@ where
shape: [usize; M],
) -> Result<Tensor<T, M, A>, TensorError> {
let numel = shape.iter().product::<usize>();
if numel != self.storage.data.len() {
if numel != self.storage.len() {
Err(TensorError::InvalidShape(numel))?;
}

Expand Down Expand Up @@ -649,7 +644,7 @@ where
pub fn cast<U>(&self) -> Result<Tensor<U, N>, TensorError>
where
T: Copy + Into<U>,
U: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
U: SafeTensorType,
{
let data: Vec<U> = self.as_slice().iter().map(|x| (*x).into()).collect();
let storage = TensorStorage::from_vec(data, CpuAllocator)?;
Expand All @@ -663,7 +658,7 @@ where

impl<T, const N: usize, A> Clone for Tensor<T, N, A>
where
T: arrow_buffer::ArrowNativeType + std::panic::RefUnwindSafe,
T: SafeTensorType,
A: TensorAllocator,
{
fn clone(&self) -> Self {
Expand Down
Loading