Skip to content

Commit 9ac36e6

Browse files
Merge pull request #102 from tgorni/devel
Updated generator for fortran.c
2 parents f2e0d4c + 16edcd9 commit 9ac36e6

File tree

5 files changed

+176
-23
lines changed

5 files changed

+176
-23
lines changed

devel_tools/fortran_wraper_generator/basics.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,29 @@ def add_cudasync(cw):
5757
fun.add_code("cudaDeviceSynchronize();")
5858
cw.add_function_definition(fun)
5959

60+
def add_cublas_handle_init(cw):
61+
fun = cs.Function("cublas_handle_init_",
62+
"void",
63+
arguments = [ cs.Variable("handle",
64+
"long int *")])
65+
fun.add_code(("cublasHandle_t *h;",
66+
"h = (cublasHandle_t*) malloc(sizeof(cublasHandle_t));",
67+
"cublasCreate(h);",
68+
"*handle = (long int) h;",
69+
))
70+
cw.add_function_definition(fun)
71+
72+
def add_cublas_handle_destroy(cw):
73+
fun = cs.Function("cublas_handle_destroy_",
74+
"void",
75+
arguments = [ cs.Variable("handle",
76+
"long int *")])
77+
fun.add_code(("cublasHandle_t *h = (cublasHandle_t *) *handle;",
78+
"cublasDestroy(*h);",
79+
"free(h);",
80+
))
81+
cw.add_function_definition(fun)
82+
6083
def add_cusolver_handle_init(cw):
6184
fun = cs.Function("cusolver_handle_init_",
6285
"void",
@@ -80,10 +103,65 @@ def add_cusolver_handle_destroy(cw):
80103
))
81104
cw.add_function_definition(fun)
82105

106+
def add_cublas_types(cw):
107+
108+
fun = cs.Function("cublas_op_type",
109+
"cublasOperation_t",
110+
arguments = [ cs.Variable("blas_op_type", "const char *")])
111+
fun.add_code(("if (*blas_op_type == 'N') return CUBLAS_OP_N;",
112+
"if (*blas_op_type == 'n') return CUBLAS_OP_N;",
113+
"if (*blas_op_type == 'T') return CUBLAS_OP_T;",
114+
"if (*blas_op_type == 't') return CUBLAS_OP_T;",
115+
"if (*blas_op_type == 'C') return CUBLAS_OP_C;",
116+
"if (*blas_op_type == 'c') return CUBLAS_OP_C;",
117+
"printf(\"WARNING: unrecognized blas_op_type\\n\");",
118+
"return -1;"))
119+
cw.add_function_definition(fun)
120+
121+
fun = cs.Function("cublas_side_type",
122+
"cublasSideMode_t",
123+
arguments = [ cs.Variable("blas_side_type", "const char *")])
124+
125+
fun.add_code(("if (*blas_side_type == 'R') return CUBLAS_SIDE_RIGHT;",
126+
"if (*blas_side_type == 'r') return CUBLAS_SIDE_RIGHT;",
127+
"if (*blas_side_type == 'L') return CUBLAS_SIDE_LEFT;",
128+
"if (*blas_side_type == 'l') return CUBLAS_SIDE_LEFT;",
129+
"printf(\"WARNING: unrecognized blas_side_type\\n\");",
130+
"return -1;"))
131+
132+
cw.add_function_definition(fun)
133+
134+
fun = cs.Function("cublas_fill_type",
135+
"cublasFillMode_t",
136+
arguments = [ cs.Variable("blas_fill_type", "const char *")])
137+
138+
fun.add_code(("if (*blas_fill_type == 'U') return CUBLAS_FILL_MODE_UPPER;",
139+
"if (*blas_fill_type == 'u') return CUBLAS_FILL_MODE_UPPER;",
140+
"if (*blas_fill_type == 'L') return CUBLAS_FILL_MODE_LOWER;",
141+
"if (*blas_fill_type == 'l') return CUBLAS_FILL_MODE_LOWER;",
142+
"printf(\"WARNING: unrecognized blas_fill_type\\n\");",
143+
"return -1;"))
144+
145+
cw.add_function_definition(fun)
146+
147+
fun = cs.Function("cublas_diag_type",
148+
"cublasDiagType_t",
149+
arguments = [ cs.Variable("blas_diag_type", "const char *")])
150+
151+
fun.add_code(("if (*blas_diag_type == 'U') return CUBLAS_DIAG_UNIT;",
152+
"if (*blas_diag_type == 'u') return CUBLAS_DIAG_UNIT;",
153+
"if (*blas_diag_type == 'N') return CUBLAS_DIAG_NON_UNIT;",
154+
"if (*blas_diag_type == 'n') return CUBLAS_DIAG_NON_UNIT;",
155+
"printf(\"WARNING: unrecognized blas_diag_type\\n\");",
156+
"return -1;"))
157+
158+
cw.add_function_definition(fun)
159+
83160
def default_includes(cw):
84161
cw.include("<stdlib.h>")
85162
cw.include("<stddef.h>")
86163
cw.include("<ctype.h>")
164+
cw.include("<stdio.h>")
87165

