Skip to content

Commit

Permalink
implement Tensor into_vec (#125)
Browse files Browse the repository at this point in the history
* implement Tensor into_vec

* temporary fix
  • Loading branch information
edgarriba authored Sep 5, 2024
1 parent e182ddb commit dd35ed0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
68 changes: 60 additions & 8 deletions crates/kornia-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,20 @@ where
///
/// The vector must have the correct length and alignment.
pub fn from_vec(vec: Vec<T>, alloc: A) -> Self {
// NOTE: this is a temporary solution until we have a custom allocator for the buffer
// create immutable buffer from vec
let buffer = unsafe {
// SAFETY: `vec` is properly aligned and has the correct length.
Buffer::from_custom_allocation(
NonNull::new_unchecked(vec.as_ptr() as *mut u8),
vec.len() * std::mem::size_of::<T>(),
Arc::new(vec),
)
};
// let _buffer = unsafe {
// // SAFETY: `vec` is properly aligned and has the correct length.
// Buffer::from_custom_allocation(
// NonNull::new_unchecked(vec.as_ptr() as *mut u8),
// vec.len() * std::mem::size_of::<T>(),
// Arc::new(vec),
// )
// };

// create immutable buffer from vec
// NOTE: this is a temporary solution until we have a custom allocator for the buffer
let buffer = Buffer::from_vec(vec);

// create tensor storage
Self {
Expand All @@ -98,6 +103,36 @@ where
}
}

/// Converts the tensor storage into a `Vec<T>`.
///
/// NOTE: useful for safe zero copies.
///
/// This method attempts to convert the internal buffer of the tensor storage into a `Vec<T>`.
/// If the conversion fails (e.g., due to reference counting issues), it constructs a new `Vec<T>`
/// by copying the data from the raw pointer.
///
/// # Safety
///
/// This method is safe to call, but it may involve unsafe operations internally when
/// constructing a new Vec from raw parts if the initial conversion fails.
///
/// # Performance
///
/// In the best case, this operation is O(1) when the internal buffer can be directly converted.
/// In the worst case, it's O(n) where n is the number of elements, as it may need to copy all data.
pub fn into_vec(self) -> Vec<T> {
match self.data.into_inner().into_vec() {
Ok(vec) => vec,
Err(buf) => unsafe {
std::slice::from_raw_parts(
buf.as_ptr() as *const T,
buf.len() / std::mem::size_of::<T>(),
)
.to_vec()
},
}
}

/// Returns the allocator used to allocate the tensor storage.
#[inline]
pub fn alloc(&self) -> &A {
Expand Down Expand Up @@ -313,4 +348,21 @@ mod tests {

Ok(())
}

#[test]
fn test_tensor_storage_into_vec() {
let allocator = CpuAllocator;
let original_vec = vec![1, 2, 3, 4, 5];
let original_vec_ptr = original_vec.as_ptr();
let original_vec_capacity = original_vec.capacity();

let storage = TensorStorage::<i32, _>::from_vec(original_vec, allocator);

// Convert the storage back to a vector
let result_vec = storage.into_vec();

// check NO copy
assert_eq!(result_vec.capacity(), original_vec_capacity);
assert!(std::ptr::eq(result_vec.as_ptr(), original_vec_ptr));
}
}
9 changes: 9 additions & 0 deletions crates/kornia-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ where
self.storage.as_mut_ptr()
}

/// Consumes the tensor and returns the underlying vector.
///
/// This method destroys the tensor and returns ownership of the underlying data.
/// The returned vector will have a length equal to the total number of elements in the tensor.
///
pub fn into_vec(self) -> Vec<T> {
self.storage.into_vec()
}

/// Creates a new `Tensor` with the given shape and data.
///
/// # Arguments
Expand Down

0 comments on commit dd35ed0

Please sign in to comment.