Skip to content

Commit ebc4915

Browse files
committed
Add ct.reduce()
Signed-off-by: Greg Bonik <[email protected]>
1 parent 8c3c4b6 commit ebc4915

File tree

11 files changed

+564
-128
lines changed

11 files changed

+564
-128
lines changed

changelog.d/custom-reduction.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added support for custom reduction via `ct.reduce()`.

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
pow,
121121
printf,
122122
prod,
123+
reduce,
123124
reshape,
124125
rsqrt,
125126
scatter,
@@ -250,6 +251,7 @@
250251
"pow",
251252
"printf",
252253
"prod",
254+
"reduce",
253255
"reshape",
254256
"rsqrt",
255257
"scatter",

src/cuda/tile/_ir/ir.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from cuda.tile._exception import (
2222
TileTypeError, Loc, TileInternalError
2323
)
24+
from .. import TileSyntaxError
2425
from .._cext import TileContext
2526
from .._context import TileContextConfig
2627

@@ -317,9 +318,20 @@ def terminator(cls):
317318
return cls
318319

319320

320-
def has_side_effects(cls):
321-
cls._has_side_effects = True
322-
return cls
321+
class MemoryEffect(enum.IntEnum):
322+
# Int value assigned here is meaningful.
323+
# It implies the relative strength of memory effects.
324+
# For example, NONE < LOAD < STORE.
325+
NONE = 0
326+
LOAD = 1
327+
STORE = 2
328+
329+
330+
def memory_effect(eff: MemoryEffect):
331+
def decorate(cls):
332+
cls.memory_effect = eff
333+
return cls
334+
return decorate
323335

324336

325337
def has_multiple_results(cls):
@@ -399,19 +411,24 @@ def finalize_loopvar_type(self, body_var: Var):
399411

400412

401413
class Builder:
402-
def __init__(self, ctx: IRContext, loc: Loc):
414+
def __init__(self, ctx: IRContext, loc: Loc, reduction_body: bool = False):
403415
self.ir_ctx = ctx
404416
self.is_terminated = False
405417
self._loc = loc
406418
self._ops = []
407419
self._entered = False
408420
self._prev_builder = None
409421
self._var_map: Dict[str, Var] = dict()
422+
self.reduction_body = reduction_body
410423

411424
def add_operation(self, op_class,
412425
result_ty: Type | None | Tuple[Type | None, ...],
413426
attrs_and_operands,
414427
result: Var | Sequence[Var] | None = None) -> Var | Tuple[Var, ...]:
428+
if self.reduction_body and op_class.memory_effect != MemoryEffect.NONE:
429+
raise TileSyntaxError("Operations with memory effects are not supported"
430+
" inside reduction body")
431+
415432
assert not self.is_terminated
416433
force_type = False
417434
if isinstance(result_ty, tuple):
@@ -504,10 +521,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
504521

505522

506523
@contextmanager
507-
def nested_block(loc: Loc):
524+
def nested_block(loc: Loc, reduction_body: bool = False):
508525
prev_builder = Builder.get_current()
509526
block = Block(prev_builder.ir_ctx, loc=loc)
510-
with Builder(prev_builder.ir_ctx, loc) as builder:
527+
with Builder(prev_builder.ir_ctx, loc,
528+
reduction_body=reduction_body or prev_builder.reduction_body) as builder:
511529
yield block
512530
block.extend(builder.ops)
513531

@@ -520,7 +538,7 @@ class _CurrentBuilder(threading.local):
520538

521539

522540
class Operation:
523-
_has_side_effects = False
541+
memory_effect = MemoryEffect.NONE
524542
_multiple_results = False
525543

