Skip to content

Commit

Permalink
created new struct dedicated for tensor, changed modules pathes for c…
Browse files Browse the repository at this point in the history
…hannel_data
  • Loading branch information
«ratal» committed Mar 6, 2024
1 parent 5388d52 commit a30b933
Show file tree
Hide file tree
Showing 24 changed files with 2,094 additions and 1,864 deletions.
912 changes: 455 additions & 457 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/channel_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod arrow_helpers;
pub mod channel_data;
pub mod complex_arrow;
pub mod tensor_arrow;
File renamed without changes.
636 changes: 270 additions & 366 deletions src/mdfreader/channel_data.rs → src/channel_data/channel_data.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow::{
array::{ArrayBuilder, BooleanBufferBuilder, PrimitiveArray, PrimitiveBuilder},
buffer::{BooleanBuffer, MutableBuffer},
datatypes::{ArrowPrimitiveType, DataType, Float32Type, Float64Type},
datatypes::{ArrowPrimitiveType, Float32Type, Float64Type},
};

/// Complex
Expand All @@ -10,7 +10,6 @@ use arrow::{
pub struct ComplexArrow<T: ArrowPrimitiveType> {
null_buffer_builder: Option<BooleanBuffer>,
values_builder: PrimitiveBuilder<T>,
data_type: DataType,
len: usize,
}

Expand All @@ -22,7 +21,6 @@ impl<T: ArrowPrimitiveType> ComplexArrow<T> {
Self {
null_buffer_builder: None,
values_builder: PrimitiveBuilder::with_capacity(capacity * 2),
data_type: T::DATA_TYPE,
len: 0,
}
}
Expand All @@ -32,7 +30,6 @@ impl<T: ArrowPrimitiveType> ComplexArrow<T> {
Self {
null_buffer_builder: None,
values_builder,
data_type: T::DATA_TYPE,
len: length,
}
}
Expand All @@ -51,7 +48,6 @@ impl<T: ArrowPrimitiveType> ComplexArrow<T> {
Self {
null_buffer_builder,
values_builder: primitive_builder,
data_type: T::DATA_TYPE,
len: length,
}
}
Expand All @@ -64,16 +60,9 @@ impl<T: ArrowPrimitiveType> ComplexArrow<T> {
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn data_type(&self) -> DataType {
self.data_type.clone()
}
pub fn values_slice(&self) -> &[T::Native] {
self.values_builder.values_slice()
}
/// Returns the current values buffer as a mutable slice
pub fn values_slice_mut(&mut self) -> &mut [T::Native] {
self.values_builder.values_slice_mut()
}
pub fn nulls(&self) -> Option<&BooleanBuffer> {
self.null_buffer_builder.as_ref()
}
Expand Down Expand Up @@ -155,7 +144,6 @@ impl Clone for ComplexArrow<Float32Type> {
.finish_cloned()
.into_builder()
.expect("failed getting builder from Primitive array"),
data_type: self.data_type.clone(),
len: self.len.clone(),
}
}
Expand All @@ -170,7 +158,6 @@ impl Clone for ComplexArrow<Float64Type> {
.finish_cloned()
.into_builder()
.expect("failed getting builder from Primitive array"),
data_type: self.data_type.clone(),
len: self.len.clone(),
}
}
Expand Down
194 changes: 194 additions & 0 deletions src/channel_data/tensor_arrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use arrow::{
array::{ArrayBuilder, BooleanBufferBuilder, PrimitiveArray, PrimitiveBuilder},
buffer::{BooleanBuffer, MutableBuffer},
datatypes::{
ArrowPrimitiveType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
},
};

/// Complex
#[derive(Debug)]
pub struct TensorArrow<T: ArrowPrimitiveType> {
null_buffer_builder: Option<BooleanBuffer>,
values_builder: PrimitiveBuilder<T>,
len: usize,
shape: Vec<usize>,
order: Order,
}

/// Order of the array, Row or Column Major (first)
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum Order {
#[default]
RowMajor,
ColumnMajor,
}

