Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial draft for a specialized computed_assign for xfixed #2394

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions include/xtensor/xassign.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ namespace xt
static void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial);
};

namespace detail{

template<class S>
struct is_fixed_shape : public std::false_type
{
};

template<std::size_t ... X>
struct is_fixed_shape<xt::fixed_shape< X ...>> : public std::true_type
{
};

template<class E>
using has_fixed_shape = is_fixed_shape<typename E::shape_type>;

template<class E>
using enable_if_has_fixed_shape_t = std::enable_if_t<has_fixed_shape<E>::value>;

template<class E>
using enable_if_has_no_fixed_shape_t = std::enable_if_t<!has_fixed_shape<E>::value>;

}


template <class Tag>
class xexpression_assigner : public xexpression_assigner_base<Tag>
{
Expand All @@ -84,7 +108,12 @@ namespace xt
template <class E1, class E2>
static void assign_xexpression(E1& e1, const E2& e2);

template <class E1, class E2>


template <class E1, class E2, typename detail::enable_if_has_no_fixed_shape_t<E1> * = nullptr>
static void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2);

template <class E1, class E2, typename detail::enable_if_has_fixed_shape_t<E1> * = nullptr>
static void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2);

template <class E1, class E2, class F>
Expand Down Expand Up @@ -415,7 +444,7 @@ namespace xt
}

template <class Tag>
template <class E1, class E2>
template <class E1, class E2, typename detail::enable_if_has_no_fixed_shape_t<E1> *>
inline void xexpression_assigner<Tag>::computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
{
using shape_type = typename E1::shape_type;
Expand Down Expand Up @@ -444,6 +473,39 @@ namespace xt
}
}


template <class Tag>
template <class E1, class E2, typename detail::enable_if_has_fixed_shape_t<E1> *>
inline void xexpression_assigner<Tag>::computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
{
using shape_type = typename E1::shape_type;

E1& de1 = e1.derived_cast();
const E2& de2 = e2.derived_cast();

if(de1.dimension() != de2.dimension())
{
// not sure if best error
throw_broadcast_error(de1.shape(), de2.shape());
}
else
{
auto && shape1 = de1.shape();
auto && shape2 = de2.shape();
if(std::equal(shape1.begin(), shape1.end(), shape2.begin()))
{
// the tests fail if just set it to true for the case
// when creating e2 itself involved broadcasting
base_type::assign_data(e1, e2, false/*trivial_broadcast*/);
}
else
{
// not sure if best error
throw_broadcast_error(de1.shape(), de2.shape());
}
}
}

template <class Tag>
template <class E1, class E2, class F>
inline void xexpression_assigner<Tag>::scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f)
Expand Down
140 changes: 139 additions & 1 deletion test/test_xassign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/



#if defined(_MSC_VER) && !defined(__clang__)
#define VS_SKIP_XFIXED 1
#endif


#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"

#ifndef VS_SKIP_XFIXED
#include "xtensor/xfixed.hpp"
#endif
#include "xtensor/xassign.hpp"
#include "xtensor/xnoalias.hpp"
#include "test_common.hpp"
Expand All @@ -19,6 +28,19 @@
#include <vector>



// On VS2015, when compiling in x86 mode, alignas(T) leads to C2718
// when used for a function parameter, even indirectly. This means that
// we cannot pass parameters whose class is declared with alignas specifier
// or any type wrapping or inheriting from such a type.
// The xtensor_fixed class internally uses aligned_array which is declared as
// alignas(something_different_from_0), hence the workaround.
#if _MSC_VER < 1910 && !_WIN64
#define VS_X86_WORKAROUND 1
#endif



// a dummy shape *not derived* from std::vector but compatible
template<class T>
class my_vector
Expand Down Expand Up @@ -143,6 +165,122 @@ namespace xt
EXPECT_EQ(a.shape(0), 2);
EXPECT_EQ(a.shape(1), 3);
}
}

// test_fixed removed from MSVC x86 because of recurring ICE.
// Will be enabled again when the compiler is fixed

#ifndef VS_SKIP_XFIXED
#if (_MSC_VER < 1910 && _WIN64) || (_MSC_VER >= 1910 && !defined(DISABLE_VS2017)) || !defined(_MSC_VER)


