Skip to content

Commit 49f41c0

Browse files
committed
Fix call
1 parent 3ccf7c5 commit 49f41c0

File tree

1 file changed

+2
-45
lines changed

1 file changed

+2
-45
lines changed

pygyro/model/layout.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import operator
55
import time
66
#from .accelerated_layout import flat_transpose
7-
from hptt import tensorTransposeAndUpdate
7+
import hptt
88
import torch
99

1010
from abc import ABC
@@ -13,50 +13,7 @@ def my_transpose(dest, source, axes):
1313
if axes == list(range(len(axes))):
1414
dest[:] = source
1515
else:
16-
#print(axes, source.shape, dest.shape, source.size, dest.size)
17-
#assert tuple(source.shape[a] for a in axes) == dest.shape
18-
#print(source.flags)
19-
#print(dest.flags)
20-
#assert source.flags['C_CONTIGUOUS'] or source.flags['F_CONTIGUOUS']
21-
#assert dest.flags['C_CONTIGUOUS'] or dest.flags['F_CONTIGUOUS']
22-
tensorTransposeAndUpdate(tuple(axes), 1.0, source, 1.0, dest)
23-
#s = time.time()
24-
#dest[:] = source.transpose(axes)
25-
#print("NumpyTranspose : ", time.time() - s)
26-
27-
#s = time.time()
28-
#new_shape = [source.shape[0]]
29-
#idx = 0
30-
#new_axes = [axes[0]]
31-
#for i,a in enumerate(axes[1:],1):
32-
# if a == axes[i-1]+1:
33-
# new_shape[-1] *= source.shape[i]
34-
# else:
35-
# new_shape.append(source.shape[i])
36-
# new_axes.append(a)
37-
#if len(new_shape) == 1:
38-
# dest[:] = source
39-
#else:
40-
# new_axes = np.argsort(np.argsort(new_axes))
41-
# dim1, dim2 = next((i,a) for i,a in enumerate(new_axes) if i!=a)
42-
# print(new_shape, axes, new_axes)
43-
# #tensorTransposeAndUpdate(tuple(axes), 1.0, source, 1.0, dest)
44-
# #flat_transpose(dest.ravel(), source.ravel(), source.shape, dest.shape, tuple(axes))
45-
# #dest[:] = np.transpose(source, axes)
46-
# print(source.shape)
47-
# print(dest.shape)
48-
# dest[:] = torch.transpose(torch.from_numpy(source), int(dim1), int(dim2)).contiguous()
49-
#print("Torch : ", time.time() - s)
50-
#s = time.time()
51-
#dest[:] = np.ascontiguousarray(source.transpose(axes))
52-
#print("NumpyTransposeContig : ", time.time() - s)
53-
##print(source.shape, source.strides)
54-
##print(dest.shape, dest.strides)
55-
##print(source.flags)
56-
##print(dest.flags)
57-
#s = time.time()
58-
#tensorTransposeAndUpdate(tuple(axes), 1.0, source, 1.0, dest)
59-
#print("HPTT Transpose : ", time.time() - s)
16+
hptt.tensorTransposeAndUpdate(axes, 1.0, source, 0.0, dest)
6017

6118

6219
class Layout:

0 commit comments

Comments
 (0)