Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int):
f"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{max_num_threads_per_block}"
)

if str(target.kind) == "webgpu":
if target.kind.name == "webgpu":
# https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez
assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got bdz={bdz}"
assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}"
Expand Down Expand Up @@ -623,7 +623,7 @@ def __init__( # pylint: disable=too-many-locals
# pylint: enable=line-too-long
]

if str(target.kind) == "llvm":
if target.kind.name == "llvm":
if attn_kind_single == "mla":
raise ValueError("MLA is not supported in TIR kernels for now.")
# pylint: disable=line-too-long
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target):

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
str(target.kind) == "webgpu"
target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/frontend/nn/llm/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: dict[str, Any], target: Target)

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
str(target.kind) == "webgpu"
target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
Expand Down Expand Up @@ -898,7 +898,7 @@ def tree_attn_with_paged_kv_cache(

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
str(target.kind) == "webgpu"
target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def enabled(target):

Parameters
----------
target : str
target : str or Dict[str, Any] or tvm.target.Target
The target device type.

Returns
Expand All @@ -490,7 +490,7 @@ def enabled(target):
if isinstance(target, dict):
target = target.get("kind", "")
elif hasattr(target, "kind"):
target = str(target.kind)
target = target.kind.name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using target.kind.name is less robust than the previous str(target.kind). In TVM, some objects like tvm.runtime.Module have a kind attribute that is a string. For such objects, hasattr(target, "kind") will be true, but target.kind.name will raise an AttributeError because strings do not have a .name attribute. Since str(TargetKind) is equivalent to TargetKind.name, it is safer to use getattr(target.kind, "name", str(target.kind)) to maintain compatibility with non-Target objects.

Suggested change
target = target.kind.name
target = getattr(target.kind, "name", str(target.kind))

return _ffi_api.RuntimeEnabled(target)


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def device_enabled(target):

Parameters
----------
target : str
target : str or Dict[str, Any] or tvm.target.Target
Target string to check against

Returns
Expand All @@ -485,7 +485,7 @@ def device_enabled(target):
if isinstance(target, dict):
target_kind = target["kind"]
elif hasattr(target, "kind"):
target_kind = str(target.kind)
target_kind = target.kind.name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in tvm.runtime.enabled, using target.kind.name here can cause an AttributeError if target is an object where kind is a string (e.g., a tvm.runtime.Module). Using getattr with a fallback to str() ensures robustness while still preferring the .name attribute if present.

Suggested change
target_kind = target.kind.name
target_kind = getattr(target.kind, "name", str(target.kind))

else:
target_kind = target
return any(target_kind == t["target_kind"] for t in _get_targets() if t["is_runnable"])
Expand Down
Loading