Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify splines names #382

Merged
merged 14 commits into from
Apr 19, 2024
4 changes: 2 additions & 2 deletions benchmarks/splines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ static void characteristics_advection(benchmark::State& state)
DDimY>
spline_evaluator(periodic_extrapolation, periodic_extrapolation);
ddc::Chunk coef_alloc(
spline_builder.spline_domain(),
spline_builder.batched_spline_domain(),
ddc::KokkosAllocator<double, Kokkos::DefaultExecutionSpace::memory_space>());
ddc::ChunkSpan coef = coef_alloc.span_view();
ddc::Chunk feet_coords_alloc(
spline_builder.vals_domain(),
spline_builder.batched_interpolation_domain(),
ddc::KokkosAllocator<
ddc::Coordinate<X, Y>,
Kokkos::DefaultExecutionSpace::memory_space>());
Expand Down
4 changes: 2 additions & 2 deletions examples/characteristics_advection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ int main(int argc, char** argv)
//! [instantiate intermediate chunks]
// Instantiate chunk of spline coefs to receive output of spline_builder
ddc::Chunk coef_alloc(
spline_builder.spline_domain(),
spline_builder.batched_spline_domain(),
ddc::DeviceAllocator<double>());
ddc::ChunkSpan coef = coef_alloc.span_view();

// Instantiate chunk to receive feet coords
ddc::Chunk feet_coords_alloc(
spline_builder.vals_domain(),
spline_builder.batched_interpolation_domain(),
ddc::DeviceAllocator<ddc::Coordinate<X>>());
ddc::ChunkSpan feet_coords = feet_coords_alloc.span_view();
//! [instantiate intermediate chunks]
Expand Down
79 changes: 45 additions & 34 deletions include/ddc/kernels/splines/spline_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SplineBuilder
*/
using interpolation_domain_type = ddc::DiscreteDomain<interpolation_mesh_type>;

using vals_domain_type = ddc::DiscreteDomain<IDimX...>;
using batched_interpolation_domain_type = ddc::DiscreteDomain<IDimX...>;

using batch_domain_type =
typename ddc::detail::convert_type_seq_to_discrete_domain<ddc::type_seq_remove_t<
Expand All @@ -89,20 +89,20 @@ class SplineBuilder
using spline_dim_type
= std::conditional_t<std::is_same_v<Tag, interpolation_mesh_type>, bsplines_type, Tag>;

using spline_domain_type =
using batched_spline_domain_type =
typename ddc::detail::convert_type_seq_to_discrete_domain<ddc::type_seq_replace_t<
ddc::detail::TypeSeq<IDimX...>,
ddc::detail::TypeSeq<interpolation_mesh_type>,
ddc::detail::TypeSeq<bsplines_type>>>;

using spline_tr_domain_type =
using batched_spline_tr_domain_type =
typename ddc::detail::convert_type_seq_to_discrete_domain<ddc::type_seq_merge_t<
ddc::detail::TypeSeq<bsplines_type>,
ddc::type_seq_remove_t<
ddc::detail::TypeSeq<IDimX...>,
ddc::detail::TypeSeq<interpolation_mesh_type>>>>;

using derivs_domain_type =
using batched_derivs_domain_type =
typename ddc::detail::convert_type_seq_to_discrete_domain<ddc::type_seq_replace_t<
ddc::detail::TypeSeq<IDimX...>,
ddc::detail::TypeSeq<interpolation_mesh_type>,
Expand Down Expand Up @@ -152,7 +152,7 @@ class SplineBuilder
static constexpr ddc::BoundCond s_bc_xmax = BcXmax;

private:
vals_domain_type m_vals_domain;
batched_interpolation_domain_type m_batched_interpolation_domain;

int m_offset;

Expand All @@ -165,10 +165,10 @@ class SplineBuilder
int compute_offset(interpolation_domain_type const& interpolation_domain);

explicit SplineBuilder(
vals_domain_type const& vals_domain,
batched_interpolation_domain_type const& batched_interpolation_domain,
std::optional<int> cols_per_chunk = std::nullopt,
std::optional<unsigned int> preconditionner_max_block_size = std::nullopt)
: m_vals_domain(vals_domain)
: m_batched_interpolation_domain(batched_interpolation_domain)
, m_offset(compute_offset(interpolation_domain()))
, m_dx((ddc::discrete_space<BSplines>().rmax() - ddc::discrete_space<BSplines>().rmin())
/ ddc::discrete_space<BSplines>().ncells())
Expand Down Expand Up @@ -212,9 +212,9 @@ class SplineBuilder
*/
SplineBuilder& operator=(SplineBuilder&& x) = default;