88166
def default_definitions(cw):
89167
cw.add_line("typedef size_t devptr_t;")

devel_tools/fortran_wraper_generator/datapack.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@
131131
"type": "char",
132132
"offload": False,
133133
"const": True,
134+
"deprecated_cast": "cublas_op_type",
134135
},
135136
{"name": "transb",
136137
"type": "char",
137138
"offload": False,
138139
"const": True,
140+
"deprecated_cast": "cublas_op_type",
139141
},
140142
{"name": "M",
141143
"type": "int",
@@ -203,11 +205,13 @@
203205
"type": "char",
204206
"offload": False,
205207
"const": True,
208+
"deprecated_cast": "cublas_op_type",
206209
},
207210
{"name": "transb",
208211
"type": "char",
209212
"offload": False,
210213
"const": True,
214+
"deprecated_cast": "cublas_op_type",
211215
},
212216
{"name": "M",
213217
"type": "int",
@@ -275,11 +279,13 @@
275279
"type": "char",
276280
"offload": False,
277281
"const": True,
282+
"deprecated_cast": "cublas_op_type",
278283
},
279284
{"name": "transb",
280285
"type": "char",
281286
"offload": False,
282287
"const": True,
288+
"deprecated_cast": "cublas_op_type",
283289
},
284290
{"name": "M",
285291
"type": "int",
@@ -347,6 +353,7 @@
347353
"type": "char",
348354
"offload": False,
349355
"const": True,
356+
"deprecated_cast": "cublas_op_type",
350357
},
351358
{"name": "M",
352359
"type": "int",
@@ -409,6 +416,7 @@
409416
"type": "char",
410417
"offload": False,
411418
"const": True,
419+
"deprecated_cast": "cublas_op_type",
412420
},
413421
{"name": "M",
414422
"type": "int",
@@ -471,6 +479,7 @@
471479
"type": "char",
472480
"offload": False,
473481
"const": True,
482+
"deprecated_cast": "cublas_op_type",
474483
},
475484
{"name": "M",
476485
"type": "int",
@@ -533,21 +542,25 @@
533542
"type": "char",
534543
"offload": False,
535544
"const": True,
545+
"deprecated_cast": "cublas_side_type",
536546
},
537547
{"name": "uplo",
538548
"type": "char",
539549
"offload": False,
540550
"const": True,
551+
"deprecated_cast": "cublas_fill_type",
541552
},
542553
{"name": "transa",
543554
"type": "char",
544555
"offload": False,
545556
"const": True,
557+
"deprecated_cast": "cublas_op_type",
546558
},
547559
{"name": "diag",
548560
"type": "char",
549561
"offload": False,
550562
"const": True,
563+
"deprecated_cast": "cublas_diag_type",
551564
},
552565
{"name": "M",
553566
"type": "int",
@@ -594,21 +607,25 @@
594607
"type": "char",
595608
"offload": False,
596609
"const": True,
610+
"deprecated_cast": "cublas_side_type",
597611
},
598612
{"name": "uplo",
599613
"type": "char",
600614
"offload": False,
601615
"const": True,
616+
"deprecated_cast": "cublas_fill_type",
602617
},
603618
{"name": "transa",
604619
"type": "char",
605620
"offload": False,
606621
"const": True,
622+
"deprecated_cast": "cublas_op_type",
607623
},
608624
{"name": "diag",
609625
"type": "char",
610626
"offload": False,
611627
"const": True,
628+
"deprecated_cast": "cublas_diag_type",
612629
},
613630
{"name": "M",
614631
"type": "int",
@@ -651,6 +668,30 @@
651668

652669
"""
653670
671+
###################################################################
672+
673+
cublas_v2 part starts here it is generated from v1
674+
675+
###################################################################
676+
677+
"""
678+
679+
cublas_v2 = [ entry for entry in cublas ]
680+
681+
for ii, entry in enumerate(cublas_v2):
682+
if entry["library"] == "cublas":
683+
entry["args"].insert(0, {"name": "handle",
684+
"type": "long int",
685+
"offload": False,
686+
"const": True,
687+
"recast": {"to": "cublasHandle_t"}})
688+
for yy, entry_arg in enumerate(entry["args"]):
689+
if entry_arg["name"] in ("alpha", "beta"):
690+
entry["args"][yy]["norefpass"] = True
691+
cublas_v2[ii] = entry
692+
693+
"""
694+
654695
###################################################################
655696
656697
cusolver part starts here

devel_tools/fortran_wraper_generator/includer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def add_this(cw, d):
3434
if not x["const"]:
3535
qualifiers.remove("const")
3636
name = x["name"]
37+
3738
type_ = x["type"] + " *"
3839

3940
v = cs.Variable(name, type_, qualifiers = qualifiers)
@@ -53,7 +54,9 @@ def add_this(cw, d):
5354
return_value = x["name"]
5455
except:
5556
call_name = x["name"]
56-
if x["offload"]:
57+
if "deprecated_cast" in x:
58+
call_name = f"{x['deprecated_cast']}({call_name})"
59+
elif x["offload"]:
5760
call_name = call_name + "_"
5861
elif "recast" in x:
5962
call_name = "* " + call_name + "__"

devel_tools/fortran_wraper_generator/make_wrapper.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@
3939

4040
add_complex(cw)
4141

42-
b_cublas = True
42+
b_cublas = False
43+
b_cublas_v2 = True
4344
b_cusolver = True
4445

46+
if b_cublas and b_cublas_v2:
47+
raise RuntimeError("b_cublas and b_cublas_v2 cannot be both True, choose one of them")
48+
4549
if b_cublas:
4650
from datapack import cublas
4751
from basics import add_cudasync
@@ -57,18 +61,45 @@
5761
for entry in cublas:
5862
add_this(cw, entry)
5963

60-
if b_cublas:
61-
from datapack import cusolver
64+
if b_cublas_v2:
65+
from datapack import cublas_v2
66+
from basics import ( add_cudasync,
67+
add_cublas_handle_init,
68+
add_cublas_handle_destroy,
69+
add_cublas_types,
70+
)
6271

63-
cw.start_if_def(f"_CUSOLVER")
64-
cw.include('"cusolverDn.h"')
72+
cw.start_if_def(f"_CUBLAS")
73+
cw.include("<cublas_v2.h>")
6574
cw.end_if_def()
6675

67-
cw.start_if_def(f"_CUSOLVER")
76+
cw.start_if_def(f"_CUBLAS")
77+
add_cublas_handle_init(cw)
78+
add_cublas_handle_destroy(cw)
79+
cw.end_if_def()
80+
81+
cw.start_if_def(f"_CUBLAS")
82+
add_cudasync(cw)
83+
cw.end_if_def()
84+
85+
cw.start_if_def(f"_CUBLAS")
86+
add_cublas_types(cw)
87+
cw.end_if_def()
88+
89+
for entry in cublas_v2:
90+
add_this(cw, entry)
91+
92+
if b_cusolver:
93+
from datapack import cusolver
6894
from basics import ( add_cusolver_handle_init,
6995
add_cusolver_handle_destroy,
7096
)
7197

98+
cw.start_if_def(f"_CUSOLVER")
99+
cw.include('"cusolverDn.h"')
100+
cw.end_if_def()
101+
102+
cw.start_if_def(f"_CUSOLVER")
72103
add_cusolver_handle_init(cw)
73104
add_cusolver_handle_destroy(cw)
74105
cw.end_if_def()

0 commit comments

Comments
 (0)