Skip to content

Commit

Permalink
refactor(expression): add function cast_tensor_expression for casting
Browse files Browse the repository at this point in the history
This function casts any `tensor_expression` to its child class, and it
also handles recursive casting to get the real expression that is stored
inside the layers of `tensor_expression`.
  • Loading branch information
amitsingh19975 committed Feb 13, 2022
1 parent d70a701 commit 8ec747c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 80 deletions.
69 changes: 37 additions & 32 deletions include/boost/numeric/ublas/tensor/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,33 @@ static constexpr bool does_exp_need_cast_v = does_exp_need_cast< std::decay_t<T>
template<typename E, typename T>
struct does_exp_need_cast< tensor_expression<T,E> > : std::true_type{};

/**
* @brief It is a safer way of casting `tensor_expression` because it handles
* recursive expressions. Otherwise, in most of the cases, we try to access
* `operator()`, which requires a parameter argument, that is not supported
* by the `tensor_expression` class and might give an error if it is not casted
* properly.
*
* @tparam T type of the tensor
* @tparam E type of the child stored inside tensor_expression
* @param e tensor_expression that needs to be casted
* @return child of tensor_expression that is not tensor_expression
*/
template<typename T, typename E>
constexpr auto const& cast_tensor_exression(tensor_expression<T,E> const& e) noexcept{
auto const& res = e();
if constexpr(does_exp_need_cast_v<decltype(res)>)
return cast_tensor_exression(res);
else
return res;
}


// FIXME: remove it when template expression support for the old matrix and vector is removed
/// @brief No Op: Any expression other than `tensor_expression`.
template<typename E>
constexpr auto const& cast_tensor_exression(E const& e) noexcept{ return e; }

template<typename E, typename T>
constexpr auto is_tensor_expression_impl(tensor_expression<T,E> const*) -> std::true_type;

Expand Down Expand Up @@ -137,33 +164,15 @@ struct binary_tensor_expression
binary_tensor_expression(const binary_tensor_expression& l) = delete;
binary_tensor_expression& operator=(binary_tensor_expression const& l) noexcept = delete;

