Skip to content

Commit d7c990a

Browse files
authored
frontend: (pyast) support classmethods (#5595)
1 parent af616a9 commit d7c990a

File tree

3 files changed

+174
-27
lines changed

3 files changed

+174
-27
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# RUN: python %s | filecheck %s
2+
3+
from ctypes import c_int32
4+
5+
from xdsl.dialects import arith, builtin
6+
from xdsl.frontend.pyast.context import PyASTContext
7+
from xdsl.frontend.pyast.utils.exceptions import CodeGenerationException
8+
9+
10+
class Adder:
11+
@classmethod
12+
def add_i32(cls, operand1: c_int32, operand2: c_int32) -> c_int32: ...
13+
14+
15+
ctx = PyASTContext()
16+
ctx.register_type(c_int32, builtin.i32)
17+
ctx.register_function(Adder.add_i32, arith.AddiOp)
18+
19+
20+
@ctx.parse_program
21+
def test_add(x: c_int32, y: c_int32) -> c_int32:
22+
return Adder.add_i32(x, operand2=y)
23+
24+
25+
print(test_add.module)
26+
# CHECK-NEXT: builtin.module {
27+
# CHECK-NEXT: func.func @test_add(%x : i32, %y : i32) -> i32 {
28+
# CHECK-NEXT: %0 = arith.addi %x, %y : i32
29+
# CHECK-NEXT: func.return %0 : i32
30+
# CHECK-NEXT: }
31+
# CHECK-NEXT: }
32+
33+
34+
# CHECK-NEXT: Classmethod arguments must be declared variables.
35+
@ctx.parse_program
36+
def test_args():
37+
return Adder.add_i32(1, 2) # pyright: ignore[reportArgumentType]
38+
39+
40+
try:
41+
test_args.module
42+
except CodeGenerationException as e:
43+
print(e.msg)
44+
45+
46+
# ================================================= #
47+
# Disable the desymref pass for the remaining tests #
48+
# ================================================= #
49+
ctx.post_transforms = []
50+
51+
52+
# CHECK-NEXT: Classmethod arguments must be declared variables.
53+
@ctx.parse_program
54+
def test_more_args():
55+
return Adder.add_i32(operand1=1, operand2=2) # pyright: ignore[reportArgumentType]
56+
57+
58+
try:
59+
test_more_args.module
60+
except CodeGenerationException as e:
61+
print(e.msg)
62+
63+
64+
class Class:
65+
@classmethod
66+
def method(cls):
67+
pass
68+
69+
70+
# CHECK-NEXT: Classmethod 'Class.method' is not registered.
71+
@ctx.parse_program
72+
def test_unregistered():
73+
return Class.method() # noqa: F821
74+
75+
76+
try:
77+
test_unregistered.module
78+
except CodeGenerationException as e:
79+
print(e.msg)
80+
81+
82+
# CHECK-NEXT: Method 'method' is not defined on class 'Class'.
83+
@ctx.parse_program
84+
def test_missing_method():
85+
return Class.method() # noqa: F821
86+
87+
88+
del Class.method
89+
90+
try:
91+
test_missing_method.module
92+
except CodeGenerationException as e:
93+
print(e.msg)
94+
95+
96+
# CHECK-NEXT: Class 'Class' is not defined in scope.
97+
@ctx.parse_program
98+
def test_missing_class():
99+
return Class.method() # noqa: F821
100+
101+
102+
del Class
103+
104+
try:
105+
test_missing_class.module
106+
except CodeGenerationException as e:
107+
print(e.msg)

tests/filecheck/frontend/pypdl/pdl.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,16 @@
2727

2828
from xdsl.dialects import arith, builtin, pdl
2929
from xdsl.frontend import pypdl
30-
from xdsl.ir import Operation
3130
from xdsl.rewriter import Rewriter
3231

33-
34-
def erase_op(operation: Operation) -> None:
35-
"""Shim to avoid `Expr` AST node required for methods."""
36-
return Rewriter.erase_op(operation)
37-
38-
3932
ctx = pypdl.PyPDLContext()
4033
ctx.register_type(arith.ConstantOp, pdl.OperationType())
41-
ctx.register_function(erase_op, pdl.EraseOp)
34+
ctx.register_function(Rewriter.erase_op, pdl.EraseOp)
4235

4336

4437
@ctx.parse_program
4538
def constant_replace(matched_operation: arith.ConstantOp):
46-
erase_op(matched_operation)
39+
Rewriter.erase_op(matched_operation)
4740

4841

4942
# Check that the DSL correctly rewrites on the xDSL data structures

xdsl/frontend/pyast/code_generation.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
2+
from collections.abc import Callable
23
from dataclasses import dataclass, field
3-
from typing import cast
4+
from typing import Any, cast
45

56
import xdsl.dialects.builtin as builtin
67
import xdsl.dialects.cf as cf
@@ -193,26 +194,28 @@ def visit_BinOp(self, node: ast.BinOp) -> None:
193194
)
194195

