|
15 | 15 | from cuda.tile._exception import TileTypeError |
16 | 16 | from cuda.tile._ir.ops_utils import get_dtype |
17 | 17 |
|
18 | | -from .ir import Var |
19 | 18 | from .typing_support import datatype, get_signature |
| 19 | +from .ir import Var, TupleValue |
20 | 20 | from .type import TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \ |
21 | | - ListTy, PointerTy, LooselyTypedScalar, RangeIterType |
| 21 | + ListTy, PointerTy, LooselyTypedScalar, RangeIterType, FunctionTy, ClosureTy, BoundMethodTy, \ |
| 22 | + DTypeConstructor |
22 | 23 | from .. import _datatype |
23 | 24 |
|
24 | 25 |
|
@@ -46,18 +47,31 @@ def decorate(func): |
46 | 47 |
|
47 | 48 | func_sig = get_signature(func) |
48 | 49 | _verify_params_match(stub_sig, func_sig) |
49 | | - |
50 | | - @functools.wraps(func) |
51 | | - def wrapper(*args, **kwargs): |
52 | | - # Memorize the stub and the args so that we can automatically |
53 | | - # provide context for error messages. |
54 | | - old = _current_stub.stub_and_args |
55 | | - _current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs) |
56 | | - try: |
57 | | - return func(*args, **kwargs) |
58 | | - finally: |
59 | | - _current_stub.stub_and_args = old |
60 | | - wrapper._is_coroutine = inspect.iscoroutinefunction(func) |
| 50 | + is_coroutine = inspect.iscoroutinefunction(func) |
| 51 | + if is_coroutine: |
| 52 | + @functools.wraps(func) |
| 53 | + async def wrapper(*args, **kwargs): |
| 54 | + # Memorize the stub and the args so that we can automatically |
| 55 | + # provide context for error messages. |
| 56 | + old = _current_stub.stub_and_args |
| 57 | + _current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs) |
| 58 | + try: |
| 59 | + return await func(*args, **kwargs) |
| 60 | + finally: |
| 61 | + _current_stub.stub_and_args = old |
| 62 | + else: |
| 63 | + @functools.wraps(func) |
| 64 | + def wrapper(*args, **kwargs): |
| 65 | + # Memorize the stub and the args so that we can automatically |
| 66 | + # provide context for error messages. |
| 67 | + old = _current_stub.stub_and_args |
| 68 | + _current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs) |
| 69 | + try: |
| 70 | + return func(*args, **kwargs) |
| 71 | + finally: |
| 72 | + _current_stub.stub_and_args = old |
| 73 | + |
| 74 | + wrapper._is_coroutine = is_coroutine |
61 | 75 | op_implementations[stub] = wrapper |
62 | 76 | return orig_func |
63 | 77 |
|
@@ -96,6 +110,36 @@ def require_constant_bool(var: Var) -> bool: |
96 | 110 | return var.get_constant() |
97 | 111 |
|
98 | 112 |
|
| 113 | +def require_constant_scalar(var: Var) -> bool | int | float: |
| 114 | + ty = var.get_type() |
| 115 | + if not isinstance(ty, DType): |
| 116 | + raise _make_type_error(f"Expected a scalar constant, but given value has type {ty}", var) |
| 117 | + if not var.is_constant(): |
| 118 | + raise _make_type_error(f"Expected a constant, but given value has non-constant type {ty}", |
| 119 | + var) |
| 120 | + ret = var.get_constant() |
| 121 | + assert isinstance(ret, bool | int | float) |
| 122 | + return ret |
| 123 | + |
| 124 | + |
| 125 | +def require_constant_scalar_tuple(var: Var) -> tuple[bool | int | float, ...]: |
| 126 | + ty = require_tuple_type(var) |
| 127 | + ret = [] |
| 128 | + tuple_val = var.get_aggregate() |
| 129 | + assert isinstance(tuple_val, TupleValue) |
| 130 | + for i, (item_ty, item) in enumerate(zip(ty.value_types, tuple_val.items, strict=True)): |
| 131 | + if not isinstance(item_ty, DType): |
| 132 | + raise _make_type_error(f"Expected a tuple of scalar constants," |
| 133 | + f" but item at position #{i} has type {item_ty}", var) |
| 134 | + if not item.is_constant(): |
| 135 | + raise _make_type_error(f"Expected a tuple of scalar constants," |
| 136 | + f" but item at position #{i} has non-constant type {ty}", var) |
| 137 | + value = item.get_constant() |
| 138 | + assert isinstance(value, bool | int | float) |
| 139 | + ret.append(value) |
| 140 | + return tuple(ret) |
| 141 | + |
| 142 | + |
99 | 143 | def require_optional_constant_bool(var: Var) -> Optional[bool]: |
100 | 144 | if var.is_constant() and var.get_constant() is None: |
101 | 145 | return None |
@@ -244,6 +288,16 @@ def require_tile_type(var: Var) -> TileTy: |
244 | 288 | return ty |
245 | 289 |
|
246 | 290 |
|
| 291 | +def require_tile_or_tile_tuple_type(var: Var) -> TileTy | TupleTy: |
| 292 | + ty = var.get_type() |
| 293 | + if isinstance(ty, TileTy): |
| 294 | + return ty |
| 295 | + if isinstance(ty, TupleTy) and all(isinstance(x, TileTy) for x in ty.value_types): |
| 296 | + return ty |
| 297 | + raise _make_type_error(f"Expected a tile or a tuple of tiles, but given value has type {ty}", |
| 298 | + var) |
| 299 | + |
| 300 | + |
247 | 301 | def require_tile_or_scalar_type(var: Var) -> TileTy | DType | PointerTy: |
248 | 302 | ty = var.get_type() |
249 | 303 | if not isinstance(ty, TileTy | DType | PointerTy): |
@@ -359,6 +413,13 @@ def require_index_or_index_tuple_type(var: Var, |
359 | 413 | return ty |
360 | 414 |
|
361 | 415 |
|
| 416 | +def require_callable_type(var: Var) -> FunctionTy | BoundMethodTy | ClosureTy | DTypeConstructor: |
| 417 | + ty = var.get_type() |
| 418 | + if not isinstance(ty, FunctionTy | BoundMethodTy | ClosureTy | DTypeConstructor): |
| 419 | + raise _make_type_error(f"Expected a callable object, but given value has type {ty}", var) |
| 420 | + return ty |
| 421 | + |
| 422 | + |
362 | 423 | class PrintfValidator: |
363 | 424 | # c-format string has the following: %[flags][width][.precision][length]specifier |
364 | 425 | # we only support a subset which makes sense in the tile context |
|
0 commit comments