Skip to content

Commit

Permalink
rust - fix CeedOperatorFieldGet*
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Sep 9, 2024
1 parent 204a29b commit 80db90e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 43 deletions.
7 changes: 7 additions & 0 deletions rust/libceed/src/basis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ impl<'a> Basis<'a> {
})
}

pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
Ok(Self {
ptr,
_lifeline: PhantomData,
})
}

pub fn create_tensor_H1_Lagrange(
ceed: &crate::Ceed,
dim: usize,
Expand Down
7 changes: 7 additions & 0 deletions rust/libceed/src/elem_restriction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ impl<'a> ElemRestriction<'a> {
})
}

pub(crate) fn from_raw(ptr: bind_ceed::CeedElemRestriction) -> crate::Result<Self> {
Ok(Self {
ptr,
_lifeline: PhantomData,
})
}

pub fn create_oriented(
ceed: &crate::Ceed,
nelem: usize,
Expand Down
129 changes: 86 additions & 43 deletions rust/libceed/src/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,49 @@ use crate::prelude::*;
#[derive(Debug)]
pub struct OperatorField<'a> {
pub(crate) ptr: bind_ceed::CeedOperatorField,
pub(crate) vector: crate::Vector<'a>,
pub(crate) elem_restriction: crate::ElemRestriction<'a>,
pub(crate) basis: crate::Basis<'a>,
_lifeline: PhantomData<&'a ()>,
}

