Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions tools/clang/unittests/HLSLExec/HlslTestDataTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

#include <ostream>

#include <windows.h>

#include <DirectXMath.h>
#include <DirectXPackedVector.h>

#include "HlslTestUtils.h"
#include "dxc/Support/Global.h"

// These types bridge the gap between C++ and HLSL type representations.
Expand Down Expand Up @@ -460,6 +463,101 @@ struct HLSLMin16Uint_t {
uint32_t Val;
};

enum class ValidationType {
Epsilon,
Ulp,
};

template <typename T>
inline bool doValuesMatch(T A, T B, double Tolerance, ValidationType) {
if (Tolerance == 0.0)
return A == B;

T Diff = A > B ? A - B : B - A;
return Diff <= Tolerance;
}

inline bool doValuesMatch(HLSLBool_t A, HLSLBool_t B, double, ValidationType) {
return A == B;
}

inline bool doValuesMatch(HLSLHalf_t A, HLSLHalf_t B, double Tolerance,
ValidationType ValidationType) {
switch (ValidationType) {
case ValidationType::Epsilon:
return CompareHalfEpsilon(A.Val, B.Val, static_cast<float>(Tolerance));
case ValidationType::Ulp:
return CompareHalfULP(A.Val, B.Val, static_cast<float>(Tolerance));
default:
hlsl_test::LogErrorFmt(
L"Invalid ValidationType. Expecting Epsilon or ULP.");
return false;
}
}

// Min precision float comparison: convert to half and compare in fp16 space.
// This reuses the same tolerance values as HLSLHalf_t. Min precision is at
// least 16-bit, so fp16 tolerances are an upper bound for all cases.
inline bool doValuesMatch(HLSLMin16Float_t A, HLSLMin16Float_t B,
double Tolerance, ValidationType ValidationType) {
auto HalfA = DirectX::PackedVector::XMConvertFloatToHalf(A.Val);
auto HalfB = DirectX::PackedVector::XMConvertFloatToHalf(B.Val);
switch (ValidationType) {
case ValidationType::Epsilon:
return CompareHalfEpsilon(HalfA, HalfB, static_cast<float>(Tolerance));
case ValidationType::Ulp:
return CompareHalfULP(HalfA, HalfB, static_cast<float>(Tolerance));
default:
hlsl_test::LogErrorFmt(
L"Invalid ValidationType. Expecting Epsilon or ULP.");
return false;
}
}

inline bool doValuesMatch(HLSLMin16Int_t A, HLSLMin16Int_t B, double,
ValidationType) {
return A == B;
}

inline bool doValuesMatch(HLSLMin16Uint_t A, HLSLMin16Uint_t B, double,
ValidationType) {
return A == B;
}

inline bool doValuesMatch(float A, float B, double Tolerance,
ValidationType ValidationType) {
switch (ValidationType) {
case ValidationType::Epsilon:
return CompareFloatEpsilon(A, B, static_cast<float>(Tolerance));
case ValidationType::Ulp: {
// Tolerance is in ULPs. Convert to int for the comparison.
const int IntTolerance = static_cast<int>(Tolerance);
return CompareFloatULP(A, B, IntTolerance);
};
default:
hlsl_test::LogErrorFmt(
L"Invalid ValidationType. Expecting Epsilon or ULP.");
return false;
}
}

inline bool doValuesMatch(double A, double B, double Tolerance,
ValidationType ValidationType) {
switch (ValidationType) {
case ValidationType::Epsilon:
return CompareDoubleEpsilon(A, B, Tolerance);
case ValidationType::Ulp: {
// Tolerance is in ULPs. Convert to int64_t for the comparison.
const int64_t IntTolerance = static_cast<int64_t>(Tolerance);
return CompareDoubleULP(A, B, IntTolerance);
};
default:
hlsl_test::LogErrorFmt(
L"Invalid ValidationType. Expecting Epsilon or ULP.");
return false;
}
}

} // namespace HLSLTestDataTypes

#endif
125 changes: 66 additions & 59 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ using hlsl::DXIL::LinalgMatrixLayout;
using hlsl::DXIL::MatrixScope;
using hlsl::DXIL::MatrixUse;

using HLSLTestDataTypes::doValuesMatch;
using HLSLTestDataTypes::HLSLHalf_t;
using HLSLTestDataTypes::ValidationType;