constexpr auto const& left_expr() const noexcept{ return cast_tensor_exression(el); }
constexpr auto const& right_expr() const noexcept{ return cast_tensor_exression(er); }

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
{
return op(el()(i), er()(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (does_exp_need_cast_v<expression_type_left> && !does_exp_need_cast_v<expression_type_right>)
{
return op(el()(i), er(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const
requires (!does_exp_need_cast_v<expression_type_left> && does_exp_need_cast_v<expression_type_right>)
{
return op(el(i), er()(i));
}

[[nodiscard]] inline
constexpr decltype(auto) operator()(size_type i) const {
return op(el(i), er(i));
constexpr decltype(auto) operator()(size_type i) const {
return op(left_expr()(i), right_expr()(i));
}

private:
expression_type_left el;
expression_type_right er;
binary_operation op;
Expand Down Expand Up @@ -211,19 +220,15 @@ struct unary_tensor_expression
constexpr unary_tensor_expression() = delete;
unary_tensor_expression(unary_tensor_expression const& l) = delete;
unary_tensor_expression& operator=(unary_tensor_expression const& l) noexcept = delete;

[[nodiscard]] inline constexpr
decltype(auto) operator()(size_type i) const
requires does_exp_need_cast_v<expression_type>
{
return op(e()(i));
}

constexpr auto const& expr() const noexcept{ return cast_tensor_exression(e); }

[[nodiscard]] inline constexpr
decltype(auto) operator()(size_type i) const {
return op(e(i));
decltype(auto) operator()(size_type i) const {
return op(expr()(i));
}

private:
expression_type e;
unary_operation op;
};
Expand Down
50 changes: 33 additions & 17 deletions include/boost/numeric/ublas/tensor/expression_evaluation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
static_assert(has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
return expr.el.extents();
return lexpr.extents();

else if constexpr ( same_exp<T,ER> )
return expr.er.extents();
return rexpr.extents();

else if constexpr ( has_tensor_types_v<T,EL> )
return retrieve_extents(expr.el);
return retrieve_extents(lexpr);

else if constexpr ( has_tensor_types_v<T,ER> )
return retrieve_extents(expr.er);
return retrieve_extents(rexpr);
}

#ifdef _MSC_VER
Expand All @@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)

static_assert(has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
"Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
return expr.e.extents();
return uexpr.extents();

else if constexpr ( has_tensor_types_v<T,E> )
return retrieve_extents(expr.e);
return retrieve_extents(uexpr);
}

} // namespace boost::numeric::ublas::detail
Expand Down Expand Up @@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& exp
using ::operator==;
using ::operator!=;

auto const& lexpr = expr.left_expr();
auto const& rexpr = expr.right_expr();

if constexpr ( same_exp<T,EL> )
if(e != expr.el.extents())
if(e != lexpr.extents())
return false;

if constexpr ( same_exp<T,ER> )
if(e != expr.er.extents())
if(e != rexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,EL> )
if(!all_extents_equal(expr.el, e))
if(!all_extents_equal(lexpr, e))
return false;

if constexpr ( has_tensor_types_v<T,ER> )
if(!all_extents_equal(expr.er, e))
if(!all_extents_equal(rexpr, e))
return false;

return true;
Expand All @@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, ex

using ::operator==;

auto const& uexpr = expr.expr();

if constexpr ( same_exp<T,E> )
if(e != expr.e.extents())
if(e != uexpr.extents())
return false;

if constexpr ( has_tensor_types_v<T,E> )
if(!all_extents_equal(expr.e, e))
if(!all_extents_equal(uexpr, e))
return false;

return true;
Expand All @@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
if(!all_extents_equal(expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");

#pragma omp parallel for
auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
lhs(i) = expr()(i);
lhs(i) = rhs(i);
}

/** @brief Evaluates expression for a tensor_core
Expand All @@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression<other_tensor_type, derived_
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");
}

auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
lhs(i) = expr()(i);
lhs(i) = rhs(i);
}

/** @brief Evaluates expression for a tensor_core
Expand All @@ -330,9 +344,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
if(!all_extents_equal( expr, lhs.extents() ))
throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes.");

auto const& rhs = cast_tensor_exression(expr);

#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i), expr()(i));
fn(lhs(i), rhs(i));
}


Expand All @@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
template<class tensor_type, class unary_fn>
inline void eval(tensor_type& lhs, unary_fn const& fn)
{
#pragma omp parallel for
#pragma omp parallel for
for(auto i = 0u; i < lhs.size(); ++i)
fn(lhs(i));
}
Expand Down
2 changes: 1 addition & 1 deletion include/boost/numeric/ublas/tensor/multiplication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ void mtv(SizeType const m,
template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
void mtm(PointerOut c, SizeType const*const nc, SizeType const*const wc,
PointerIn1 a, SizeType const*const na, SizeType const*const wa,
PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
PointerIn2 b, [[maybe_unused]] SizeType const*const nb, SizeType const*const wb)
{

// C(i,j) = A(i,k) * B(k,j)
Expand Down
48 changes: 24 additions & 24 deletions test/tensor/tensor/test_tensor_binary_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic,
auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus1 );
auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus2 );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) );
Expand All @@ -59,8 +59,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic,

auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( uexpr1, uexpr2, bplus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) );


for(auto i = 0ul; i < t.size(); ++i){
Expand All @@ -69,10 +69,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic,

auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( bexpr_uexpr, t, bminus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) );
Expand Down Expand Up @@ -113,8 +113,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank,
auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus1 );
auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus2 );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) );
Expand All @@ -126,8 +126,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank,

auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( uexpr1, uexpr2, bplus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) );


for(auto i = 0ul; i < t.size(); ++i){
Expand All @@ -136,10 +136,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank,

auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( bexpr_uexpr, t, bminus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) );
Expand Down Expand Up @@ -180,8 +180,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static,
auto uexpr1 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus1 );
auto uexpr2 = ublas::detail::make_unary_tensor_expression<tensor_t>( t, uplus2 );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) );
Expand All @@ -193,8 +193,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static,

auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( uexpr1, uexpr2, bplus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) );


for(auto i = 0ul; i < t.size(); ++i){
Expand All @@ -203,10 +203,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static,

auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression<tensor_t>( bexpr_uexpr, t, bminus );

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );
BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) );

for(auto i = 0ul; i < t.size(); ++i){
BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) );
Expand Down
12 changes: 6 additions & 6 deletions test/tensor/tensor/test_tensor_unary_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic,
BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) );
}

const auto & uexpr_e = uexpr.e;
const auto & uexpr_e = uexpr.expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) );

const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e;
const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) );
}
Expand Down Expand Up @@ -101,11 +101,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank,
BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) );
}

const auto & uexpr_e = uexpr.e;
const auto & uexpr_e = uexpr.expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) );

const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e;
const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) );
}
Expand Down Expand Up @@ -150,11 +150,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static,
BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) );
}

const auto & uexpr_e = uexpr.e;
const auto & uexpr_e = uexpr.expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) );

const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e;
const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr();

BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) );
}
Expand Down

0 comments on commit 8ec747c

Please sign in to comment.