Skip to content

Commit 6ef3218

Browse files
AlexVlxarsenm
andauthored
[SPIRV] Add support for bfloat16 atomics via the SPV_INTEL_16bit_atomics extension (#166257)
This enables support for atomic RMW ops (add, sub, min and max to be precise) with `bfloat16` operands, via the [SPV_INTEL_16bit_atomics extension](#20009). It's logically a successor to #166031 (I should've used a stack), but I'm putting it up for early review. --------- Co-authored-by: Matt Arsenault <[email protected]>
1 parent c3b31ba commit 6ef3218

File tree

7 files changed

+103
-14
lines changed

7 files changed

+103
-14
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
167167
- Adds atomic add instruction on floating-point numbers.
168168
* - ``SPV_EXT_shader_atomic_float_min_max``
169169
- Adds atomic min and max instruction on floating-point numbers.
170+
* - ``SPV_INTEL_16bit_atomics``
171+
- Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
170172
* - ``SPV_INTEL_2d_block_io``
171173
- Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
172174
* - ``SPV_INTEL_arbitrary_precision_integers``

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3482,7 +3482,7 @@ bool IRTranslator::translateAtomicCmpXchg(const User &U,
34823482

34833483
bool IRTranslator::translateAtomicRMW(const User &U,
34843484
MachineIRBuilder &MIRBuilder) {
3485-
if (containsBF16Type(U))
3485+
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
34863486
return false;
34873487

34883488
const AtomicRMWInst &I = cast<AtomicRMWInst>(U);

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
2929
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
3030
{"SPV_EXT_shader_atomic_float_min_max",
3131
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
32+
{"SPV_INTEL_16bit_atomics",
33+
SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
3234
{"SPV_EXT_arithmetic_fence",
3335
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
3436
{"SPV_EXT_demote_to_helper_invocation",

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
10581058
}
10591059
}
10601060

1061+
static bool isBFloat16Type(const SPIRVType *TypeDef) {
1062+
return TypeDef && TypeDef->getNumOperands() == 3 &&
1063+
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1064+
TypeDef->getOperand(1).getImm() == 16 &&
1065+
TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1066+
}
1067+
10611068
// Add requirements for handling atomic float instructions
10621069
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
10631070
"The atomic float instruction requires the following SPIR-V " \
@@ -1081,11 +1088,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
10811088
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
10821089
switch (BitWidth) {
10831090
case 16:
1084-
if (!ST.canUseExtension(
1085-
SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1086-
report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1087-
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1088-
Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1091+
if (isBFloat16Type(TypeDef)) {
1092+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1093+
report_fatal_error(
1094+
"The atomic bfloat16 instruction requires the following SPIR-V "
1095+
"extension: SPV_INTEL_16bit_atomics",
1096+
false);
1097+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1098+
Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
1099+
} else {
1100+
if (!ST.canUseExtension(
1101+
SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1102+
report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1103+
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1104+
Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1105+
}
10891106
break;
10901107
case 32:
10911108
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1121,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
11041121
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
11051122
switch (BitWidth) {
11061123
case 16:
1107-
Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1124+
if (isBFloat16Type(TypeDef)) {
1125+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1126+
report_fatal_error(
1127+
"The atomic bfloat16 instruction requires the following SPIR-V "
1128+
"extension: SPV_INTEL_16bit_atomics",
1129+
false);
1130+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1131+
Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1132+
} else {
1133+
Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1134+
}
11081135
break;
11091136
case 32:
11101137
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1355,6 @@ void addPrintfRequirements(const MachineInstr &MI,
13281355
}
13291356
}
13301357

1331-
static bool isBFloat16Type(const SPIRVType *TypeDef) {
1332-
return TypeDef && TypeDef->getNumOperands() == 3 &&
1333-
TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1334-
TypeDef->getOperand(1).getImm() == 16 &&
1335-
TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1336-
}
1337-
13381358
void addInstrRequirements(const MachineInstr &MI,
13391359
SPIRV::ModuleAnalysisInfo &MAI,
13401360
const SPIRVSubtarget &ST) {

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
389389
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
390390
defm SPV_INTEL_bfloat16_arithmetic
391391
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
392+
defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
392393

393394
//===----------------------------------------------------------------------===//
394395
// Multiclass used to define Capabilities enum values and at the same time
@@ -566,9 +567,11 @@ defm FloatControls2
566567
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
567568
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
568569
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
570+
defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
569571
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
570572
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
571573
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
574+
defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
572575
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
573576
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
574577
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR1
2+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR2
3+
4+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16,+SPV_INTEL_bfloat16_arithmetic %s -o - | FileCheck %s
5+
6+
; CHECK-ERROR1: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
7+
; CHECK-ERROR2: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
8+
9+
; CHECK: Capability BFloat16TypeKHR
10+
; CHECK: Capability AtomicBFloat16AddINTEL
11+
; CHECK: Extension "SPV_KHR_bfloat16"
12+
; CHECK: Extension "SPV_EXT_shader_atomic_float_add"
13+
; CHECK: Extension "SPV_INTEL_16bit_atomics"
14+
; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
15+
; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
16+
; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
17+
; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
18+
; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
19+
; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
20+
; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
21+
; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
22+
; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
23+
; CHECK: %[[NegatedConstBF16:[0-9]+]] = OpFNegate %[[TyBF16]] %[[ConstBF16]]
24+
; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstBF16]]
25+
26+
27+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
28+
29+
define dso_local spir_func void @test1() local_unnamed_addr {
30+
entry:
31+
%addval = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
32+
%subval = atomicrmw fsub ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
33+
ret void
34+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
3+
4+
; CHECK-ERROR: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
5+
6+
; CHECK: Capability AtomicBFloat16MinMaxINTEL
7+
; CHECK: Extension "SPV_KHR_bfloat16"
8+
; CHECK: Extension "SPV_EXT_shader_atomic_float_min_max"
9+
; CHECK: Extension "SPV_INTEL_16bit_atomics"
10+
; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
11+
; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
12+
; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
13+
; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
14+
; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
15+
; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
16+
; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
17+
; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
18+
; CHECK: OpAtomicFMinEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
19+
; CHECK: OpAtomicFMaxEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
20+
21+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
22+
23+
define spir_func void @test1() {
24+
entry:
25+
%minval = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
26+
%maxval = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
27+
ret void
28+
}

0 commit comments

Comments
 (0)