using VariantCompType = std::variant<std::vector<float>, std::vector<int32_t>,
std::vector<HLSLHalf_t>>;
Expand Down Expand Up @@ -111,10 +113,8 @@ static bool verifyFloatBuffer(const float *Actual, const float *Expected,
float Tolerance = 0.0f) {
bool Success = true;
for (size_t I = 0; I < Count; I++) {
float Diff = Actual[I] - Expected[I];
if (Diff < 0)
Diff = -Diff;
if (Diff > Tolerance) {
if (!doValuesMatch(Actual[I], Expected[I], Tolerance,
ValidationType::Epsilon)) {
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%f, expected=%f",
I, static_cast<double>(Actual[I]),
static_cast<double>(Expected[I]));
Expand All @@ -132,7 +132,7 @@ static bool verifyIntBuffer(const int32_t *Actual, const int32_t *Expected,
size_t Count, bool Verbose) {
bool Success = true;
for (size_t I = 0; I < Count; I++) {
if (Actual[I] != Expected[I]) {
if (!doValuesMatch(Actual[I], Expected[I], 0.0, ValidationType::Epsilon)) {
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%d, expected=%d",
I, Actual[I], Expected[I]);
Success = false;
Expand All @@ -149,10 +149,8 @@ static bool verifyHalfBuffer(const HLSLHalf_t *Actual,
bool Verbose, HLSLHalf_t Tolerance = 0.0f) {
bool Success = true;
for (size_t I = 0; I < Count; I++) {
HLSLHalf_t Diff = Actual[I] - Expected[I];
if (Diff < 0.0f)
Diff = -Diff;
if (Diff > Tolerance) {
if (!doValuesMatch(Actual[I], Expected[I], Tolerance,
ValidationType::Epsilon)) {
hlsl_test::LogErrorFmt(L"Mismatch at index %zu: actual=%f, expected=%f",
I, static_cast<float>(Actual[I]),
static_cast<float>(Expected[I]));
Expand Down Expand Up @@ -254,21 +252,10 @@ static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
return std::vector<float>();
}

static bool shouldSkipBecauseSM610Unsupported(ID3D12Device *Device) {
// Never skip in an HLK environment
#ifdef _HLK_CONF
return false;
#endif

// Don't skip if a device is available
if (Device)
return false;

// Skip GPU execution
static void logCompiledButSkipping() {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return true;
}

class DxilConf_SM610_LinAlg {
Expand Down Expand Up @@ -299,49 +286,34 @@ class DxilConf_SM610_LinAlg {
TEST_METHOD(ElementAccess_Wave_16x16_F16);

private:
bool createDevice();
D3D_SHADER_MODEL createDevice();

CComPtr<ID3D12Device> D3DDevice;
dxc::SpecificDllLoader DxcSupport;
bool VerboseLogging = false;
bool EmulateTest = false;
bool Initialized = false;
bool CompileOnly = false;
std::optional<D3D12SDKSelector> D3D12SDK;

WEX::TestExecution::SetVerifyOutput VerifyOutput{
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
};

/// Creates the device and setups the test scenario with the following variants
/// HLK build: Require SM6.10 supported fail otherwise
/// Non-HLK, no SM6.10 support: Compile shaders, then exit with skip
/// Non-HLK, SM6.10 support: Compile shaders and run full test
bool DxilConf_SM610_LinAlg::createDevice() {
bool FailIfRequirementsNotMet = false;
#ifdef _HLK_CONF
FailIfRequirementsNotMet = true;
#endif
/// Attempts to create a device. If shaders are being emulated then a SM6.8
/// device is attempted. Otherwise a SM6.10 device is attempted
D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() {
if (EmulateTest) {
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false))
return D3D_SHADER_MODEL_6_8;

const bool SkipUnsupported = FailIfRequirementsNotMet;
if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10,
SkipUnsupported)) {
if (FailIfRequirementsNotMet) {
hlsl_test::LogErrorFmt(
L"Device creation failed, resulting in test failure, since "
L"FailIfRequirementsNotMet is set. The expectation is that this "
L"test will only be executed if something has previously "
L"determined that the system meets the requirements of this "
L"test.");
return false;
}
return D3D_SHADER_MODEL_NONE;
}

if (EmulateTest) {
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false);
}
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false))
return D3D_SHADER_MODEL_6_10;

return true;
return D3D_SHADER_MODEL_NONE;
}

bool DxilConf_SM610_LinAlg::setupClass() {
Expand All @@ -354,7 +326,26 @@ bool DxilConf_SM610_LinAlg::setupClass() {
VerboseLogging);
WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest",
EmulateTest);
return createDevice();
D3D_SHADER_MODEL SupportedSM = createDevice();

if (EmulateTest) {
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
if (SupportedSM != D3D_SHADER_MODEL_6_8) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.8");
return false;
}
}

#ifdef _HLK_CONF
if (SupportedSM != D3D_SHADER_MODEL_6_10) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.10");
return false;
}
#endif

CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
}

return true;
Expand All @@ -366,11 +357,17 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK)
return true;

// Device is expected to be null. No point in recreating it
if (CompileOnly)
return true;

hlsl_test::LogCommentFmt(L"Device was lost!");
D3DDevice.Release();

hlsl_test::LogCommentFmt(L"Recreating device");
return createDevice();

// !CompileOnly implies we expect it to succeeded
return createDevice() != D3D_SHADER_MODEL_NONE;
}

static const char LoadStoreShader[] = R"(
Expand Down Expand Up @@ -400,7 +397,8 @@ static const char LoadStoreShader[] = R"(

static void runLoadStoreRoundtrip(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const MatrixParams &Params, bool Verbose,
bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();

Expand All @@ -417,8 +415,10 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
// Always verify the shader compiles.
compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);

Expand Down Expand Up @@ -457,7 +457,8 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging);
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging,
CompileOnly);
}

static const char SplatStoreShader[] = R"(
Expand Down Expand Up @@ -493,7 +494,7 @@ static const char SplatStoreShader[] = R"(
static void runSplatStore(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, float FillValue,
bool Verbose) {
bool Verbose, bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();
std::string Target = "cs_6_10";
Expand All @@ -508,8 +509,10 @@ static void runSplatStore(ID3D12Device *Device,
// Always verify the shader compiles.
compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);

Expand Down Expand Up @@ -538,7 +541,8 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging,
CompileOnly);
}

static const char ElementAccessShader[] = R"(
Expand Down Expand Up @@ -598,7 +602,8 @@ static const char ElementAccessShader[] = R"(

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const MatrixParams &Params, bool Verbose,
bool CompileOnly) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
Expand All @@ -621,8 +626,10 @@ static void runElementAccess(ID3D12Device *Device,

compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);

if (shouldSkipBecauseSM610Unsupported(Device))
if (CompileOnly) {
logCompiledButSkipping();
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);

Expand Down Expand Up @@ -673,7 +680,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
Params.NumThreads = 4;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
}

} // namespace LinAlg
Loading
Loading