vals_domain_type vals_domain() const noexcept
batched_interpolation_domain_type batched_interpolation_domain() const noexcept
{
return m_vals_domain;
return m_batched_interpolation_domain;
}

/**
Expand All @@ -227,15 +227,15 @@ class SplineBuilder
*/
interpolation_domain_type interpolation_domain() const noexcept
{
return interpolation_domain_type(vals_domain());
return interpolation_domain_type(batched_interpolation_domain());
}

batch_domain_type batch_domain() const noexcept
{
return ddc::remove_dims_of(vals_domain(), interpolation_domain());
return ddc::remove_dims_of(batched_interpolation_domain(), interpolation_domain());
}

ddc::DiscreteDomain<bsplines_type> bsplines_domain() const noexcept // TODO : clarify name
ddc::DiscreteDomain<bsplines_type> spline_domain() const noexcept // TODO : clarify name
blegouix marked this conversation as resolved.
Show resolved Hide resolved
{
return ddc::discrete_space<bsplines_type>().full_domain();
}
Expand All @@ -248,31 +248,31 @@ class SplineBuilder
*
* @return The domain for the splines.
*/
spline_domain_type spline_domain() const noexcept
batched_spline_domain_type batched_spline_domain() const noexcept
{
return ddc::replace_dim_of<
interpolation_mesh_type,
bsplines_type>(vals_domain(), bsplines_domain());
bsplines_type>(batched_interpolation_domain(), spline_domain());
}

spline_tr_domain_type spline_tr_domain() const noexcept
batched_spline_tr_domain_type spline_tr_domain() const noexcept
{
return spline_tr_domain_type(bsplines_domain(), batch_domain());
return batched_spline_tr_domain_type(spline_domain(), batch_domain());
}

derivs_domain_type derivs_xmin_domain() const noexcept
batched_derivs_domain_type derivs_xmin_domain() const noexcept
{
return ddc::replace_dim_of<interpolation_mesh_type, deriv_type>(
vals_domain(),
batched_interpolation_domain(),
ddc::DiscreteDomain<deriv_type>(
ddc::DiscreteElement<deriv_type>(1),
ddc::DiscreteVector<deriv_type>(s_nbc_xmin)));
}

derivs_domain_type derivs_xmax_domain() const noexcept
batched_derivs_domain_type derivs_xmax_domain() const noexcept
{
return ddc::replace_dim_of<interpolation_mesh_type, deriv_type>(
vals_domain(),
batched_interpolation_domain(),
ddc::DiscreteDomain<deriv_type>(
ddc::DiscreteElement<deriv_type>(1),
ddc::DiscreteVector<deriv_type>(s_nbc_xmax)));
Expand Down Expand Up @@ -309,15 +309,20 @@ class SplineBuilder
*/
template <class Layout>
void operator()(
ddc::ChunkSpan<double, spline_domain_type, Layout, memory_space> spline,
ddc::ChunkSpan<double const, vals_domain_type, Layout, memory_space> vals,
std::optional<
ddc::ChunkSpan<double const, derivs_domain_type, Layout, memory_space>> const
derivs_xmin
ddc::ChunkSpan<double, batched_spline_domain_type, Layout, memory_space> spline,
ddc::ChunkSpan<double const, batched_interpolation_domain_type, Layout, memory_space>
vals,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type,
Layout,
memory_space>> const derivs_xmin
= std::nullopt,
std::optional<
ddc::ChunkSpan<double const, derivs_domain_type, Layout, memory_space>> const
derivs_xmax
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type,
Layout,
memory_space>> const derivs_xmax
= std::nullopt) const;

private:
Expand Down Expand Up @@ -629,12 +634,18 @@ void SplineBuilder<
Solver,
IDimX...>::
operator()(
ddc::ChunkSpan<double, spline_domain_type, Layout, memory_space> spline,
ddc::ChunkSpan<double const, vals_domain_type, Layout, memory_space> vals,
std::optional<ddc::ChunkSpan<double const, derivs_domain_type, Layout, memory_space>> const
derivs_xmin,
std::optional<ddc::ChunkSpan<double const, derivs_domain_type, Layout, memory_space>> const
derivs_xmax) const
ddc::ChunkSpan<double, batched_spline_domain_type, Layout, memory_space> spline,
ddc::ChunkSpan<double const, batched_interpolation_domain_type, Layout, memory_space> vals,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type,
Layout,
memory_space>> const derivs_xmin,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type,
Layout,
memory_space>> const derivs_xmax) const
{
assert(vals.template extent<interpolation_mesh_type>()
== ddc::discrete_space<bsplines_type>().nbasis() - s_nbe_xmin - s_nbe_xmax);
Expand Down
Loading