TEST(xassign, fixed_shape)
{
// matching shape 1D
{
xt::xtensor_fixed<int, xt::xshape<2>> a = {2,3};
xt::xtensor_fixed<int, xt::xshape<2>> b = {3,4};

xt::noalias(a) += b;

EXPECT_EQ(a(0), 5);
EXPECT_EQ(a(1), 7);
}
//matching shape 2D
{
xt::xtensor_fixed<int, xt::xshape<2,2>> aa = {{1,2},{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> a(aa);
xt::xtensor_fixed<int, xt::xshape<2,2>> b = {{3,4},{5,6}};
xt::noalias(a) += b;

EXPECT_EQ(a(0,0), aa(0,0) + b(0,0));
EXPECT_EQ(a(0,1), aa(0,1) + b(0,1));
EXPECT_EQ(a(1,0), aa(1,0) + b(1,0));
EXPECT_EQ(a(1,1), aa(1,1) + b(1,1));
}
// b is broadcasted with matching dimension (first axis is singleton)
{
xt::xtensor_fixed<int, xt::xshape<2,2>> aa = {{1,2},{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> a(aa);
xt::xarray<int> b = {{5,6}};
EXPECT_EQ(b.shape(0),1);
EXPECT_EQ(b.shape(1),2);
xt::noalias(a) += b;

EXPECT_EQ(a(0,0), aa(0,0) + b(0,0));
EXPECT_EQ(a(0,1), aa(0,1) + b(0,1));
EXPECT_EQ(a(1,0), aa(1,0) + b(0,0));
EXPECT_EQ(a(1,1), aa(1,1) + b(0,1));
}
// b is broadcasted with matching dimension (second axis is singleton)
{
xt::xtensor_fixed<int, xt::xshape<2,2>> aa = {{1,2},{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> a(aa);
xt::xtensor_fixed<int, xt::xshape<2,1>> b = {{5},{6}};
EXPECT_EQ(b.shape(0),2);
EXPECT_EQ(b.shape(1),1);
xt::noalias(a) += b;

EXPECT_EQ(a(0,0), aa(0,0) + b.at(0,0));
EXPECT_EQ(a(0,1), aa(0,1) + b.at(0,0));
EXPECT_EQ(a(1,0), aa(1,0) + b.at(1,0));
EXPECT_EQ(a(1,1), aa(1,1) + b.at(1,0));
}
// b is broadcasted with matching dimension (first axis is singleton)
{
xt::xtensor_fixed<int, xt::xshape<2,2>> aa = {{1,2},{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> a(aa);
xt::xtensor_fixed<int, xt::xshape<1,2>> b = {{3,4}};
EXPECT_EQ(b.shape(0),1);
EXPECT_EQ(b.shape(1),2);
xt::noalias(a) += b;

EXPECT_EQ(a(0,0), aa(0,0) + b(0,0));
EXPECT_EQ(a(0,1), aa(0,1) + b(0,1));
EXPECT_EQ(a(1,0), aa(1,0) + b(0,0));
EXPECT_EQ(a(1,1), aa(1,1) + b(0,1));
}
// broadcast with non matching dimensions
{
xt::xtensor_fixed<int, xt::xshape<2,2>> aa = {{1,2},{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> a(aa);
xt::xtensor_fixed<int, xt::xshape<2>> b = {3,4};

xt::noalias(a) += b;

EXPECT_EQ(a(0,0), aa(0,0) + b(0));
EXPECT_EQ(a(0,1), aa(0,1) + b(1));
EXPECT_EQ(a(1,0), aa(1,0) + b(0));
EXPECT_EQ(a(1,1), aa(1,1) + b(1));
}
}
TEST(xassign, fixed_raises)
{
// cannot broadcast a itself on assignment
{
xt::xtensor_fixed<int, xt::xshape<2>> a = {2,3};
xt::xtensor_fixed<int, xt::xshape<2,2>> b = {{3,4},{3,4}};

EXPECT_THROW(xt::noalias(a) += b, xt::broadcast_error);
}

// cannot broadcast a itself on assignment
{
xt::xtensor_fixed<int, xt::xshape<1,2>> a = {{3,4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> b = {{3,4},{3,4}};

EXPECT_THROW(xt::noalias(a) += b, xt::broadcast_error);
}

// cannot broadcast a itself on assignment
{
xt::xtensor_fixed<int, xt::xshape<2,1>> a = {{3},{4}};
xt::xtensor_fixed<int, xt::xshape<2,2>> b = {{3,4},{3,4}};

EXPECT_THROW(xt::noalias(a) += b, xt::broadcast_error);
}
}
#endif
#endif // VS_SKIP_XFIXED

}