diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index f66ff6b6636f..749707cb2942 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -73,7 +73,7 @@ def check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int): f"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{max_num_threads_per_block}" ) - if str(target.kind) == "webgpu": + if target.kind.name == "webgpu": # https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got bdz={bdz}" assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}" @@ -623,7 +623,7 @@ def __init__( # pylint: disable=too-many-locals # pylint: enable=line-too-long ] - if str(target.kind) == "llvm": + if target.kind.name == "llvm": if attn_kind_single == "mla": raise ValueError("MLA is not supported in TIR kernels for now.") # pylint: disable=line-too-long @@ -1098,7 +1098,7 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target): # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( - str(target.kind) == "webgpu" + target.kind.name == "webgpu" and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 ): tile_z = 8 diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index c55aa3eceb22..2c7f66cf3707 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -331,7 +331,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: dict[str, Any], target: Target) # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( - str(target.kind) == "webgpu" + target.kind.name == "webgpu" and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 ): tile_z = 8 @@ -898,7 +898,7 @@ def tree_attn_with_paged_kv_cache( # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( - str(target.kind) == "webgpu" + target.kind.name == "webgpu" and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 ): tile_z = 8 diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 7586e0df1576..df78faa5968d 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -473,7 +473,7 @@ def enabled(target): Parameters ---------- - target : str + target : str or Dict[str, Any] or tvm.target.Target The target device type. Returns @@ -490,7 +490,7 @@ def enabled(target): if isinstance(target, dict): target = target.get("kind", "") elif hasattr(target, "kind"): - target = str(target.kind) + target = target.kind.name return _ffi_api.RuntimeEnabled(target) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index c5937eec4ec8..125fa9586b03 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -463,7 +463,7 @@ def device_enabled(target): Parameters ---------- - target : str + target : str or Dict[str, Any] or tvm.target.Target Target string to check against Returns @@ -485,7 +485,7 @@ def device_enabled(target): if isinstance(target, dict): target_kind = target["kind"] elif hasattr(target, "kind"): - target_kind = str(target.kind) + target_kind = target.kind.name else: target_kind = target return any(target_kind == t["target_kind"] for t in _get_targets() if t["is_runnable"])