-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[WebGPU] Add gating logic for subgroup shuffle primitives #18823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
097d05f
f119bbd
b1e3688
07d011c
3298e94
b1139a9
e9697fe
d95827a
397ac1b
89d6142
9a3edc9
4fb4cce
3c2ab40
53409cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { | |||||||||||||||||||||||||||||||
| bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent, | ||||||||||||||||||||||||||||||||
| int contiguous_reduce_extent) { | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal")) { | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) { | ||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal"; | ||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu"; | ||||||||||||||||||||||||||||||||
|
Comment on lines
744
to
+749
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve maintainability, consider using
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // rocm only supports 32 bit operands for shuffling at the moment | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name == "rocm") && | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -427,8 +427,41 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) | |
| // Tags | ||
| .set_default_keys({"vulkan", "gpu"}); | ||
|
|
||
| /*! | ||
| * \brief Update WebGPU target attributes for subgroup-enabled lowering. | ||
| * Runtime routing on the WebLLM side guarantees subgroup size == 32. | ||
| * Runtime routing on the WebLLM side guarantees | ||
| * maxComputeInvocationsPerWorkgroup >= 1024. | ||
| * This is intentionally constrained for the subgroup-enabled WASM variant. | ||
| * When supports_subgroups is true, canonicalize thread_warp_size to 32 so | ||
| * TIR lowering can emit subgroup shuffle reductions. | ||
| */ | ||
| ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, ffi::Any> target) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this function to line ~258, since we have a lot of UpdateXXXAttrs there? |
||
| bool subgroups = false; | ||
| if (target.count("supports_subgroups")) { | ||
| subgroups = Downcast<Bool>(target.at("supports_subgroups")); | ||
| } | ||
|
|
||
| if (target.count("thread_warp_size")) { | ||
| int64_t thread_warp_size = Downcast<Integer>(target.at("thread_warp_size"))->value; | ||
| TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1) | ||
| << "WebGPU target with thread_warp_size=" << thread_warp_size | ||
| << " requires supports_subgroups=true"; | ||
| } | ||
|
|
||
| if (subgroups) { | ||
| target.Set("thread_warp_size", int64_t(32)); | ||
| } | ||
| return target; | ||
| } | ||
|
|
||
| TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) | ||
| .add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256)) | ||
| .add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false)) | ||
| // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no | ||
| // subgroup ops are emitted. | ||
| .add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1)) | ||
| .set_target_canonicalizer(UpdateWebGPUAttrs) | ||
| .set_default_keys({"webgpu", "gpu"}); | ||
|
|
||
| TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,8 @@ | |
| #include "../../src/runtime/metadata.h" | ||
| #include "../../src/runtime/workspace_pool.h" | ||
| #include "../../src/support/bytes_io.h" | ||
| #include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" | ||
| #include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" | ||
|
Comment on lines
+43
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious: are these two lines related to this PR, or they are for some other purpose?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this I encountered a failure at at WebGPU runtime in the browser, when loading the compiled model WASM. Upon further digging I found that the break came from the refactor that replaced These two includes are there to pull the JSON implementation into the same WASM compilation unit. |
||
|
|
||
| namespace tvm { | ||
| namespace runtime { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.