Skip to content

Commit 60dfed7

Browse files
committed
feat: implement hivm.address space for gpu.global
1 parent a071d87 commit 60dfed7

File tree

4 files changed

+303
-218
lines changed

4 files changed

+303
-218
lines changed

examples/simple/gpu_mem.desc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
fn add<n: nat, r: prv>(
2+
a: &r shrd gpu.global [i16; 16],
3+
b: &r shrd gpu.global [i16; 16],
4+
c: &r uniq gpu.global [i16; 16]
5+
) -[grid: gpu.grid<X<1>, X<16>>]-> () {
6+
// a = b + c;
7+
()
8+
}

src/codegen/mlir/mod.rs

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,100 @@ use melior::{
1313
use crate::ast::CompilUnit;
1414

1515
pub fn gen(comp_unit: &CompilUnit, _idx_checks: bool) -> String {
16-
let context = create_context();
17-
let location = Location::unknown(&context);
18-
let module = Module::new(location);
19-
let mut builder = MlirBuilder::new(&context, module);
16+
// Check if we need HIVM address spaces
17+
if needs_hivm_address_space(comp_unit) {
18+
to_mlir::types::generate_mlir_string_with_hivm(comp_unit)
19+
} else {
20+
let context = create_context();
21+
let location = Location::unknown(&context);
22+
let module = Module::new(location);
23+
let mut builder = MlirBuilder::new(&context, module);
2024

21-
// Two-pass build so that calls know callee result types
22-
builder.build_items_two_pass(comp_unit);
25+
// Two-pass build so that calls know callee result types
26+
builder.build_items_two_pass(comp_unit);
2327

24-
// Dump the module to string
25-
builder.module().as_operation().to_string()
28+
// Dump the module to string
29+
builder.module().as_operation().to_string()
30+
}
2631
}
2732

2833
pub fn gen_checked(comp_unit: &CompilUnit, _idx_checks: bool) -> Result<String, String> {
29-
let context = create_context();
30-
let location = Location::unknown(&context);
31-
let module = Module::new(location);
32-
let mut builder = MlirBuilder::new(&context, module);
34+
// Check if we need HIVM address spaces
35+
if needs_hivm_address_space(comp_unit) {
36+
Ok(to_mlir::types::generate_mlir_string_with_hivm(comp_unit))
37+
} else {
38+
let context = create_context();
39+
let location = Location::unknown(&context);
40+
let module = Module::new(location);
41+
let mut builder = MlirBuilder::new(&context, module);
42+
43+
// Two-pass build so that calls know callee result types
44+
builder.build_items_two_pass(comp_unit);
45+
46+
// Verify the module
47+
if !builder.module().as_operation().verify() {
48+
return Err("MLIR module verification failed".to_string());
49+
}
3350

34-
// Two-pass build so that calls know callee result types
35-
builder.build_items_two_pass(comp_unit);
51+
// Dump the module to string
52+
Ok(builder.module().as_operation().to_string())
53+
}
54+
}
3655

37-
// Verify the module
38-
if !builder.module().as_operation().verify() {
39-
return Err("MLIR module verification failed".to_string());
56+
/// Check if the compilation unit needs HIVM address spaces
57+
fn needs_hivm_address_space(comp_unit: &CompilUnit) -> bool {
58+
for item in &comp_unit.items {
59+
if let crate::ast::Item::FunDef(fun) = item {
60+
// Only check the main function or functions that are not HIVM placeholders
61+
if fun.ident.name == "main".into() || !is_hivm_placeholder_function(fun) {
62+
for param in &fun.param_decls {
63+
if let Some(ty) = &param.ty {
64+
if has_gpu_memory(ty) {
65+
return true;
66+
}
67+
}
68+
}
69+
}
70+
}
4071
}
72+
false
73+
}
4174

42-
// Dump the module to string
43-
Ok(builder.module().as_operation().to_string())
75+
/// Check if a function is a HIVM placeholder function
76+
fn is_hivm_placeholder_function(fun: &crate::ast::FunDef) -> bool {
77+
fun.ident.name.starts_with("hivm_")
78+
}
79+
80+
/// Check if a type has GPU memory qualifiers
81+
fn has_gpu_memory(ty: &crate::ast::Ty) -> bool {
82+
match &ty.ty {
83+
crate::ast::TyKind::Data(data_ty) => {
84+
match &data_ty.dty {
85+
crate::ast::DataTyKind::At(_, mem) => {
86+
matches!(mem, crate::ast::Memory::GpuGlobal | crate::ast::Memory::GpuShared | crate::ast::Memory::GpuLocal)
87+
},
88+
crate::ast::DataTyKind::Ref(ref_dty) => {
89+
matches!(ref_dty.mem, crate::ast::Memory::GpuGlobal | crate::ast::Memory::GpuShared | crate::ast::Memory::GpuLocal)
90+
},
91+
_ => false,
92+
}
93+
},
94+
_ => false,
95+
}
4496
}
4597

4698
pub fn create_context() -> Context {
4799
let registry = DialectRegistry::new();
48100
register_all_dialects(&registry);
49101

50-
// Custom dialects (hivm, annotation, symbol) are loaded via dialects.rs module
51-
52102
let context = Context::new();
103+
104+
// Allow unregistered dialects to handle HIVM dialect
105+
context.set_allow_unregistered_dialects(true);
106+
53107
context.append_dialect_registry(&registry);
54108
context.load_all_available_dialects();
109+
55110
context
56111
}
112+

0 commit comments

Comments
 (0)