Skip to content

Commit 361048e

Browse files
committed
Decouple HIR from the IRContext
Instead of ir.Var, HIR now uses a separate hir.Value. It is numbered, not named, and is unique per function rather than per IRContext. This allows us to convert AST to HIR once and reuse it multiple times. We use this property to add an LRU cache to get_function_hir(). As a result, we lose the mapping between HIR and IR using variable names, which is occasionally useful for debugging. Hopefully, line numbers should suffice for this purpose. Also, there is no name mapping to the AST anyway. Giving up on this name-based mapping also simplifies the logic in _dispatch_call(): we no longer need to remap the variable names and can directly use the variable returned by the op implementation. We also get rid of two uses of assign(): there is now only one use remaining in store_var_impl(). Signed-off-by: Greg Bonik <[email protected]>
1 parent 3ce4742 commit 361048e

File tree

10 files changed

+476
-401
lines changed

10 files changed

+476
-401
lines changed

src/cuda/tile/_compile.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def wrapper(*args, **kwargs):
7373
return wrapper
7474

7575

76-
def _get_final_ir(pyfunc, args, tile_context) -> ir.Block:
76+
def _get_final_ir(pyfunc, args, tile_context) -> ir.Function:
7777
ir_ctx = ir.IRContext(tile_context)
78-
func_hir: hir.Function = get_function_hir(pyfunc, ir_ctx, entry_point=True)
78+
func_hir: hir.Function = get_function_hir(pyfunc, entry_point=True)
7979

8080
ir_args = _bind_kernel_arguments(func_hir.param_names, args, get_constant_annotations(pyfunc))
8181
func_body = hir2ir(func_hir, ir_args, ir_ctx)
@@ -94,7 +94,7 @@ def _get_final_ir(pyfunc, args, tile_context) -> ir.Block:
9494

9595
split_loops(func_body)
9696
dead_code_elimination_pass(func_body)
97-
return func_body
97+
return ir.Function(func_body, func_hir.desc.name, func_hir.body.loc)
9898

9999

