diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index 6221f7f1df..2e4ce65d57 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -89,8 +89,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params, std::stringstream SS; SS << "-HV 202x"; SS << " -DCOMP_TYPE=" << static_cast(Params.CompType); - SS << " -DCOMP_TYPE_F16=" << 8; - SS << " -DCOMP_TYPE_F32=" << 9; SS << " -DM_DIM=" << Params.M; SS << " -DN_DIM=" << Params.N; SS << " -DUSE=" << static_cast(Params.Use); @@ -99,6 +97,17 @@ static std::string buildCompilerArgs(const MatrixParams &Params, SS << " -DLAYOUT=" << static_cast(Params.Layout); SS << " -DELEM_SIZE=" << elementSize(Params.CompType); SS << " -DNUMTHREADS=" << Params.NumThreads; + switch (Params.CompType) { + case ComponentType::F16: + SS << " -DELEM_TYPE=half"; + break; + case ComponentType::F32: + SS << " -DELEM_TYPE=float"; + break; + default: + SS << " -DELEM_TYPE=uint"; + break; + } if (Params.EmulateTest) SS << " -DEMULATE_TEST"; if (Params.Enable16Bit) @@ -389,7 +398,7 @@ static const char LoadStoreShader[] = R"( [numthreads(NUMTHREADS, 1, 1)] void main() { for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); + Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); } } #endif @@ -477,15 +486,9 @@ static const char SplatStoreShader[] = R"( #else [numthreads(NUMTHREADS, 1, 1)] void main() { -#if COMP_TYPE == COMP_TYPE_F32 - float fill = FILL_VALUE; -#elif COMP_TYPE == COMP_TYPE_F16 - half fill = FILL_VALUE; -#else - uint fill = FILL_VALUE; -#endif + ELEM_TYPE fill = FILL_VALUE; for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, fill); + Output.Store(I*ELEM_SIZE, fill); } } #endif @@ -567,34 +570,28 @@ static const char ElementAccessShader[] = R"( for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) { uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I); uint Offset = coordToByteOffset(Coord); -#if COMP_TYPE == COMP_TYPE_F32 - float Elem; - __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); - Output.Store(Offset, asuint(Elem)); -#else - uint Elem; - __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); - Output.Store(Offset, Elem); -#endif + ELEM_TYPE Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Output.Store(Offset, Elem); } // Save the matrix length that this thread saw. The length is written // to the output right after the matrix, offset by the thread index uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint)); uint Len = __builtin_LinAlg_MatrixLength(Mat); - Output.Store(LenIdx, Len); + Output.Store(LenIdx, Len); } #else [numthreads(NUMTHREADS, 1, 1)] void main(uint threadIndex : SV_GroupIndex) { uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint)); - Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS); + Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS); if (threadIndex != 0) return; for (uint I = 0; I < M_DIM*N_DIM; ++I) { - Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); + Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE)); } } #endif