Skip to content

Commit

Permalink
TreeNode specialization for handling more sophisticated multi-dimensi…
Browse files Browse the repository at this point in the history
…onal tree cuts than the current tree cuts
  • Loading branch information
paulbkoch committed Nov 3, 2024
1 parent b7900df commit edf8bed
Showing 1 changed file with 146 additions and 8 deletions.
154 changes: 146 additions & 8 deletions shared/libebm/TreeNode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace DEFINED_ZONE_NAME {
static bool IsOverflowTreeNodeSize(const bool bHessian, const size_t cScores);
static size_t GetTreeNodeSize(const bool bHessian, const size_t cScores);

static bool IsOverflowTreeNodeMultiSize(const bool bHessian, const size_t cScores);
static size_t GetTreeNodeMultiSize(const bool bHessian, const size_t cScores);

template<bool bHessian, size_t cCompilerScores = 1> struct TreeNode final {
friend bool IsOverflowTreeNodeSize(const bool, const size_t);
friend size_t GetTreeNodeSize(const bool, const size_t);
Expand Down Expand Up @@ -72,27 +75,27 @@ template<bool bHessian, size_t cCompilerScores = 1> struct TreeNode final {
pBinLastOrChildren = pChildren;
}

inline FloatCalc AFTER_GetSplitGain() const {
inline FloatMain AFTER_GetSplitGain() const {
EBM_ASSERT(1 == m_debugProgressionStage);

const FloatCalc splitGain = m_UNION.m_afterGainCalc.m_splitGain;
const FloatMain splitGain = m_UNION.m_afterGainCalc.m_splitGain;

// our priority queue cannot handle NaN values so we filter them out before adding them
EBM_ASSERT(!std::isnan(splitGain));
EBM_ASSERT(!std::isinf(splitGain));
EBM_ASSERT(0 <= splitGain);
EBM_ASSERT(FloatMain{0} <= splitGain);

return splitGain;
}
inline void AFTER_SetSplitGain(const FloatCalc splitGain) {
inline void AFTER_SetSplitGain(const FloatMain splitGain) {
EBM_ASSERT(1 == m_debugProgressionStage);

// this is only called if there is a legal gain value. If the TreeNode cannot be split call AFTER_RejectSplit.

// our priority queue cannot handle NaN values so we filter them out before adding them
EBM_ASSERT(!std::isnan(splitGain));
EBM_ASSERT(!std::isinf(splitGain));
EBM_ASSERT(0 <= splitGain);
EBM_ASSERT(FloatMain{0} <= splitGain);

m_UNION.m_afterGainCalc.m_splitGain = splitGain;
}
Expand All @@ -113,12 +116,12 @@ template<bool bHessian, size_t cCompilerScores = 1> struct TreeNode final {
//
// We need to set the m_splitGain value then to something other than NaN to indicate that it was not split.

m_UNION.m_afterGainCalc.m_splitGain = 0;
m_UNION.m_afterGainCalc.m_splitGain = FloatMain{0};
}

inline void AFTER_SplitNode() {
EBM_ASSERT(1 == m_debugProgressionStage);
m_UNION.m_afterGainCalc.m_splitGain = std::numeric_limits<FloatCalc>::quiet_NaN();
m_UNION.m_afterGainCalc.m_splitGain = std::numeric_limits<FloatMain>::quiet_NaN();
}

inline bool AFTER_IsSplit() const {
Expand Down Expand Up @@ -155,7 +158,7 @@ template<bool bHessian, size_t cCompilerScores = 1> struct TreeNode final {
};

struct AfterGainCalc final {
FloatCalc m_splitGain;
FloatMain m_splitGain;
};

struct Deconstruct final {
Expand Down Expand Up @@ -185,6 +188,89 @@ static_assert(std::is_trivial<TreeNode<true>>::value && std::is_trivial<TreeNode
static_assert(std::is_pod<TreeNode<true>>::value && std::is_pod<TreeNode<false>>::value,
"We use a lot of C constructs, so disallow non-POD types in general");

template<bool bHessian, size_t cCompilerScores = 1> struct TreeNodeMulti final {
friend bool IsOverflowTreeNodeMultiSize(const bool, const size_t);
friend size_t GetTreeNodeMultiSize(const bool, const size_t);

TreeNodeMulti() = default; // preserve our POD status
~TreeNodeMulti() = default; // preserve our POD status
void* operator new(std::size_t) = delete; // we only use malloc/free in this library
void operator delete(void*) = delete; // we only use malloc/free in this library

inline void SetSplitGain(const FloatMain splitGain) {
// our priority queue cannot handle NaN values so we filter them out before adding them
EBM_ASSERT(!std::isnan(splitGain));
EBM_ASSERT(!std::isinf(splitGain));
EBM_ASSERT(FloatMain{0} <= splitGain);

m_splitGain = splitGain;
}
inline bool IsSplit() const { return std::isnan(m_splitGain); }
inline FloatMain GetSplitGain() const {
// our priority queue cannot handle NaN values so we filter them out before adding them
EBM_ASSERT(!std::isnan(m_splitGain));
EBM_ASSERT(!std::isinf(m_splitGain));
EBM_ASSERT(FloatMain{0} <= m_splitGain);

return m_splitGain;
}
inline void SplitNode() {
EBM_ASSERT(!IsSplit());
m_splitGain = std::numeric_limits<FloatMain>::quiet_NaN();
}

inline void SetDimensionIndex(const size_t iDimension) {
EBM_ASSERT(!IsSplit());
m_iDimension = iDimension;
}
inline size_t GetDimensionIndex() const { return m_iDimension; }

inline void SetSplitIndex(const size_t iSplit) {
EBM_ASSERT(!IsSplit());
m_iSplit = iSplit;
}
inline size_t GetSplitIndex() const { return m_iSplit; }

inline void SetParent(TreeNodeMulti* const pParent) {
EBM_ASSERT(!IsSplit());
m_pParent = pParent;
}
inline TreeNodeMulti* GetParent() { return m_pParent; }

inline void SetChildren(TreeNodeMulti* const pChildren) {
EBM_ASSERT(IsSplit());
m_pChildren = pChildren;
}
inline TreeNodeMulti* GetChildren() {
EBM_ASSERT(IsSplit());
return m_pChildren;
}

inline Bin<FloatMain, UIntMain, true, true, bHessian, cCompilerScores>* GetBin() { return &m_bin; }

template<size_t cNewCompilerScores> inline TreeNodeMulti<bHessian, cNewCompilerScores>* Upgrade() {
return reinterpret_cast<TreeNodeMulti<bHessian, cNewCompilerScores>*>(this);
}
inline TreeNodeMulti<bHessian, 1>* Downgrade() { return reinterpret_cast<TreeNodeMulti<bHessian, 1>*>(this); }

private:
FloatMain m_splitGain;
size_t m_iDimension;
size_t m_iSplit;
TreeNodeMulti* m_pParent;
TreeNodeMulti* m_pChildren;

// IMPORTANT: m_bin must be in the last position for the struct hack and this must be standard layout
Bin<FloatMain, UIntMain, true, true, bHessian, cCompilerScores> m_bin;
};
static_assert(
std::is_standard_layout<TreeNodeMulti<true>>::value && std::is_standard_layout<TreeNodeMulti<false>>::value,
"We use the struct hack in several places, so disallow non-standard_layout types in general");
static_assert(std::is_trivial<TreeNodeMulti<true>>::value && std::is_trivial<TreeNodeMulti<false>>::value,
"We use memcpy in several places, so disallow non-trivial types in general");
static_assert(std::is_pod<TreeNodeMulti<true>>::value && std::is_pod<TreeNodeMulti<false>>::value,
"We use a lot of C constructs, so disallow non-POD types in general");

inline static bool IsOverflowTreeNodeSize(const bool bHessian, const size_t cScores) {
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);

Expand All @@ -204,6 +290,25 @@ inline static bool IsOverflowTreeNodeSize(const bool bHessian, const size_t cSco
return false;
}

inline static bool IsOverflowTreeNodeMultiSize(const bool bHessian, const size_t cScores) {
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);

size_t cBytesTreeNodeMultiComponent;
if(bHessian) {
typedef TreeNodeMulti<true> OffsetType;
cBytesTreeNodeMultiComponent = offsetof(OffsetType, m_bin);
} else {
typedef TreeNodeMulti<false> OffsetType;
cBytesTreeNodeMultiComponent = offsetof(OffsetType, m_bin);
}

if(UNLIKELY(IsAddError(cBytesTreeNodeMultiComponent, cBytesPerBin))) {
return true;
}

return false;
}

inline static size_t GetTreeNodeSize(const bool bHessian, const size_t cScores) {
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);

Expand All @@ -219,23 +324,56 @@ inline static size_t GetTreeNodeSize(const bool bHessian, const size_t cScores)
return cBytesTreeNodeComponent + cBytesPerBin;
}

inline static size_t GetTreeNodeMultiSize(const bool bHessian, const size_t cScores) {
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);

size_t cBytesTreeNodeMultiComponent;
if(bHessian) {
typedef TreeNodeMulti<true> OffsetType;
cBytesTreeNodeMultiComponent = offsetof(OffsetType, m_bin);
} else {
typedef TreeNodeMulti<false> OffsetType;
cBytesTreeNodeMultiComponent = offsetof(OffsetType, m_bin);
}

return cBytesTreeNodeMultiComponent + cBytesPerBin;
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNode<bHessian, cCompilerScores>* IndexTreeNode(
TreeNode<bHessian, cCompilerScores>* const pTreeNode, const size_t iByte) {
return IndexByte(pTreeNode, iByte);
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNodeMulti<bHessian, cCompilerScores>* IndexTreeNodeMulti(
TreeNodeMulti<bHessian, cCompilerScores>* const pTreeNodeMulti, const size_t iByte) {
return IndexByte(pTreeNodeMulti, iByte);
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNode<bHessian, cCompilerScores>* GetLeftNode(TreeNode<bHessian, cCompilerScores>* const pChildren) {
return pChildren;
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNodeMulti<bHessian, cCompilerScores>* GetLeftNode(
TreeNodeMulti<bHessian, cCompilerScores>* const pChildren) {
return pChildren;
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNode<bHessian, cCompilerScores>* GetRightNode(
TreeNode<bHessian, cCompilerScores>* const pChildren, const size_t cBytesPerTreeNode) {
return IndexTreeNode(pChildren, cBytesPerTreeNode);
}

template<bool bHessian, size_t cCompilerScores>
inline static TreeNodeMulti<bHessian, cCompilerScores>* GetRightNode(
TreeNodeMulti<bHessian, cCompilerScores>* const pChildren, const size_t cBytesPerTreeNodeMulti) {
return IndexTreeNodeMulti(pChildren, cBytesPerTreeNodeMulti);
}

} // namespace DEFINED_ZONE_NAME

#endif // TREE_NODE_HPP

0 comments on commit edf8bed

Please sign in to comment.