@@ -13,44 +13,100 @@ use melior::{
1313use crate :: ast:: CompilUnit ;
1414
1515pub fn gen ( comp_unit : & CompilUnit , _idx_checks : bool ) -> String {
16- let context = create_context ( ) ;
17- let location = Location :: unknown ( & context) ;
18- let module = Module :: new ( location) ;
19- let mut builder = MlirBuilder :: new ( & context, module) ;
16+ // Check if we need HIVM address spaces
17+ if needs_hivm_address_space ( comp_unit) {
18+ to_mlir:: types:: generate_mlir_string_with_hivm ( comp_unit)
19+ } else {
20+ let context = create_context ( ) ;
21+ let location = Location :: unknown ( & context) ;
22+ let module = Module :: new ( location) ;
23+ let mut builder = MlirBuilder :: new ( & context, module) ;
2024
21- // Two-pass build so that calls know callee result types
22- builder. build_items_two_pass ( comp_unit) ;
25+ // Two-pass build so that calls know callee result types
26+ builder. build_items_two_pass ( comp_unit) ;
2327
24- // Dump the module to string
25- builder. module ( ) . as_operation ( ) . to_string ( )
28+ // Dump the module to string
29+ builder. module ( ) . as_operation ( ) . to_string ( )
30+ }
2631}
2732
2833pub fn gen_checked ( comp_unit : & CompilUnit , _idx_checks : bool ) -> Result < String , String > {
29- let context = create_context ( ) ;
30- let location = Location :: unknown ( & context) ;
31- let module = Module :: new ( location) ;
32- let mut builder = MlirBuilder :: new ( & context, module) ;
34+ // Check if we need HIVM address spaces
35+ if needs_hivm_address_space ( comp_unit) {
36+ Ok ( to_mlir:: types:: generate_mlir_string_with_hivm ( comp_unit) )
37+ } else {
38+ let context = create_context ( ) ;
39+ let location = Location :: unknown ( & context) ;
40+ let module = Module :: new ( location) ;
41+ let mut builder = MlirBuilder :: new ( & context, module) ;
42+
43+ // Two-pass build so that calls know callee result types
44+ builder. build_items_two_pass ( comp_unit) ;
45+
46+ // Verify the module
47+ if !builder. module ( ) . as_operation ( ) . verify ( ) {
48+ return Err ( "MLIR module verification failed" . to_string ( ) ) ;
49+ }
3350
34- // Two-pass build so that calls know callee result types
35- builder. build_items_two_pass ( comp_unit) ;
51+ // Dump the module to string
52+ Ok ( builder. module ( ) . as_operation ( ) . to_string ( ) )
53+ }
54+ }
3655
37- // Verify the module
38- if !builder. module ( ) . as_operation ( ) . verify ( ) {
39- return Err ( "MLIR module verification failed" . to_string ( ) ) ;
56+ /// Check if the compilation unit needs HIVM address spaces
57+ fn needs_hivm_address_space ( comp_unit : & CompilUnit ) -> bool {
58+ for item in & comp_unit. items {
59+ if let crate :: ast:: Item :: FunDef ( fun) = item {
60+ // Only check the main function or functions that are not HIVM placeholders
61+ if fun. ident . name == "main" . into ( ) || !is_hivm_placeholder_function ( fun) {
62+ for param in & fun. param_decls {
63+ if let Some ( ty) = & param. ty {
64+ if has_gpu_memory ( ty) {
65+ return true ;
66+ }
67+ }
68+ }
69+ }
70+ }
4071 }
72+ false
73+ }
4174
42- // Dump the module to string
43- Ok ( builder. module ( ) . as_operation ( ) . to_string ( ) )
75+ /// Check if a function is a HIVM placeholder function
76+ fn is_hivm_placeholder_function ( fun : & crate :: ast:: FunDef ) -> bool {
77+ fun. ident . name . starts_with ( "hivm_" )
78+ }
79+
80+ /// Check if a type has GPU memory qualifiers
81+ fn has_gpu_memory ( ty : & crate :: ast:: Ty ) -> bool {
82+ match & ty. ty {
83+ crate :: ast:: TyKind :: Data ( data_ty) => {
84+ match & data_ty. dty {
85+ crate :: ast:: DataTyKind :: At ( _, mem) => {
86+ matches ! ( mem, crate :: ast:: Memory :: GpuGlobal | crate :: ast:: Memory :: GpuShared | crate :: ast:: Memory :: GpuLocal )
87+ } ,
88+ crate :: ast:: DataTyKind :: Ref ( ref_dty) => {
89+ matches ! ( ref_dty. mem, crate :: ast:: Memory :: GpuGlobal | crate :: ast:: Memory :: GpuShared | crate :: ast:: Memory :: GpuLocal )
90+ } ,
91+ _ => false ,
92+ }
93+ } ,
94+ _ => false ,
95+ }
4496}
4597
4698pub fn create_context ( ) -> Context {
4799 let registry = DialectRegistry :: new ( ) ;
48100 register_all_dialects ( & registry) ;
49101
50- // Custom dialects (hivm, annotation, symbol) are loaded via dialects.rs module
51-
52102 let context = Context :: new ( ) ;
103+
104+ // Allow unregistered dialects to handle HIVM dialect
105+ context. set_allow_unregistered_dialects ( true ) ;
106+
53107 context. append_dialect_registry ( & registry) ;
54108 context. load_all_available_dialects ( ) ;
109+
55110 context
56111}
112+
0 commit comments