Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ GlobalParameters:
MaxWorkspaceSize: 13421772800
DataInitTypeAlpha: 1
DataInitTypeBeta: 0
DataInitTypeA: 12
DataInitTypeB: 13
BoundsCheck: 2
KeepBuildTmp: True
EnqueuesPerSync: 10
Expand Down
32 changes: 17 additions & 15 deletions projects/hipblaslt/tensilelite/client/include/Reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ namespace TensileLite
{
namespace Client
{
// threshold is largest allowed delta. -1 uses default for each type
template <typename T>
inline bool AlmostEqual(T a, T b);
inline bool AlmostEqual(T a, T b, double threshold = -1.0);

template <>
inline bool AlmostEqual(Half a, Half b)
inline bool AlmostEqual(Half a, Half b, double threshold)
{
Half absA = (a > 0) ? a : -a;
Half absB = (b > 0) ? b : -b;
Expand All @@ -59,7 +60,7 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(Float8 a, Float8 b)
inline bool AlmostEqual(Float8 a, Float8 b, double threshold)
{
Float8 absA = (a > static_cast<Float8>(0.0f)) ? a : static_cast<Float8>(0.0f) - a;
Float8 absB = (b > static_cast<Float8>(0.0f)) ? b : static_cast<Float8>(0.0f) - b;
Expand All @@ -69,7 +70,7 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(BFloat8 a, BFloat8 b)
inline bool AlmostEqual(BFloat8 a, BFloat8 b, double threshold)
{
BFloat8 absA = (a > static_cast<BFloat8>(0.0f)) ? a : static_cast<BFloat8>(0.0f) - a;
BFloat8 absB = (b > static_cast<BFloat8>(0.0f)) ? b : static_cast<BFloat8>(0.0f) - b;
Expand All @@ -79,7 +80,7 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(Float8_fnuz a, Float8_fnuz b)
inline bool AlmostEqual(Float8_fnuz a, Float8_fnuz b, double threshold)
{
Float8_fnuz absA = (a > static_cast<Float8_fnuz>(0.0f)) ? a : static_cast<Float8_fnuz>(0.0f) - a;
Float8_fnuz absB = (b > static_cast<Float8_fnuz>(0.0f)) ? b : static_cast<Float8_fnuz>(0.0f) - b;
Expand All @@ -89,7 +90,7 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(BFloat8_fnuz a, BFloat8_fnuz b)
inline bool AlmostEqual(BFloat8_fnuz a, BFloat8_fnuz b, double threshold)
{
BFloat8_fnuz absA = (a > static_cast<BFloat8_fnuz>(0.0f)) ? a : static_cast<BFloat8_fnuz>(0.0f) - a;
BFloat8_fnuz absB = (b > static_cast<BFloat8_fnuz>(0.0f)) ? b : static_cast<BFloat8_fnuz>(0.0f) - b;
Expand All @@ -99,7 +100,7 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(BFloat16 a, BFloat16 b)
inline bool AlmostEqual(BFloat16 a, BFloat16 b, double threshold)
{
BFloat16 absA = (a > static_cast<BFloat16>(0.0f)) ? a : static_cast<BFloat16>(0.0f) - a;
BFloat16 absB = (b > static_cast<BFloat16>(0.0f)) ? b : static_cast<BFloat16>(0.0f) - b;
Expand All @@ -109,41 +110,42 @@ namespace TensileLite
}

template <>
inline bool AlmostEqual(float a, float b)
inline bool AlmostEqual(float a, float b, double threshold)
{
threshold = threshold > 0.0 ? threshold : 0.0001; // default threshold for float
return std::fabs(a - b) / (std::fabs(a) + std::fabs(b) + 1)
< 0.0001; // 7 digits of precision - 2
< threshold; // 7 digits of precision - 2
}

template <>
inline bool AlmostEqual(double a, double b)
inline bool AlmostEqual(double a, double b, double threshold)
{
return std::fabs(a - b) / (std::fabs(a) + std::fabs(b) + 1)
< 0.000000000001; // 15 digits of precision - 2
}
template <>
inline bool AlmostEqual(int8_t a, int8_t b)
inline bool AlmostEqual(int8_t a, int8_t b, double threshold)
{
return a == b;
}
template <>
inline bool AlmostEqual(int a, int b)
inline bool AlmostEqual(int a, int b, double threshold)
{
return a == b;
}
template <>
inline bool AlmostEqual(unsigned int a, unsigned int b)
inline bool AlmostEqual(unsigned int a, unsigned int b, double threshold)
{
return a == b;
}
template <>
inline bool AlmostEqual(std::complex<float> a, std::complex<float> b)
inline bool AlmostEqual(std::complex<float> a, std::complex<float> b, double threshold)
{
return AlmostEqual(a.real(), b.real()) && AlmostEqual(a.imag(), b.imag());
}

template <>
inline bool AlmostEqual(std::complex<double> a, std::complex<double> b)
inline bool AlmostEqual(std::complex<double> a, std::complex<double> b, double threshold)
{
return AlmostEqual(a.real(), b.real()) && AlmostEqual(a.imag(), b.imag());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,17 @@ namespace TensileLite
void const* resPtr,
size_t maxElements,
bool isgpu,
size_t validationStride);
size_t validationStride,
double threshold);

template <typename ValidType>
bool checkResultsTyped(TensorDescriptor const& tensor,
ValidType const* reference,
ValidType const* result,
size_t maxElement,
bool isgpu,
size_t validationStride);
size_t validationStride,
double threshold);

void printTensors(ContractionProblemGemm const& problem,
ContractionInputs const& reference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,20 @@ namespace TensileLite
class PointwiseComparison
{
public:
PointwiseComparison(bool printValids, size_t printMax, bool printReport)
PointwiseComparison(bool printValids, size_t printMax, bool printReport, double threshold = -1.0)
: m_printValids(printValids)
, m_printMax(printMax)
, m_doPrint(printMax > 0)
, m_printReport(printReport)
, m_threshold(threshold)
{
}

inline void
operator()(T referenceValue, T resultValue, size_t elemIndex, size_t elemNumber)
{
m_values++;
bool match = AlmostEqual(referenceValue, resultValue);
bool match = AlmostEqual(referenceValue, resultValue, m_threshold);
if(!match)
m_errors++;

Expand Down Expand Up @@ -150,6 +151,7 @@ namespace TensileLite
bool m_doPrint = false;
bool m_printReport = false;
bool m_failed = false;
double m_threshold = -1.0;
};

template <typename T>
Expand Down
59 changes: 42 additions & 17 deletions projects/hipblaslt/tensilelite/client/src/ReferenceValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ namespace TensileLite
void const* resPtr,
size_t maxElements,
bool isgpu,
size_t validationStride)
size_t validationStride,
double threshold)
{
bool rv = false;
switch(tensor.dataType())
Expand All @@ -220,7 +221,8 @@ namespace TensileLite
(float const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Double:
Expand All @@ -230,7 +232,8 @@ namespace TensileLite
(double const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::ComplexFloat:
Expand All @@ -240,7 +243,8 @@ namespace TensileLite
(std::complex<float> const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::ComplexDouble:
Expand All @@ -250,7 +254,8 @@ namespace TensileLite
(std::complex<double> const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Half:
Expand All @@ -260,7 +265,8 @@ namespace TensileLite
(Half const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Float8:
Expand All @@ -270,7 +276,8 @@ namespace TensileLite
(Float8 const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::BFloat8:
Expand All @@ -280,7 +287,8 @@ namespace TensileLite
(BFloat8 const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Float8_fnuz:
Expand All @@ -290,7 +298,8 @@ namespace TensileLite
(Float8_fnuz const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::BFloat8_fnuz:
Expand All @@ -300,7 +309,8 @@ namespace TensileLite
(BFloat8_fnuz const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Int8x4:
Expand All @@ -315,7 +325,8 @@ namespace TensileLite
(int32_t const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::BFloat16:
Expand All @@ -325,7 +336,8 @@ namespace TensileLite
(BFloat16 const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
case rocisa::DataType::Int8:
Expand All @@ -335,7 +347,8 @@ namespace TensileLite
(int8_t const*)resPtr,
maxElements,
isgpu,
validationStride);
validationStride,
threshold);
}
break;
default:
Expand All @@ -360,6 +373,17 @@ namespace TensileLite
if(m_printAny)
printTensors(problem, reference, result);

auto k = problem.transA() ? problem.a().sizes().at(0) : problem.a().sizes().at(1);
bool isTF32 = (problem.f32XdlMathOp() == rocisa::DataType::XFloat32);
bool isTF32x1 = (problem.computeInputType() == rocisa::DataType::BFloat16
&& problem.computeType() == rocisa::DataType::Float);
double threshold = -1.0;
if (isTF32) {
threshold = 0.01 * sqrt(double(k));
} else if (isTF32x1) {
threshold = 0.3 * sqrt(double(k));
}

for(size_t i = 0; i < problem.tensors().size(); i++)
{
auto& tensor = problem.tensors()[i];
Expand Down Expand Up @@ -460,9 +484,9 @@ namespace TensileLite
std::cout << "Validating tensor " << tensor.getName() << ", cpu pointer "
<< refPtr << ", gpu pointer " << resPtr
<< ", size = " << result.maxElements[i] << std::endl;

rv &= checkResults(
tensor, refPtr, resPtr, result.maxElements[i], result.gpu, validationStride);
tensor, refPtr, resPtr, result.maxElements[i], result.gpu, validationStride, threshold);
}
return rv;
}
Expand Down Expand Up @@ -649,9 +673,10 @@ namespace TensileLite
ValidType const* result,
size_t maxElement,
bool isgpu,
size_t validationStride)
size_t validationStride,
double threshold)
{
PointwiseComparison<ValidType> compareValid(m_printValids, m_printMax, m_printMax > 0);
PointwiseComparison<ValidType> compareValid(m_printValids, m_printMax, m_printMax > 0, threshold);
InvalidComparison<ValidType> compareInvalid(m_printMax, m_printMax > 0);

size_t elementsToCopy = tensor.totalAllocatedElements();
Expand Down
Loading