Skip to content

Commit

Permalink
Merge pull request xtensor-stack#2768 from BioDataAnalysis/option_to_…
Browse files Browse the repository at this point in the history
…disable_temporary_object_in_assignment

Adding the ability to enable memory overlap check in assignment to avoid unneeded temporary memory allocation
  • Loading branch information
JohanMabille committed Mar 18, 2024
2 parents a17f3de + 4507f14 commit d9c3782
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ target_link_libraries(xtensor INTERFACE xtl)

OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
Expand All @@ -219,6 +220,10 @@ if(XTENSOR_CHECK_DIMENSION)
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
endif()

if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
endif()

if(DEFAULT_COLUMN_MAJOR)
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
endif()
Expand Down
23 changes: 23 additions & 0 deletions include/xtensor/xbroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ namespace xt
return linear_end(c.expression());
}

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

/**
* @class xbroadcast
* @brief Broadcasted xexpression to a specified shape.
Expand Down
36 changes: 36 additions & 0 deletions include/xtensor/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ namespace xt
{
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
{
template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
{
return false;
}

template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
{
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range)
|| check_tuple<I + 1>(t, dst_range);
}

static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return check_tuple(expr.arguments(), dst_range);
}
}
};

/*************
* xfunction *
*************/
Expand Down
15 changes: 15 additions & 0 deletions include/xtensor/xgenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ namespace xt
using size_type = std::size_t;
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
{
static bool check_overlap(const E&, const memory_range&)
{
return false;
}
};

/**
* @class xgenerator
* @brief Multidimensional function operating on indices.
Expand Down
37 changes: 37 additions & 0 deletions include/xtensor/xsemantic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,29 @@ namespace xt
template <class E, class R = void>
using disable_xcontainer_semantics = typename std::enable_if<!has_container_semantics<E>::value, R>::type;


template <class D>
class xview_semantic;

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_crtp_base_of<xview_semantic, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

/**
* @class xview_semantic
* @brief Implementation of the xsemantic_base interface for
Expand Down Expand Up @@ -598,8 +621,22 @@ namespace xt
template <class E>
inline auto xsemantic_base<D>::operator=(const xexpression<E>& e) -> derived_type&
{
#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS
temporary_type tmp(e);
return this->derived_cast().assign_temporary(std::move(tmp));
#else
auto&& this_derived = this->derived_cast();
auto memory_checker = make_overlapping_memory_checker(this_derived);
if (memory_checker.check_overlap(e.derived_cast()))
{
temporary_type tmp(e);
return this_derived.assign_temporary(std::move(tmp));
}
else
{
return this->assign(e);
}
#endif
}

/**************************************
Expand Down
147 changes: 147 additions & 0 deletions include/xtensor/xutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ namespace xt
using type = T;
};

/***************************************
* is_specialization_of implementation *
***************************************/

template <template <class...> class TT, class T>
struct is_specialization_of : std::false_type
{
};

template <template <class...> class TT, class... Ts>
struct is_specialization_of<TT, TT<Ts...>> : std::true_type
{
};

/*******************************
* remove_class implementation *
*******************************/
Expand Down Expand Up @@ -860,6 +874,139 @@ namespace xt
{
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class T, class Enable = void>
struct has_memory_address : std::false_type
{
};

template <class T>
struct has_memory_address<T, void_t<decltype(std::addressof(*std::declval<T>().begin()))>> : std::true_type
{
};

struct memory_range
{
// Checking pointer overlap is more correct in integer values,
// for more explanation check https://devblogs.microsoft.com/oldnewthing/20170927-00/?p=97095
const uintptr_t m_first = 0;
const uintptr_t m_last = 0;

explicit memory_range() = default;

template <class T>
explicit memory_range(T* first, T* last)
: m_first(reinterpret_cast<uintptr_t>(last < first ? last : first))
, m_last(reinterpret_cast<uintptr_t>(last < first ? first : last))
{
}

template <class T>
bool overlaps(T* first, T* last) const
{
if (first <= last)
{
return reinterpret_cast<uintptr_t>(first) <= m_last
&& reinterpret_cast<uintptr_t>(last) >= m_first;
}
else
{
return reinterpret_cast<uintptr_t>(last) <= m_last
&& reinterpret_cast<uintptr_t>(first) >= m_first;
}
}
};

template <class E, class Enable = void>
struct overlapping_memory_checker_traits
{
static bool check_overlap(const E&, const memory_range&)
{
return true;
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<has_memory_address<E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return dst_range.overlaps(std::addressof(*expr.begin()), std::addressof(*expr.rbegin()));
}
}
};

struct overlapping_memory_checker_base
{
memory_range m_dst_range;

explicit overlapping_memory_checker_base() = default;

explicit overlapping_memory_checker_base(memory_range dst_memory_range)
: m_dst_range(std::move(dst_memory_range))
{
}

template <class E>
bool check_overlap(const E& expr) const
{
if (!m_dst_range.m_first || !m_dst_range.m_last)
{
return false;
}
else
{
return overlapping_memory_checker_traits<E>::check_overlap(expr, m_dst_range);
}
}
};

template <class Dst, class Enable = void>
struct overlapping_memory_checker : overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst&)
: overlapping_memory_checker_base()
{
}
};

template <class Dst>
struct overlapping_memory_checker<Dst, std::enable_if_t<has_memory_address<Dst>::value>>
: overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst& aDst)
: overlapping_memory_checker_base(
[&]()
{
if (aDst.size() == 0)
{
return memory_range();
}
else
{
return memory_range(std::addressof(*aDst.begin()), std::addressof(*aDst.rbegin()));
}
}()
)
{
}
};

template <class Dst>
auto make_overlapping_memory_checker(const Dst& a_dst)
{
return overlapping_memory_checker<Dst>(a_dst);
}

/********************
* rebind_container *
********************/
Expand Down

0 comments on commit d9c3782

Please sign in to comment.