// -----------------------------------------------------------------------------
// Implementations
// -----------------------------------------------------------------------------
impl<'a> OperatorField<'a> {
pub(crate) fn from_raw(
ptr: bind_ceed::CeedOperatorField,
ceed: crate::Ceed,
) -> crate::Result<Self> {
let vector = {
let mut vector_ptr = std::ptr::null_mut();
let ierr = unsafe { bind_ceed::CeedOperatorFieldGetVector(ptr, &mut vector_ptr) };
ceed.check_error(ierr)?;
crate::Vector::from_raw(vector_ptr)?
};
let elem_restriction = {
let mut elem_restriction_ptr = std::ptr::null_mut();
let ierr = unsafe {
bind_ceed::CeedOperatorFieldGetElemRestriction(ptr, &mut elem_restriction_ptr)
};
ceed.check_error(ierr)?;
crate::ElemRestriction::from_raw(elem_restriction_ptr)?
};
let basis = {
let mut basis_ptr = std::ptr::null_mut();
let ierr = unsafe { bind_ceed::CeedOperatorFieldGetBasis(ptr, &mut basis_ptr) };
ceed.check_error(ierr)?;
crate::Basis::from_raw(basis_ptr)?
};
Ok(Self {
ptr,
vector,
elem_restriction,
basis,
_lifeline: PhantomData,
})
}

/// Get the name of an OperatorField
///
/// ```
Expand Down Expand Up @@ -110,24 +146,21 @@ impl<'a> OperatorField<'a> {
/// inputs[1].elem_restriction().is_none(),
/// "Incorrect field ElemRestriction"
/// );
///
/// let outputs = op.outputs()?;
///
/// assert!(
/// outputs[0].elem_restriction().is_some(),
/// "Incorrect field ElemRestriction"
/// );
/// # Ok(())
/// # }
/// ```
pub fn elem_restriction(&self) -> ElemRestrictionOpt {
let mut ptr = std::ptr::null_mut();
unsafe {
bind_ceed::CeedOperatorFieldGetElemRestriction(self.ptr, &mut ptr);
}
if ptr == unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE } {
if self.elem_restriction.ptr == unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE } {
ElemRestrictionOpt::None
} else {
let slice = unsafe {
std::slice::from_raw_parts(
&ptr as *const bind_ceed::CeedElemRestriction as *const crate::ElemRestriction,
1 as usize,
)
};
ElemRestrictionOpt::Some(&slice[0])
ElemRestrictionOpt::Some(&self.elem_restriction)
}
}

Expand Down Expand Up @@ -172,20 +205,10 @@ impl<'a> OperatorField<'a> {
/// # }
/// ```
pub fn basis(&self) -> BasisOpt {
let mut ptr = std::ptr::null_mut();
unsafe {
bind_ceed::CeedOperatorFieldGetBasis(self.ptr, &mut ptr);
}
if ptr == unsafe { bind_ceed::CEED_BASIS_NONE } {
if self.basis.ptr == unsafe { bind_ceed::CEED_BASIS_NONE } {
BasisOpt::None
} else {
let slice = unsafe {
std::slice::from_raw_parts(
&ptr as *const bind_ceed::CeedBasis as *const crate::Basis,
1 as usize,
)
};
BasisOpt::Some(&slice[0])
BasisOpt::Some(&self.basis)
}
}

Expand Down Expand Up @@ -222,26 +245,20 @@ impl<'a> OperatorField<'a> {
///
/// assert!(inputs[0].vector().is_active(), "Incorrect field Vector");
/// assert!(inputs[1].vector().is_none(), "Incorrect field Vector");
///
/// let outputs = op.outputs()?;
///
/// assert!(outputs[0].vector().is_active(), "Incorrect field Vector");
/// # Ok(())
/// # }
/// ```
pub fn vector(&self) -> VectorOpt {
let mut ptr = std::ptr::null_mut();
unsafe {
bind_ceed::CeedOperatorFieldGetVector(self.ptr, &mut ptr);
}
if ptr == unsafe { bind_ceed::CEED_VECTOR_ACTIVE } {
if self.vector.ptr == unsafe { bind_ceed::CEED_VECTOR_ACTIVE } {
VectorOpt::Active
} else if ptr == unsafe { bind_ceed::CEED_VECTOR_NONE } {
} else if self.vector.ptr == unsafe { bind_ceed::CEED_VECTOR_NONE } {
VectorOpt::None
} else {
let slice = unsafe {
std::slice::from_raw_parts(
&ptr as *const bind_ceed::CeedVector as *const crate::Vector,
1 as usize,
)
};
VectorOpt::Some(&slice[0])
VectorOpt::Some(&self.vector)
}
}
}
Expand Down Expand Up @@ -814,7 +831,7 @@ impl<'a> Operator<'a> {
/// # Ok(())
/// # }
/// ```
pub fn inputs(&self) -> crate::Result<&[crate::OperatorField]> {
pub fn inputs(&self) -> crate::Result<Vec<crate::OperatorField>> {
// Get array of raw C pointers for inputs
let mut num_inputs = 0;
let mut inputs_ptr = std::ptr::null_mut();
Expand All @@ -831,11 +848,24 @@ impl<'a> Operator<'a> {
// Convert raw C pointers to fixed length slice
let inputs_slice = unsafe {
std::slice::from_raw_parts(
inputs_ptr as *const crate::OperatorField,
inputs_ptr as *mut bind_ceed::CeedOperatorField,
num_inputs as usize,
)
};
Ok(inputs_slice)
// And finally build vec
let ceed = {
let mut ptr = std::ptr::null_mut();
let mut ptr_copy = std::ptr::null_mut();
unsafe {
bind_ceed::CeedOperatorGetCeed(self.op_core.ptr, &mut ptr);
bind_ceed::CeedReferenceCopy(ptr, &mut ptr_copy); // refcount
}
crate::Ceed { ptr }
};
let inputs = (0..num_inputs as usize)
.map(|i| crate::OperatorField::from_raw(inputs_slice[i], ceed.clone()))
.collect::<crate::Result<Vec<_>>>()?;
Ok(inputs)
}

/// Get a slice of Operator outputs
Expand Down Expand Up @@ -873,7 +903,7 @@ impl<'a> Operator<'a> {
/// # Ok(())
/// # }
/// ```
pub fn outputs(&self) -> crate::Result<&[crate::OperatorField]> {
pub fn outputs(&self) -> crate::Result<Vec<crate::OperatorField>> {
// Get array of raw C pointers for outputs
let mut num_outputs = 0;
let mut outputs_ptr = std::ptr::null_mut();
Expand All @@ -890,11 +920,24 @@ impl<'a> Operator<'a> {
// Convert raw C pointers to fixed length slice
let outputs_slice = unsafe {
std::slice::from_raw_parts(
outputs_ptr as *const crate::OperatorField,
outputs_ptr as *mut bind_ceed::CeedOperatorField,
num_outputs as usize,
)
};
Ok(outputs_slice)
// And finally build vec
let ceed = {
let mut ptr = std::ptr::null_mut();
let mut ptr_copy = std::ptr::null_mut();
unsafe {
bind_ceed::CeedOperatorGetCeed(self.op_core.ptr, &mut ptr);
bind_ceed::CeedReferenceCopy(ptr, &mut ptr_copy); // refcount
}
crate::Ceed { ptr }
};
let outputs = (0..num_outputs as usize)
.map(|i| crate::OperatorField::from_raw(outputs_slice[i], ceed.clone()))
.collect::<crate::Result<Vec<_>>>()?;
Ok(outputs)
}

/// Check if Operator is setup correctly
Expand Down

0 comments on commit 80db90e

Please sign in to comment.