195196
def visit_Call(self, node: ast.Call) -> None:
196-
# Resolve function
197-
assert isinstance(node.func, ast.Name)
198-
func_name = node.func.id
199-
source_func = self.type_converter.globals.get(func_name, None)
200-
if source_func is None:
201-
raise CodeGenerationException(
202-
self.file,
203-
node.lineno,
204-
node.col_offset,
205-
f"Function '{func_name}' is not defined in scope.",
206-
)
207-
ir_op = self.type_converter.function_registry.get_operation_constructor(
208-
source_func
209-
)
197+
match node.func:
198+
case ast.Name():
199+
source_kind = "function"
200+
source, source_name = self._call_source_function(node)
201+
case ast.Attribute():
202+
source_kind = "classmethod"
203+
source, source_name = self._call_source_classmethod(node)
204+
case _:
205+
raise CodeGenerationException(
206+
self.file,
207+
node.lineno,
208+
node.col_offset,
209+
"Unsupported call expression.",
210+
)
211+
212+
ir_op = self.type_converter.function_registry.get_operation_constructor(source)
210213
if ir_op is None:
211214
raise CodeGenerationException(
212215
self.file,
213216
node.lineno,
214217
node.col_offset,
215-
f"Function '{func_name}' is not registered.",
218+
f"{source_kind.capitalize()} '{source_name}' is not registered.",
216219
)
217220

218221
# Resolve arguments
@@ -224,7 +227,7 @@ def visit_Call(self, node: ast.Call) -> None:
224227
self.file,
225228
node.lineno,
226229
node.col_offset,
227-
"Function arguments must be declared variables.",
230+
f"{source_kind.capitalize()} arguments must be declared variables.",
228231
)
229232
args.append(arg_op := symref.FetchOp(arg.id, self.symbol_table[arg.id]))
230233
self.inserter.insert_op(arg_op)
@@ -240,7 +243,7 @@ def visit_Call(self, node: ast.Call) -> None:
240243
self.file,
241244
node.lineno,
242245
node.col_offset,
243-
"Function arguments must be declared variables.",
246+
f"{source_kind.capitalize()} arguments must be declared variables.",
244247
)
245248
assert keyword.arg is not None
246249
kwargs[keyword.arg] = symref.FetchOp(
@@ -250,6 +253,50 @@ def visit_Call(self, node: ast.Call) -> None:
250253

251254
self.inserter.insert_op(ir_op(*args, **kwargs))
252255

256+
# Get called function for a call expression.
257+
def _call_source_function(self, node: ast.Call) -> tuple[Callable[..., Any], str]:
258+
assert isinstance(node.func, ast.Name)
259+
260+
func_name = node.func.id
261+
func = self.type_converter.globals.get(func_name, None)
262+
if func is None:
263+
raise CodeGenerationException(
264+
self.file,
265+
node.lineno,
266+
node.col_offset,
267+
f"Function '{func_name}' is not defined in scope.",
268+
)
269+
return func, func_name
270+
271+
# Get called classmethod for a call expression.
272+
def _call_source_classmethod(
273+
self, node: ast.Call
274+
) -> tuple[Callable[..., Any], str]:
275+
assert isinstance(node.func, ast.Attribute)
276+
assert isinstance(node.func.value, ast.Name)
277+
278+
class_name = node.func.value.id
279+
method_name = node.func.attr
280+
classmethod_name = f"{class_name}.{method_name}"
281+
282+
source_class = self.type_converter.globals.get(class_name, None)
283+
if source_class is None:
284+
raise CodeGenerationException(
285+
self.file,
286+
node.lineno,
287+
node.col_offset,
288+
f"Class '{class_name}' is not defined in scope.",
289+
)
290+
classmethod_ = getattr(source_class, method_name, None)
291+
if classmethod_ is None:
292+
raise CodeGenerationException(
293+
self.file,
294+
node.lineno,
295+
node.col_offset,
296+
f"Method '{method_name}' is not defined on class '{class_name}'.",
297+
)
298+
return classmethod_, classmethod_name
299+
253300
def visit_Compare(self, node: ast.Compare) -> None:
254301
# Allow a single comparison only.
255302
if len(node.comparators) != 1 or len(node.ops) != 1:

0 commit comments

Comments
 (0)