Skip to content

Commit d31c4ec

Browse files
committed
add new driver test to check for valid broadcast impl
1 parent 03eab5e commit d31c4ec

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
@Internal @Builtin("SubgroupLocalInvocationId")
2+
var input u32 subgroup_local_id;
3+
4+
fn subgroup_min varying u32(varying u32 x) {
5+
return (ext_instr["spirv.core", /* OpGroupNonUniformUMin */ 354, varying u32](3, 0, x));
6+
}
7+
8+
fn subgroup_shuffle varying u32(varying u32 x, varying u32 i) {
9+
return (ext_instr["spirv.core", /* OpGroupNonUniformShuffle */ 345, varying u32](3, x, i));
10+
}
11+
12+
fn subgroup_elect_first varying bool() {
13+
val tid = subgroup_local_id;
14+
//return (tid == subgroup_min(tid));
15+
return (ext_instr["spirv.core", 333, varying bool](3));
16+
}
17+
18+
@Alias type mask_t = u64;
19+
20+
fn subgroup_active_mask uniform mask_t() {
21+
return (ext_instr["spirv.core", 339, uniform mask_t](3, true));
22+
}
23+
24+
fn subgroup_broadcast_first uniform u64(varying u64 x) {
25+
return (ext_instr["spirv.core", 338, uniform u64](3, x));
26+
}
27+
28+
@Exported @EntryPoint("Compute") @WorkgroupSize(32, 1, 1) fn main() {
29+
val tid = subgroup_local_id;
30+
val x = convert[u64](tid) * u64 33333333 + u64 111111111;
31+
val y = subgroup_broadcast_first(x);
32+
debug_printf("tid = %d x = %lu, y = %lu\n", tid, x, y);
33+
34+
return ();
35+
}

0 commit comments

Comments
 (0)