11import ast
2+ from collections .abc import Callable
23from dataclasses import dataclass , field
3- from typing import cast
4+ from typing import Any , cast
45
56import xdsl .dialects .builtin as builtin
67import xdsl .dialects .cf as cf
@@ -193,26 +194,28 @@ def visit_BinOp(self, node: ast.BinOp) -> None:
193194 )
194195
195196 def visit_Call (self , node : ast .Call ) -> None :
196- # Resolve function
197- assert isinstance (node .func , ast .Name )
198- func_name = node .func .id
199- source_func = self .type_converter .globals .get (func_name , None )
200- if source_func is None :
201- raise CodeGenerationException (
202- self .file ,
203- node .lineno ,
204- node .col_offset ,
205- f"Function '{ func_name } ' is not defined in scope." ,
206- )
207- ir_op = self .type_converter .function_registry .get_operation_constructor (
208- source_func
209- )
197+ match node .func :
198+ case ast .Name ():
199+ source_kind = "function"
200+ source , source_name = self ._call_source_function (node )
201+ case ast .Attribute ():
202+ source_kind = "classmethod"
203+ source , source_name = self ._call_source_classmethod (node )
204+ case _:
205+ raise CodeGenerationException (
206+ self .file ,
207+ node .lineno ,
208+ node .col_offset ,
209+ "Unsupported call expression." ,
210+ )
211+
212+ ir_op = self .type_converter .function_registry .get_operation_constructor (source )
210213 if ir_op is None :
211214 raise CodeGenerationException (
212215 self .file ,
213216 node .lineno ,
214217 node .col_offset ,
215- f"Function '{ func_name } ' is not registered." ,
218+ f"{ source_kind . capitalize () } '{ source_name } ' is not registered." ,
216219 )
217220
218221 # Resolve arguments
@@ -224,7 +227,7 @@ def visit_Call(self, node: ast.Call) -> None:
224227 self .file ,
225228 node .lineno ,
226229 node .col_offset ,
227- "Function arguments must be declared variables." ,
230+ f" { source_kind . capitalize () } arguments must be declared variables." ,
228231 )
229232 args .append (arg_op := symref .FetchOp (arg .id , self .symbol_table [arg .id ]))
230233 self .inserter .insert_op (arg_op )
@@ -240,7 +243,7 @@ def visit_Call(self, node: ast.Call) -> None:
240243 self .file ,
241244 node .lineno ,
242245 node .col_offset ,
243- "Function arguments must be declared variables." ,
246+ f" { source_kind . capitalize () } arguments must be declared variables." ,
244247 )
245248 assert keyword .arg is not None
246249 kwargs [keyword .arg ] = symref .FetchOp (
@@ -250,6 +253,50 @@ def visit_Call(self, node: ast.Call) -> None:
250253
251254 self .inserter .insert_op (ir_op (* args , ** kwargs ))
252255
256+ # Get called function for a call expression.
257+ def _call_source_function (self , node : ast .Call ) -> tuple [Callable [..., Any ], str ]:
258+ assert isinstance (node .func , ast .Name )
259+
260+ func_name = node .func .id
261+ func = self .type_converter .globals .get (func_name , None )
262+ if func is None :
263+ raise CodeGenerationException (
264+ self .file ,
265+ node .lineno ,
266+ node .col_offset ,
267+ f"Function '{ func_name } ' is not defined in scope." ,
268+ )
269+ return func , func_name
270+
271+ # Get called classmethod for a call expression.
272+ def _call_source_classmethod (
273+ self , node : ast .Call
274+ ) -> tuple [Callable [..., Any ], str ]:
275+ assert isinstance (node .func , ast .Attribute )
276+ assert isinstance (node .func .value , ast .Name )
277+
278+ class_name = node .func .value .id
279+ method_name = node .func .attr
280+ classmethod_name = f"{ class_name } .{ method_name } "
281+
282+ source_class = self .type_converter .globals .get (class_name , None )
283+ if source_class is None :
284+ raise CodeGenerationException (
285+ self .file ,
286+ node .lineno ,
287+ node .col_offset ,
288+ f"Class '{ class_name } ' is not defined in scope." ,
289+ )
290+ classmethod_ = getattr (source_class , method_name , None )
291+ if classmethod_ is None :
292+ raise CodeGenerationException (
293+ self .file ,
294+ node .lineno ,
295+ node .col_offset ,
296+ f"Method '{ method_name } ' is not defined on class '{ class_name } '." ,
297+ )
298+ return classmethod_ , classmethod_name
299+
253300 def visit_Compare (self , node : ast .Compare ) -> None :
254301 # Allow a single comparison only.
255302 if len (node .comparators ) != 1 or len (node .ops ) != 1 :
0 commit comments