Skip to content

Commit

Permalink
Fix computation of abs(b) for accuracy evaluation (llvm#1257)
Browse files Browse the repository at this point in the history
* Fix abs(b) and clean up in `omTensorAreTwoOmtsClose`

Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>

Co-authored-by: Ettore Tiotto <etiotto@ca.ibm.com>
Co-authored-by: Tung D. Le <tungld@gmail.com>
  • Loading branch information
3 people committed Mar 30, 2022
1 parent 1c731f4 commit 53128f7
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions src/Runtime/OMTensor.inc
Original file line number Diff line number Diff line change
Expand Up @@ -610,31 +610,25 @@ inline bool omTensorAreTwoOmtsClose(
}

// Compute difference, verify it's within tolerable range.
// rtol * b + atol >= abs(a - b)
auto anum = omTensorGetNumElems(a);
auto bnum = omTensorGetNumElems(b);
auto rtolT = omTensorCreateWithShape<T>(aShape);
assert(rtolT && "failed to allocate memory");
auto atolT = omTensorCreateWithShape<T>(aShape);
assert(atolT && "failed to allocate memory");
for (const auto &idx : omTensorComputeIndexSet(rtolT)) {
omTensorGetElem<T>(rtolT, idx) = rtol;
omTensorGetElem<T>(atolT, idx) = atol;
}
std::vector<T> absoluteDiff(anum);
std::vector<T> eqAllclose(anum);
// rtol * abs(b) + atol >= abs(a - b)
auto nElems = omTensorGetNumElems(a);
std::vector<T> absoluteDiff(nElems);
std::vector<T> absb(nElems);
std::vector<T> eqAllclose(nElems);
std::vector<T> rtolT(nElems, rtol);
std::vector<T> atolT(nElems, atol);
// abs(a - b)
std::transform((T *)a->_alignedPtr, (T *)a->_alignedPtr + anum,
std::transform((T *)a->_alignedPtr, (T *)a->_alignedPtr + nElems,
(T *)b->_alignedPtr, absoluteDiff.begin(), std::minus<>());
std::transform(absoluteDiff.begin(), absoluteDiff.end(), absoluteDiff.begin(),
static_cast<T (*)(T)>(&std::abs));
// rtol * abs(b)
std::transform((T *)b->_alignedPtr, (T *)b->_alignedPtr + bnum,
(T *)b->_alignedPtr, static_cast<T (*)(T)>(&std::abs));
std::transform((T *)b->_alignedPtr, (T *)b->_alignedPtr + bnum,
(T *)rtolT->_alignedPtr, eqAllclose.begin(), std::multiplies<>());
// rtol * abs(b) + atol
std::transform(eqAllclose.begin(), eqAllclose.end(), (T *)atolT->_alignedPtr,
std::transform((T *)b->_alignedPtr, (T *)b->_alignedPtr + nElems,
absb.begin(), static_cast<T (*)(T)>(&std::abs));
std::transform(rtolT.begin(), rtolT.end(), absb.begin(), eqAllclose.begin(),
std::multiplies<>());
// (rtol * abs(b)) + atol
std::transform(eqAllclose.begin(), eqAllclose.end(), atolT.begin(),
eqAllclose.begin(), std::plus<>());
// (rtol * abs(b) + atol) - abs(a - b)
std::transform(eqAllclose.begin(), eqAllclose.end(), absoluteDiff.begin(),
Expand All @@ -659,8 +653,6 @@ inline bool omTensorAreTwoOmtsClose(
}
}
}
omTensorDestroy(rtolT);
omTensorDestroy(atolT);
return satisfied;
}

Expand Down

0 comments on commit 53128f7

Please sign in to comment.