diff --git a/include/boost/numeric/ublas/tensor/expression.hpp b/include/boost/numeric/ublas/tensor/expression.hpp index d33c385a3..74ef08e1a 100644 --- a/include/boost/numeric/ublas/tensor/expression.hpp +++ b/include/boost/numeric/ublas/tensor/expression.hpp @@ -43,6 +43,33 @@ static constexpr bool does_exp_need_cast_v = does_exp_need_cast< std::decay_t template struct does_exp_need_cast< tensor_expression > : std::true_type{}; +/** + * @brief It is a safer way of casting `tensor_expression` because it handles + * recursive expressions. Otherwise, in most of the cases, we try to access + * `operator()`, which requires a parameter argument, that is not supported + * by the `tensor_expression` class and might give an error if it is not casted + * properly. + * + * @tparam T type of the tensor + * @tparam E type of the child stored inside tensor_expression + * @param e tensor_expression that needs to be casted + * @return child of tensor_expression that is not tensor_expression + */ +template +constexpr auto const& cast_tensor_exression(tensor_expression const& e) noexcept{ + auto const& res = e(); + if constexpr(does_exp_need_cast_v) + return cast_tensor_exression(res); + else + return res; +} + + +// FIXME: remove it when template expression support for the old matrix and vector is removed +/// @brief No Op: Any expression other than `tensor_expression`. +template +constexpr auto const& cast_tensor_exression(E const& e) noexcept{ return e; } + template constexpr auto is_tensor_expression_impl(tensor_expression const*) -> std::true_type; @@ -137,33 +164,15 @@ struct binary_tensor_expression binary_tensor_expression(const binary_tensor_expression& l) = delete; binary_tensor_expression& operator=(binary_tensor_expression const& l) noexcept = delete; + constexpr auto const& left_expr() const noexcept{ return cast_tensor_exression(el); } + constexpr auto const& right_expr() const noexcept{ return cast_tensor_exression(er); } [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (does_exp_need_cast_v && does_exp_need_cast_v) - { - return op(el()(i), er()(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (does_exp_need_cast_v && !does_exp_need_cast_v) - { - return op(el()(i), er(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (!does_exp_need_cast_v && does_exp_need_cast_v) - { - return op(el(i), er()(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const { - return op(el(i), er(i)); + constexpr decltype(auto) operator()(size_type i) const { + return op(left_expr()(i), right_expr()(i)); } +private: expression_type_left el; expression_type_right er; binary_operation op; @@ -211,19 +220,15 @@ struct unary_tensor_expression constexpr unary_tensor_expression() = delete; unary_tensor_expression(unary_tensor_expression const& l) = delete; unary_tensor_expression& operator=(unary_tensor_expression const& l) noexcept = delete; - - [[nodiscard]] inline constexpr - decltype(auto) operator()(size_type i) const - requires does_exp_need_cast_v - { - return op(e()(i)); - } + + constexpr auto const& expr() const noexcept{ return cast_tensor_exression(e); } [[nodiscard]] inline constexpr - decltype(auto) operator()(size_type i) const { - return op(e(i)); + decltype(auto) operator()(size_type i) const { + return op(expr()(i)); } +private: expression_type e; unary_operation op; }; diff --git a/include/boost/numeric/ublas/tensor/expression_evaluation.hpp b/include/boost/numeric/ublas/tensor/expression_evaluation.hpp index 37e9f1e48..b18203ce2 100644 --- a/include/boost/numeric/ublas/tensor/expression_evaluation.hpp +++ b/include/boost/numeric/ublas/tensor/expression_evaluation.hpp @@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression const& exp static_assert(has_tensor_types_v>, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); + auto const& lexpr = expr.left_expr(); + auto const& rexpr = expr.right_expr(); + if constexpr ( same_exp ) - return expr.el.extents(); + return lexpr.extents(); else if constexpr ( same_exp ) - return expr.er.extents(); + return rexpr.extents(); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.el); + return retrieve_extents(lexpr); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.er); + return retrieve_extents(rexpr); } #ifdef _MSC_VER @@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression const& expr) static_assert(has_tensor_types_v>, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); + + auto const& uexpr = expr.expr(); if constexpr ( same_exp ) - return expr.e.extents(); + return uexpr.extents(); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.e); + return retrieve_extents(uexpr); } } // namespace boost::numeric::ublas::detail @@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression const& exp using ::operator==; using ::operator!=; + auto const& lexpr = expr.left_expr(); + auto const& rexpr = expr.right_expr(); + if constexpr ( same_exp ) - if(e != expr.el.extents()) + if(e != lexpr.extents()) return false; if constexpr ( same_exp ) - if(e != expr.er.extents()) + if(e != rexpr.extents()) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.el, e)) + if(!all_extents_equal(lexpr, e)) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.er, e)) + if(!all_extents_equal(rexpr, e)) return false; return true; @@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression const& expr, ex using ::operator==; + auto const& uexpr = expr.expr(); + if constexpr ( same_exp ) - if(e != expr.e.extents()) + if(e != uexpr.extents()) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.e, e)) + if(!all_extents_equal(uexpr, e)) return false; return true; @@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression if(!all_extents_equal(expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes."); -#pragma omp parallel for + auto const& rhs = cast_tensor_exression(expr); + + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) - lhs(i) = expr()(i); + lhs(i) = rhs(i); } /** @brief Evaluates expression for a tensor_core @@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression if(!all_extents_equal( expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes."); + auto const& rhs = cast_tensor_exression(expr); + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) - fn(lhs(i), expr()(i)); + fn(lhs(i), rhs(i)); } @@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression template inline void eval(tensor_type& lhs, unary_fn const& fn) { -#pragma omp parallel for + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) fn(lhs(i)); } diff --git a/include/boost/numeric/ublas/tensor/multiplication.hpp b/include/boost/numeric/ublas/tensor/multiplication.hpp index 6a9c0613b..9fc5abb95 100644 --- a/include/boost/numeric/ublas/tensor/multiplication.hpp +++ b/include/boost/numeric/ublas/tensor/multiplication.hpp @@ -389,7 +389,7 @@ void mtv(SizeType const m, template void mtm(PointerOut c, SizeType const*const nc, SizeType const*const wc, PointerIn1 a, SizeType const*const na, SizeType const*const wa, - PointerIn2 b, SizeType const*const nb, SizeType const*const wb) + PointerIn2 b, [[maybe_unused]] SizeType const*const nb, SizeType const*const wb) { // C(i,j) = A(i,k) * B(k,j) diff --git a/test/tensor/tensor/test_tensor_binary_expression.cpp b/test/tensor/tensor/test_tensor_binary_expression.cpp index 3bb1b436c..f03a61b0c 100644 --- a/test/tensor/tensor/test_tensor_binary_expression.cpp +++ b/test/tensor/tensor/test_tensor_binary_expression.cpp @@ -46,8 +46,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic, auto uexpr1 = ublas::detail::make_unary_tensor_expression( t, uplus1 ); auto uexpr2 = ublas::detail::make_unary_tensor_expression( t, uplus2 ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) ); @@ -59,8 +59,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic, auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression( uexpr1, uexpr2, bplus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ @@ -69,10 +69,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic, auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression( bexpr_uexpr, t, bminus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) ); @@ -113,8 +113,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank, auto uexpr1 = ublas::detail::make_unary_tensor_expression( t, uplus1 ); auto uexpr2 = ublas::detail::make_unary_tensor_expression( t, uplus2 ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) ); @@ -126,8 +126,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank, auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression( uexpr1, uexpr2, bplus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ @@ -136,10 +136,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank, auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression( bexpr_uexpr, t, bminus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) ); @@ -180,8 +180,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static, auto uexpr1 = ublas::detail::make_unary_tensor_expression( t, uplus1 ); auto uexpr2 = ublas::detail::make_unary_tensor_expression( t, uplus2 ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr1.expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr2.expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( uexpr1(i), uplus1(t(i)) ); @@ -193,8 +193,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static, auto bexpr_uexpr = ublas::detail::make_binary_tensor_expression( uexpr1, uexpr2, bplus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.er.e) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_uexpr.right_expr().expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ @@ -203,10 +203,10 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static, auto bexpr_bexpr_uexpr = ublas::detail::make_binary_tensor_expression( bexpr_uexpr, t, bminus ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.el.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.el.er.e) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); - BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.er) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().left_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.left_expr().right_expr().expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); + BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(bexpr_bexpr_uexpr.right_expr()) >, tensor_t > ) ); for(auto i = 0ul; i < t.size(); ++i){ BOOST_CHECK_EQUAL( bexpr_bexpr_uexpr(i), bminus(bexpr_uexpr(i),t(i)) ); diff --git a/test/tensor/tensor/test_tensor_unary_expression.cpp b/test/tensor/tensor/test_tensor_unary_expression.cpp index 789f8de7d..1dcab3ac7 100644 --- a/test/tensor/tensor/test_tensor_unary_expression.cpp +++ b/test/tensor/tensor/test_tensor_unary_expression.cpp @@ -52,11 +52,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_dynamic, BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) ); } - const auto & uexpr_e = uexpr.e; + const auto & uexpr_e = uexpr.expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) ); - const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e; + const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) ); } @@ -101,11 +101,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static_rank, BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) ); } - const auto & uexpr_e = uexpr.e; + const auto & uexpr_e = uexpr.expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) ); - const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e; + const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) ); } @@ -150,11 +150,11 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(test_tensor_static, BOOST_CHECK_EQUAL( uexpr_uexpr(i), uplus1(uplus1(t(i))) ); } - const auto & uexpr_e = uexpr.e; + const auto & uexpr_e = uexpr.expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_e) >, tensor_t > ) ); - const auto & uexpr_uexpr_e_e = uexpr_uexpr.e.e; + const auto & uexpr_uexpr_e_e = uexpr_uexpr.expr().expr(); BOOST_CHECK( ( std::is_same_v< std::decay_t< decltype(uexpr_uexpr_e_e) >, tensor_t > ) ); }