Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/s_tir/transform/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}

need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
Comment on lines 744 to +749
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability, consider using std::unordered_set for checking the target kind. This makes it easier to add or remove supported targets in the future.

Suggested change
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}
need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
const std::unordered_set<std::string> supported_targets = {"cuda", "rocm", "metal", "webgpu"};
if (!supported_targets.count(target_->kind->name)) {
return false;
}
const std::unordered_set<std::string> no_mask_targets = {"metal", "webgpu"};
need_warp_shuffle_mask_ = !no_mask_targets.count(target_->kind->name);


// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() {
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
if (enable_subgroups_) {
header_stream << "enable subgroups;\n\n";
}
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}

Expand All @@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
}
}

CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
}

runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) {
// clear previous generated state.
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_webgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC {

// whether enable fp16
bool enable_fp16_{false};
// whether enable subgroups
bool enable_subgroups_{false};

/*! \brief the header stream for function label and enable directive if any, goes before any other
* declaration */
Expand Down
59 changes: 59 additions & 0 deletions src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ namespace intrin {

using tirx::FLowerIntrinsic;

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
struct WebGPUWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tirx.webgpu.subgroup_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tirx.webgpu.subgroup_shuffle_up");
} else {
TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tirx.webgpu.subgroup_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
TVM_FFI_ICHECK(call != nullptr);
TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
PrimExpr lane_or_delta = Cast(DataType::UInt(32, call->args[2].dtype().lanes()), call->args[2]);
ffi::Array<PrimExpr> webgpu_args{{call->args[1], lane_or_delta}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), webgpu_args);
}

// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions

struct ReturnAbs {
Expand Down Expand Up @@ -113,6 +137,41 @@ TVM_REGISTER_OP("tirx.trunc")
// extra dispatch
TVM_REGISTER_OP("tirx.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
TVM_REGISTER_OP("tirx.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace intrin
} // namespace codegen
} // namespace tvm
33 changes: 33 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,41 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
// Tags
.set_default_keys({"vulkan", "gpu"});

/*!
* \brief Update WebGPU target attributes for subgroup-enabled lowering.
* Runtime routing on the WebLLM side guarantees subgroup size == 32.
* Runtime routing on the WebLLM side guarantees
* maxComputeInvocationsPerWorkgroup >= 1024.
* This is intentionally constrained for the subgroup-enabled WASM variant.
* When supports_subgroups is true, canonicalize thread_warp_size to 32 so
* TIR lowering can emit subgroup shuffle reductions.
*/
ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, ffi::Any> target) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this function to line ~258, since we have a lot of UpdateXXXAttrs there?

bool subgroups = false;
if (target.count("supports_subgroups")) {
subgroups = Downcast<Bool>(target.at("supports_subgroups"));
}

if (target.count("thread_warp_size")) {
int64_t thread_warp_size = Downcast<Integer>(target.at("thread_warp_size"))->value;
TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1)
<< "WebGPU target with thread_warp_size=" << thread_warp_size
<< " requires supports_subgroups=true";
}

if (subgroups) {
target.Set("thread_warp_size", int64_t(32));
}
return target;
}

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
// thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no
// subgroup ops are emitted.
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});

TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,106 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32
assert "tvm_storage_sync" in After_script


def test_webgpu_warp_reduce():
transform = tvm.s_tir.transform.LowerThreadAllreduce()

@I.ir_module
class Before:
@T.prim_func(private=True)
def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
A_flat = T.decl_buffer(4096, data=A.data)

for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 32)

reduce_data = T.alloc_buffer((1,), "float32", scope="local")
reduce = T.decl_buffer(1, data=reduce_data.data, scope="local")

with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
T.tvm_thread_allreduce(
T.uint32(1),
A_flat[0],
T.bool(True),
reduce[0],
threadIdx_x,
)
if threadIdx_x == 0:
B[i] = reduce[0]

After = transform(Before)
assert After is not None
After_script = After.script()
assert "tvm_warp_shuffle_down" in After_script
assert "tvm_warp_shuffle(" in After_script
assert "tvm_storage_sync" not in After_script
assert "T.uint32(" not in After_script


def test_webgpu_multi_warp_reduce():
transform = tvm.s_tir.transform.LowerThreadAllreduce()

@I.ir_module
class Before:
@T.prim_func(private=True)
def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"max_num_threads": 1024,
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
cross_thread_B = T.alloc_buffer((1,), "float32", scope="local")
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
cross_thread_B_1 = T.decl_buffer((1,), data=cross_thread_B.data, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
A_1 = T.decl_buffer((256,), data=A.data)
T.tvm_thread_allreduce(
T.uint32(1),
A_1[threadIdx_y * 128 + threadIdx_x],
T.bool(True),
cross_thread_B_1[0],
threadIdx_x,
)
if threadIdx_x == 0:
B_1 = T.decl_buffer((2,), data=B.data)
B_1[threadIdx_y] = cross_thread_B_1[0]

After = transform(Before)
assert After is not None
After_script = After.script()
assert "tvm_warp_shuffle_down" in After_script
assert "tvm_storage_sync" in After_script
assert "\"tirx.volatile\": T.bool(True)" in After_script
assert "T.uint32(" not in After_script


if __name__ == "__main__":
tvm.testing.main()
20 changes: 20 additions & 0 deletions tests/python/target/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,5 +426,25 @@ def test_cli_string_rejected():
Target("llvm -mcpu=cortex-a53")


def test_webgpu_target_subgroup_attrs():
"""Test WebGPU target defaults and supports_subgroups canonicalization."""
# Default: thread_warp_size=1, supports_subgroups=False
tgt_default = Target({"kind": "webgpu"})
assert tgt_default.attrs["thread_warp_size"] == 1
assert tgt_default.attrs["supports_subgroups"] == 0

# With supports_subgroups=True: thread_warp_size is set to 32
tgt_subgroups = Target({"kind": "webgpu", "supports_subgroups": True})
assert tgt_subgroups.attrs["thread_warp_size"] == 32
assert tgt_subgroups.attrs["supports_subgroups"] == 1

for config in [
{"kind": "webgpu", "thread_warp_size": 32},
{"kind": "webgpu", "thread_warp_size": 32, "supports_subgroups": False},
]:
with pytest.raises(tvm.TVMError, match="requires supports_subgroups=true"):
Target(config)


if __name__ == "__main__":
tvm.testing.main()
2 changes: 2 additions & 0 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#include "../../src/runtime/metadata.h"
#include "../../src/runtime/workspace_pool.h"
#include "../../src/support/bytes_io.h"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
Comment on lines +43 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: are these two lines related to this PR, or they are for some other purpose?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this I encountered a failure at at WebGPU runtime in the browser, when loading the compiled model WASM.

Upon further digging I found that the break came from the refactor that replaced dmlc::JSONWriter with tvm::ffi::json::Stringify in webgpu_runtime.cc. dmlc was header-only, but tvm::ffi::json splits declaration and implementation, so in the WASM unity build the Stringify implementation was not compiled into the WebGPU runtime unit and showed up as a missing runtime symbol.

These two includes are there to pull the JSON implementation into the same WASM compilation unit.


namespace tvm {
namespace runtime {
Expand Down
3 changes: 3 additions & 0 deletions web/src/webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
if (adapter.features.has("subgroups")) {
requiredFeatures.push("subgroups");
}
// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers, hence `as any`.
const adapterInfo = adapter.info || await (adapter as any).requestAdapterInfo();
Expand Down