impl<T: ArrowPrimitiveType> TensorArrow<T> {
pub fn new() -> Self {
Self::with_capacity(1024, vec![1], Order::RowMajor)
}
pub fn with_capacity(capacity: usize, shape: Vec<usize>, order: Order) -> Self {
Self {
null_buffer_builder: None,
values_builder: PrimitiveBuilder::with_capacity(capacity * 2),
len: 0,
shape,
order,
}
}
pub fn new_from_buffer(values_buffer: MutableBuffer, shape: Vec<usize>, order: Order) -> Self {
let length = values_buffer.len() / shape.iter().product::<usize>();
let values_builder = PrimitiveBuilder::new_from_buffer(values_buffer, None);
Self {
null_buffer_builder: None,
values_builder,
len: length,
shape,
order,
}
}
pub fn new_from_primitive(
primitive_builder: PrimitiveBuilder<T>,
null_buffer: Option<&BooleanBuffer>,
shape: Vec<usize>,
order: Order,
) -> Self {
let length = primitive_builder.len() / shape.iter().product::<usize>();
match null_buffer {
Some(null_buffer_builder) => {
assert_eq!(
null_buffer_builder.len() * shape.iter().product::<usize>(),
primitive_builder.len()
)
}
None => {}
};
let null_buffer_builder = null_buffer.map(|buffer| buffer.clone());
Self {
null_buffer_builder,
values_builder: primitive_builder,
len: length,
shape,
order,
}
}
pub fn values(&mut self) -> &mut PrimitiveBuilder<T> {
&mut self.values_builder
}
pub fn len(&self) -> usize {
self.len
}
pub fn shape(&self) -> &Vec<usize> {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn order(&self) -> &Order {
&self.order
}
pub fn values_slice(&self) -> &[T::Native] {
self.values_builder.values_slice()
}
pub fn values_slice_mut(&mut self) -> &mut [T::Native] {
self.values_builder.values_slice_mut()
}
pub fn nulls(&self) -> Option<&BooleanBuffer> {
self.null_buffer_builder.as_ref()
}
pub fn finish_cloned(&self) -> PrimitiveArray<T> {
self.values_builder.finish_cloned()
}
pub fn finish(&mut self) -> PrimitiveArray<T> {
self.values_builder.finish()
}
pub fn set_validity(&mut self, mask: &mut BooleanBufferBuilder) {
self.null_buffer_builder = Some(mask.finish());
}
}

impl<T: ArrowPrimitiveType> Default for TensorArrow<T> {
fn default() -> Self {
Self::new()
}
}

#[macro_export]
macro_rules! tensor_arrow_peq {
($type:tt) => {
impl PartialEq for TensorArrow<$type> {
fn eq(&self, other: &Self) -> bool {
if self.values_builder.finish_cloned() == other.values_builder.finish_cloned() {
match &self.null_buffer_builder {
Some(buffer) => match &other.null_buffer_builder {
Some(other_buffer) => buffer == other_buffer,
None => false,
},
None => {
if other.null_buffer_builder.is_none() {
true
} else {
false
}
}
}
} else {
false
}
}

fn ne(&self, other: &Self) -> bool {
!self.eq(other)
}
}
};
}

tensor_arrow_peq!(Int8Type);
tensor_arrow_peq!(UInt8Type);
tensor_arrow_peq!(Int16Type);
tensor_arrow_peq!(UInt16Type);
tensor_arrow_peq!(Int32Type);
tensor_arrow_peq!(UInt32Type);
tensor_arrow_peq!(Int64Type);
tensor_arrow_peq!(UInt64Type);
tensor_arrow_peq!(Float32Type);
tensor_arrow_peq!(Float64Type);

#[macro_export]
macro_rules! tensor_arrow_clone {
($type:tt) => {
impl Clone for TensorArrow<$type> {
fn clone(&self) -> Self {
Self {
null_buffer_builder: self.null_buffer_builder.clone(),
values_builder: self
.values_builder
.finish_cloned()
.into_builder()
.expect("failed getting builder from Primitive array"),
len: self.len.clone(),
shape: self.shape.clone(),
order: self.order.clone(),
}
}
}
};
}

tensor_arrow_clone!(Int8Type);
tensor_arrow_clone!(UInt8Type);
tensor_arrow_clone!(Int16Type);
tensor_arrow_clone!(UInt16Type);
tensor_arrow_clone!(Int32Type);
tensor_arrow_clone!(UInt32Type);
tensor_arrow_clone!(Int64Type);
tensor_arrow_clone!(UInt64Type);
tensor_arrow_clone!(Float32Type);
tensor_arrow_clone!(Float64Type);
Loading

0 comments on commit a30b933

Please sign in to comment.