Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
std::stringstream SS;
SS << "-HV 202x";
SS << " -DCOMP_TYPE=" << static_cast<int>(Params.CompType);
SS << " -DCOMP_TYPE_F16=" << 8;
SS << " -DCOMP_TYPE_F32=" << 9;
SS << " -DM_DIM=" << Params.M;
SS << " -DN_DIM=" << Params.N;
SS << " -DUSE=" << static_cast<int>(Params.Use);
Expand All @@ -99,6 +97,17 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
SS << " -DLAYOUT=" << static_cast<int>(Params.Layout);
SS << " -DELEM_SIZE=" << elementSize(Params.CompType);
SS << " -DNUMTHREADS=" << Params.NumThreads;
switch (Params.CompType) {
case ComponentType::F16:
SS << " -DELEM_TYPE=half";
break;
case ComponentType::F32:
SS << " -DELEM_TYPE=float";
break;
default:
SS << " -DELEM_TYPE=uint";
break;
}
if (Params.EmulateTest)
SS << " -DEMULATE_TEST";
if (Params.Enable16Bit)
Expand Down Expand Up @@ -389,7 +398,7 @@ static const char LoadStoreShader[] = R"(
[numthreads(NUMTHREADS, 1, 1)]
void main() {
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
}
}
#endif
Expand Down Expand Up @@ -477,15 +486,9 @@ static const char SplatStoreShader[] = R"(
#else
[numthreads(NUMTHREADS, 1, 1)]
void main() {
#if COMP_TYPE == COMP_TYPE_F32
float fill = FILL_VALUE;
#elif COMP_TYPE == COMP_TYPE_F16
half fill = FILL_VALUE;
#else
uint fill = FILL_VALUE;
#endif
ELEM_TYPE fill = FILL_VALUE;
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store(I*ELEM_SIZE, fill);
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, fill);
}
}
#endif
Expand Down Expand Up @@ -567,34 +570,28 @@ static const char ElementAccessShader[] = R"(
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
uint Offset = coordToByteOffset(Coord);
#if COMP_TYPE == COMP_TYPE_F32
float Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, asuint(Elem));
#else
uint Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, Elem);
#endif
ELEM_TYPE Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store<ELEM_TYPE>(Offset, Elem);
}

// Save the matrix length that this thread saw. The length is written
// to the output right after the matrix, offset by the thread index
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
uint Len = __builtin_LinAlg_MatrixLength(Mat);
Output.Store(LenIdx, Len);
Output.Store<uint>(LenIdx, Len);
}
#else
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
Output.Store(LenIdx, M_DIM * N_DIM / NUMTHREADS);
Output.Store<uint>(LenIdx, M_DIM * N_DIM / NUMTHREADS);

if (threadIndex != 0)
return;

for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store(I*ELEM_SIZE, Input.Load(I*ELEM_SIZE));
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
}
}
#endif
Expand Down
Loading