100100
def _bind_kernel_arguments(param_names: tuple[str, ...],
@@ -143,7 +143,7 @@ def _log_mlir(bytecode_buf):
143143
print(f"Lowering\n==== TILEIR MLIR module ====\n\n{text}", file=sys.stderr)
144144

145145

146-
def _compiler_crash_dump(func_body,
146+
def _compiler_crash_dump(func_ir: ir.Function,
147147
bytecode_generator,
148148
error_msg,
149149
compiler_flags,
@@ -161,13 +161,13 @@ def _compiler_crash_dump(func_body,
161161
bytecode_generator(writer, anonymize_debug_attr=True)
162162

163163
artifacts = {
164-
f"{func_body.name}.bytecode": bytes(bytecode_buf),
165-
f"{func_body.name}.cutileir": f"{func_body.to_string(include_loc=False)}\n",
164+
f"{func_ir.name}.bytecode": bytes(bytecode_buf),
165+
f"{func_ir.name}.cutileir": f"{func_ir.to_string(include_loc=False)}\n",
166166
"debug_info.txt": debug_info,
167167
}
168168

169169
timestamp = datetime.datetime.now().timestamp()
170-
zip_filename = os.path.abspath(f"crash_dump_{func_body.name}_{timestamp}.zip")
170+
zip_filename = os.path.abspath(f"crash_dump_{func_ir.name}_{timestamp}.zip")
171171
print(f"Dumping crash artifacts to {zip_filename}\n", file=sys.stderr)
172172

173173
with zipfile.ZipFile(zip_filename, "w") as z:
@@ -180,17 +180,17 @@ def compile_tile(pyfunc,
180180
args,
181181
compiler_options: CompilerOptions,
182182
context: TileContext = default_tile_context) -> TileLibrary:
183-
func_body = _get_final_ir(pyfunc, args, context)
183+
func_ir = _get_final_ir(pyfunc, args, context)
184184

185185
if 'CUTILEIR' in context.config.log_keys:
186-
code = (f"==== CuTile IR for {func_body.name}==== \n\n"
187-
f"{func_body.to_string(include_loc=False)}\n\n")
186+
code = (f"==== CuTile IR for {func_ir.name}==== \n\n"
187+
f"{func_ir.to_string(include_loc=False)}\n\n")
188188
print(f'\n{code}', file=sys.stderr)
189189

190190
sm_arch = get_sm_arch()
191191

192192
bytecode_generator = functools.partial(generate_bytecode_for_kernel,
193-
func_body, compiler_options, sm_arch)
193+
func_ir, compiler_options, sm_arch)
194194

195195
bytecode_buf = bytearray()
196196
with bc.write_bytecode(num_functions=1, buf=bytecode_buf) as writer:
@@ -202,9 +202,9 @@ def compile_tile(pyfunc,
202202
if CUDA_TILE_DUMP_BYTECODE is not None:
203203
if not os.path.exists(CUDA_TILE_DUMP_BYTECODE):
204204
os.makedirs(CUDA_TILE_DUMP_BYTECODE)
205-
base_filename = os.path.basename(func_body.loc.filename.split(".")[0])
205+
base_filename = os.path.basename(func_ir.loc.filename.split(".")[0])
206206
path = os.path.join(CUDA_TILE_DUMP_BYTECODE,
207-
f"{base_filename}.ln{func_body.loc.line}.cutile")
207+
f"{base_filename}.ln{func_ir.loc.line}.cutile")
208208
print(f"Dumping TILEIR bytecode to file: {path}", file=sys.stderr)
209209
with open(path, "wb") as f:
210210
f.write(bytecode_buf)
@@ -216,9 +216,9 @@ def compile_tile(pyfunc,
216216
mlir_text = bytecode_to_mlir_text(bytecode_buf)
217217
if not os.path.exists(CUDA_TILE_DUMP_TILEIR):
218218
os.makedirs(CUDA_TILE_DUMP_TILEIR)
219-
base_filename = os.path.basename(func_body.loc.filename.split(".")[0])
219+
base_filename = os.path.basename(func_ir.loc.filename.split(".")[0])
220220
path = os.path.join(
221-
CUDA_TILE_DUMP_TILEIR, f"{base_filename}.ln{func_body.loc.line}.cuda_tile.mlir"
221+
CUDA_TILE_DUMP_TILEIR, f"{base_filename}.ln{func_ir.loc.line}.cuda_tile.mlir"
222222
)
223223
print(f"Dumping TILEIR MLIR module to file:{path}", file=sys.stderr)
224224
with open(path, "w") as f:
@@ -228,7 +228,7 @@ def compile_tile(pyfunc,
228228
"This is currently not a public feature.", file=sys.stderr)
229229

230230
# Compile MLIR module and generate cubin
231-
with tempfile.NamedTemporaryFile(suffix='.bytecode', prefix=func_body.name,
231+
with tempfile.NamedTemporaryFile(suffix='.bytecode', prefix=func_ir.name,
232232
dir=context.config.temp_dir, delete=False) as f:
233233
f.write(bytecode_buf)
234234
f.flush()
@@ -238,12 +238,12 @@ def compile_tile(pyfunc,
238238
timeout_sec=context.config.compiler_timeout_sec)
239239
except TileCompilerError as e:
240240
if context.config.enable_crash_dump:
241-
_compiler_crash_dump(func_body, bytecode_generator, e.message,
241+
_compiler_crash_dump(func_ir, bytecode_generator, e.message,
242242
e.compiler_flags, e.compiler_version)
243243

244244
raise e
245245

246-
return TileLibrary(func_body.name, cubin_file, bytecode_buf, func_body)
246+
return TileLibrary(func_ir.name, cubin_file, bytecode_buf, func_ir.body)
247247

248248

249249
# Adapter between compile_tile() and kernel/TileDispatcher

src/cuda/tile/_ir/hir.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,58 @@
1515

1616

1717
import enum
18+
import threading
1819
from dataclasses import dataclass
1920
from textwrap import indent
2021
from typing import Any, Set, Mapping
2122

2223
from cuda.tile._exception import Loc, FunctionDesc
23-
from cuda.tile._ir.ir import Var
24+
25+
26+
@dataclass(frozen=True)
27+
class Value:
28+
id: int
29+
30+
def __str__(self):
31+
return f"%{self.id}"
32+
33+
34+
# Cache Value objects for reuse
35+
_value_cache = []
36+
_value_cache_lock = threading.Lock()
37+
38+
39+
def make_value(id: int) -> Value:
40+
try:
41+
# Fast path
42+
return _value_cache[id]
43+
except IndexError:
44+
pass
45+
46+
if id >= 2000:
47+
# Don't cache too many objects
48+
return Value(id)
49+
50+
# Add 100 objects at a time to avoid triggering the slow path every time
51+
with _value_cache_lock:
52+
_value_cache.extend([Value(j) for j in range(len(_value_cache), id + 100)])
53+
return _value_cache[id]
2454

2555

2656
# An "Operand" is a value that can be used as a function's argument, or as the function itself.
27-
# There are two kinds of Operands: variables and constants. Using a `Var` instance as an Operand
28-
# signals that this Operand is a variable, e.g. a result of a previous call or a kernel parameter.
29-
# An object of any other type means that it is an immediate constant.
30-
Operand = Var | Any
57+
# There are two kinds of Operands:
58+
# - Using a `Value` instance as an Operand signals that this Operand is a result
59+
# of a previous call, or a kernel parameter.
60+
# - An object of any other type means that it is an immediate constant.
61+
Operand = Value | Any
3162

3263

3364
ModuleType = type(enum)
3465

3566

3667
@dataclass
3768
class Call:
38-
results: tuple[Var, ...]
69+
result: Value | None
3970
callee: Operand
4071
args: tuple[Operand, ...]
4172
kwargs: tuple[tuple[str, Operand], ...]
@@ -45,24 +76,16 @@ def __str__(self):
4576
opfmt = _OperandFormatter([])
4677
loc_str = f" # Line {self.loc.line}"
4778
if self.callee is identity:
48-
return f"{_lhs_var_str(self.results[0])} = {opfmt(self.args[0])}{loc_str}"
79+
return f"{self.result} = {opfmt(self.args[0])}{loc_str}"
80+
lhs_str = "" if self.result is None else f"{self.result} = "
4981
callee_str = opfmt(self.callee)
50-
results_str = ", ".join(_lhs_var_str(r) for r in self.results)
51-
lhs_str = f"{results_str} = " if results_str else ""
5282
args_and_kwargs = (*(opfmt(a) for a in self.args),
5383
*(f"{k}={opfmt(v)}" for k, v in self.kwargs))
5484
args_str = ", ".join(args_and_kwargs)
5585
blocks_str = "".join(indent(f"\n{b}", " ") for b in opfmt.blocks)
5686
return f"{lhs_str}{callee_str}({args_str}){loc_str}{blocks_str}"
5787

5888

59-
def _lhs_var_str(var: Var):
60-
ty = var.try_get_type()
61-
if ty is None:
62-
return var.name
63-
return f"{var.name}: {ty}"
64-
65-
6689
class Jump(enum.Enum):
6790
END_BRANCH = "end_branch"
6891
CONTINUE = "continue"
@@ -72,26 +95,27 @@ class Jump(enum.Enum):
7295

7396
@dataclass
7497
class Block:
75-
name: str
76-
params: tuple[Var, ...]
98+
block_id: int
99+
params: tuple[Value, ...]
77100
calls: list[Call]
78-
results: tuple[Operand, ...]
101+
have_result: bool
102+
result: Operand
79103
jump: Jump | None
80104
jump_loc: Loc
81105
stored_names: Set[str]
82106
loc: Loc
83107

84108
def __str__(self):
85-
params_str = ", ".join(p.name for p in self.params)
109+
params_str = ", ".join(str(p) for p in self.params)
86110
calls_str = "".join(f"\n{c}" for c in self.calls)
87111
if self.jump is not None:
88112
calls_str += "\n" + self.jump_str()
89113
calls_str = indent(calls_str, " ")
90-
return f"^{self.name}({params_str}):{calls_str}"
114+
return f"^{self.block_id}({params_str}):{calls_str}"
91115

92116
def jump_str(self):
93117
opfmt = _OperandFormatter([])
94-
results_str = ",".join(f" {opfmt(r)}" for r in self.results)
118+
results_str = "" if self.result is None else opfmt(self.result)
95119
return f"{self.jump._value_}{results_str} # Line {self.jump_loc.line}"
96120

97121

@@ -102,20 +126,21 @@ class Function:
102126
param_names: tuple[str, ...]
103127
param_locs: tuple[Loc, ...]
104128
frozen_globals: Mapping[str, Any]
129+
value_id_upper_bound: int
105130

106131

107132
@dataclass
108133
class _OperandFormatter:
109134
blocks: list["Block"]
110135

111136
def __call__(self, x: Operand) -> str:
112-
if isinstance(x, Var):
113-
return x.name
137+
if isinstance(x, Value):
138+
return str(x)
114139
elif isinstance(x, ModuleType):
115140
return str(f"<mod:{x.__name__}>")
116141
elif isinstance(x, Block):
117142
self.blocks.append(x)
118-
return f"^{x.name}"
143+
return f"^{x.block_id}"
119144
elif callable(x):
120145
return f"<fn:{x.__name__}>"
121146
else:

0 commit comments

Comments
 (0)