Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2261,7 +2261,7 @@ namespace TensileLite
template <>
inline bool DataInitialization::isBadOutput<BFloat16>(BFloat16 value)
{
return std::isinf(value);
return std::isinf(static_cast<float>(value));
}

template <>
Expand Down
99 changes: 52 additions & 47 deletions projects/hipblaslt/tensilelite/client/include/Reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,94 +34,99 @@ namespace TensileLite
{
namespace Client
{
// AlmostEqual tolerance constants per type.
// Formula: |a - b| < tolerance * (|a| + |b| + 1)
constexpr float AlmostEqualTolerance_Half = 0.01f;
constexpr float AlmostEqualTolerance_BFloat16 = 0.1f;
// tolerance * epsilon = 2 * 0.0625; 2*eps needed for SR
constexpr float AlmostEqualTolerance_Float8 = 0.125f;
// tolerance * epsilon = 2 * 0.125; 2*eps needed for SR
constexpr float AlmostEqualTolerance_BFloat8 = 0.25f;
// 7 digits precision - 2
constexpr float AlmostEqualTolerance_Float = 0.0001f;
// 15 digits precision - 2
constexpr double AlmostEqualTolerance_Double = 1e-12;

// threshold is largest allowed delta. -1 uses default for each type
template <typename T>
inline bool AlmostEqual(T a, T b, double threshold = -1.0);

template <>
inline bool AlmostEqual(Half a, Half b, double threshold)
{
Half absA = (a > 0) ? a : -a;
Half absB = (b > 0) ? b : -b;
// this avoids NaN when inf is compared against inf in the alternative code
// path
if(static_cast<float>(absA) == std::numeric_limits<float>::infinity()
|| // numeric_limits is yet to
// support _Float16 type
// properly;
static_cast<float>(absB)
== std::numeric_limits<float>::infinity()) // however promoting it to
// float works just as fine
{
return a == b;
}
Half absDiff = (a - b > 0) ? a - b : b - a;
return absDiff / (absA + absB + 1) < 0.01;
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);
return fa == fb
|| absDiff < AlmostEqualTolerance_Half * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
inline bool AlmostEqual(Float8 a, Float8 b, double threshold)
{
Float8 absA = (a > static_cast<Float8>(0.0f)) ? a : static_cast<Float8>(0.0f) - a;
Float8 absB = (b > static_cast<Float8>(0.0f)) ? b : static_cast<Float8>(0.0f) - b;
Float8 absDiff = (a - b > static_cast<Float8>(0.0f)) ? a - b : b - a;
return absDiff / (absA + absB + static_cast<Float8>(1.0f)) < static_cast<Float8>(
0.125f); // tolerance * eps = 2 * 0.0625; 2*eps needed for SR
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);
return fa == fb
|| absDiff < AlmostEqualTolerance_Float8 * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
inline bool AlmostEqual(BFloat8 a, BFloat8 b, double threshold)
{
BFloat8 absA = (a > static_cast<BFloat8>(0.0f)) ? a : static_cast<BFloat8>(0.0f) - a;
BFloat8 absB = (b > static_cast<BFloat8>(0.0f)) ? b : static_cast<BFloat8>(0.0f) - b;
BFloat8 absDiff = (a - b > static_cast<BFloat8>(0.0f)) ? a - b : b - a;
return absDiff / (absA + absB + static_cast<BFloat8>(1.0f)) < static_cast<BFloat8>(
0.25f); // tolerance * epsilon = 2 * 0.125; 2*eps needed for SR
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);

return fa == fb
|| absDiff < AlmostEqualTolerance_BFloat8 * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
inline bool AlmostEqual(Float8_fnuz a, Float8_fnuz b, double threshold)
{
Float8_fnuz absA = (a > static_cast<Float8_fnuz>(0.0f)) ? a : static_cast<Float8_fnuz>(0.0f) - a;
Float8_fnuz absB = (b > static_cast<Float8_fnuz>(0.0f)) ? b : static_cast<Float8_fnuz>(0.0f) - b;
Float8_fnuz absDiff = (a - b > static_cast<Float8_fnuz>(0.0f)) ? a - b : b - a;
return absDiff / (absA + absB + static_cast<Float8_fnuz>(1.0f)) < static_cast<Float8_fnuz>(
0.125f); // tolerance * eps = 2 * 0.0625; 2*eps needed for SR
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);
return fa == fb
|| absDiff < AlmostEqualTolerance_Float8 * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
inline bool AlmostEqual(BFloat8_fnuz a, BFloat8_fnuz b, double threshold)
{
BFloat8_fnuz absA = (a > static_cast<BFloat8_fnuz>(0.0f)) ? a : static_cast<BFloat8_fnuz>(0.0f) - a;
BFloat8_fnuz absB = (b > static_cast<BFloat8_fnuz>(0.0f)) ? b : static_cast<BFloat8_fnuz>(0.0f) - b;
BFloat8_fnuz absDiff = (a - b > static_cast<BFloat8_fnuz>(0.0f)) ? a - b : b - a;
return absDiff / (absA + absB + static_cast<BFloat8_fnuz>(1.0f)) < static_cast<BFloat8_fnuz>(
0.25f); // tolerance * epsilon = 2 * 0.125; 2*eps needed for SR
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);
return fa == fb
|| absDiff < AlmostEqualTolerance_BFloat8 * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
template <>
inline bool AlmostEqual(BFloat16 a, BFloat16 b, double threshold)
{
BFloat16 absA = (a > static_cast<BFloat16>(0.0f)) ? a : static_cast<BFloat16>(0.0f) - a;
BFloat16 absB = (b > static_cast<BFloat16>(0.0f)) ? b : static_cast<BFloat16>(0.0f) - b;
BFloat16 absDiff = (a - b > static_cast<BFloat16>(0.0f)) ? a - b : b - a;
return absDiff / (absA + absB + static_cast<BFloat16>(1.0f))
< static_cast<BFloat16>(0.1f);
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
float absDiff = std::fabs(fa - fb);
return fa == fb
|| absDiff < AlmostEqualTolerance_BFloat16 * (std::fabs(fa) + std::fabs(fb) + 1.0f);
}

template <>
inline bool AlmostEqual(float a, float b, double threshold)
{
threshold = threshold > 0.0 ? threshold : 0.0001; // default threshold for float
return std::fabs(a - b) / (std::fabs(a) + std::fabs(b) + 1)
< threshold; // 7 digits of precision - 2
float tol = (threshold > 0.0) ? static_cast<float>(threshold) : AlmostEqualTolerance_Float;
float absDiff = std::fabs(a - b);
return a == b
|| absDiff < tol * (std::fabs(a) + std::fabs(b) + 1);
}

template <>
inline bool AlmostEqual(double a, double b, double threshold)
{
return std::fabs(a - b) / (std::fabs(a) + std::fabs(b) + 1)
< 0.000000000001; // 15 digits of precision - 2
double absDiff = std::fabs(a - b);
return a == b
|| absDiff < AlmostEqualTolerance_Double * (std::fabs(a) + std::fabs(b) + 1);
}
template <>
inline bool AlmostEqual(int8_t a, int8_t b, double threshold)
Expand Down
Loading
Loading