526544
def __init__(
@@ -588,10 +606,6 @@ def all_inputs(self) -> Iterator[Var]:
588606
def is_terminator(self) -> bool:
589607
return self._is_terminator
590608

591-
@property
592-
def has_side_effects(self) -> bool:
593-
return self._has_side_effects
594-
595609
def _add_operand(self, name: str, var: Var | Tuple[Var, ...]):
596610
if isinstance(var, Var) and var.is_aggregate() and self.op != "assign":
597611
# Don't allow aggregate values as operands, except for arrays and lists.

src/cuda/tile/_ir/op_impl.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from cuda.tile._exception import TileTypeError
1616
from cuda.tile._ir.ops_utils import get_dtype
1717

18-
from .ir import Var
1918
from .typing_support import datatype, get_signature
19+
from .ir import Var, TupleValue
2020
from .type import TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
21-
ListTy, PointerTy, LooselyTypedScalar, RangeIterType
21+
ListTy, PointerTy, LooselyTypedScalar, RangeIterType, FunctionTy, ClosureTy, BoundMethodTy, \
22+
DTypeConstructor
2223
from .. import _datatype
2324

2425

@@ -46,18 +47,31 @@ def decorate(func):
4647

4748
func_sig = get_signature(func)
4849
_verify_params_match(stub_sig, func_sig)
49-
50-
@functools.wraps(func)
51-
def wrapper(*args, **kwargs):
52-
# Memorize the stub and the args so that we can automatically
53-
# provide context for error messages.
54-
old = _current_stub.stub_and_args
55-
_current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs)
56-
try:
57-
return func(*args, **kwargs)
58-
finally:
59-
_current_stub.stub_and_args = old
60-
wrapper._is_coroutine = inspect.iscoroutinefunction(func)
50+
is_coroutine = inspect.iscoroutinefunction(func)
51+
if is_coroutine:
52+
@functools.wraps(func)
53+
async def wrapper(*args, **kwargs):
54+
# Memorize the stub and the args so that we can automatically
55+
# provide context for error messages.
56+
old = _current_stub.stub_and_args
57+
_current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs)
58+
try:
59+
return await func(*args, **kwargs)
60+
finally:
61+
_current_stub.stub_and_args = old
62+
else:
63+
@functools.wraps(func)
64+
def wrapper(*args, **kwargs):
65+
# Memorize the stub and the args so that we can automatically
66+
# provide context for error messages.
67+
old = _current_stub.stub_and_args
68+
_current_stub.stub_and_args = (stub, stub_sig, func_sig, args, kwargs)
69+
try:
70+
return func(*args, **kwargs)
71+
finally:
72+
_current_stub.stub_and_args = old
73+
74+
wrapper._is_coroutine = is_coroutine
6175
op_implementations[stub] = wrapper
6276
return orig_func
6377

@@ -96,6 +110,36 @@ def require_constant_bool(var: Var) -> bool:
96110
return var.get_constant()
97111

98112

113+
def require_constant_scalar(var: Var) -> bool | int | float:
114+
ty = var.get_type()
115+
if not isinstance(ty, DType):
116+
raise _make_type_error(f"Expected a scalar constant, but given value has type {ty}", var)
117+
if not var.is_constant():
118+
raise _make_type_error(f"Expected a constant, but given value has non-constant type {ty}",
119+
var)
120+
ret = var.get_constant()
121+
assert isinstance(ret, bool | int | float)
122+
return ret
123+
124+
125+
def require_constant_scalar_tuple(var: Var) -> tuple[bool | int | float, ...]:
126+
ty = require_tuple_type(var)
127+
ret = []
128+
tuple_val = var.get_aggregate()
129+
assert isinstance(tuple_val, TupleValue)
130+
for i, (item_ty, item) in enumerate(zip(ty.value_types, tuple_val.items, strict=True)):
131+
if not isinstance(item_ty, DType):
132+
raise _make_type_error(f"Expected a tuple of scalar constants,"
133+
f" but item at position #{i} has type {item_ty}", var)
134+
if not item.is_constant():
135+
raise _make_type_error(f"Expected a tuple of scalar constants,"
136+
f" but item at position #{i} has non-constant type {ty}", var)
137+
value = item.get_constant()
138+
assert isinstance(value, bool | int | float)
139+
ret.append(value)
140+
return tuple(ret)
141+
142+
99143
def require_optional_constant_bool(var: Var) -> Optional[bool]:
100144
if var.is_constant() and var.get_constant() is None:
101145
return None
@@ -244,6 +288,16 @@ def require_tile_type(var: Var) -> TileTy:
244288
return ty
245289

246290

291+
def require_tile_or_tile_tuple_type(var: Var) -> TileTy | TupleTy:
292+
ty = var.get_type()
293+
if isinstance(ty, TileTy):
294+
return ty
295+
if isinstance(ty, TupleTy) and all(isinstance(x, TileTy) for x in ty.value_types):
296+
return ty
297+
raise _make_type_error(f"Expected a tile or a tuple of tiles, but given value has type {ty}",
298+
var)
299+
300+
247301
def require_tile_or_scalar_type(var: Var) -> TileTy | DType | PointerTy:
248302
ty = var.get_type()
249303
if not isinstance(ty, TileTy | DType | PointerTy):
@@ -359,6 +413,13 @@ def require_index_or_index_tuple_type(var: Var,
359413
return ty
360414

361415

416+
def require_callable_type(var: Var) -> FunctionTy | BoundMethodTy | ClosureTy | DTypeConstructor:
417+
ty = var.get_type()
418+
if not isinstance(ty, FunctionTy | BoundMethodTy | ClosureTy | DTypeConstructor):
419+
raise _make_type_error(f"Expected a callable object, but given value has type {ty}", var)
420+
return ty
421+
422+
362423
class PrintfValidator:
363424
# c-format string has the following: %[flags][width][.precision][length]specifier
364425
# we only support a subset which makes sense in the tile context

0 commit comments

Comments
 (0)