44import operator
55import time
66#from .accelerated_layout import flat_transpose
7- from hptt import tensorTransposeAndUpdate
7+ import hptt
88import torch
99
1010from abc import ABC
@@ -13,50 +13,7 @@ def my_transpose(dest, source, axes):
1313 if axes == list (range (len (axes ))):
1414 dest [:] = source
1515 else :
16- #print(axes, source.shape, dest.shape, source.size, dest.size)
17- #assert tuple(source.shape[a] for a in axes) == dest.shape
18- #print(source.flags)
19- #print(dest.flags)
20- #assert source.flags['C_CONTIGUOUS'] or source.flags['F_CONTIGUOUS']
21- #assert dest.flags['C_CONTIGUOUS'] or dest.flags['F_CONTIGUOUS']
22- tensorTransposeAndUpdate (tuple (axes ), 1.0 , source , 1.0 , dest )
23- #s = time.time()
24- #dest[:] = source.transpose(axes)
25- #print("NumpyTranspose : ", time.time() - s)
26-
27- #s = time.time()
28- #new_shape = [source.shape[0]]
29- #idx = 0
30- #new_axes = [axes[0]]
31- #for i,a in enumerate(axes[1:],1):
32- # if a == axes[i-1]+1:
33- # new_shape[-1] *= source.shape[i]
34- # else:
35- # new_shape.append(source.shape[i])
36- # new_axes.append(a)
37- #if len(new_shape) == 1:
38- # dest[:] = source
39- #else:
40- # new_axes = np.argsort(np.argsort(new_axes))
41- # dim1, dim2 = next((i,a) for i,a in enumerate(new_axes) if i!=a)
42- # print(new_shape, axes, new_axes)
43- # #tensorTransposeAndUpdate(tuple(axes), 1.0, source, 1.0, dest)
44- # #flat_transpose(dest.ravel(), source.ravel(), source.shape, dest.shape, tuple(axes))
45- # #dest[:] = np.transpose(source, axes)
46- # print(source.shape)
47- # print(dest.shape)
48- # dest[:] = torch.transpose(torch.from_numpy(source), int(dim1), int(dim2)).contiguous()
49- #print("Torch : ", time.time() - s)
50- #s = time.time()
51- #dest[:] = np.ascontiguousarray(source.transpose(axes))
52- #print("NumpyTransposeContig : ", time.time() - s)
53- ##print(source.shape, source.strides)
54- ##print(dest.shape, dest.strides)
55- ##print(source.flags)
56- ##print(dest.flags)
57- #s = time.time()
58- #tensorTransposeAndUpdate(tuple(axes), 1.0, source, 1.0, dest)
59- #print("HPTT Transpose : ", time.time() - s)
16+ hptt .tensorTransposeAndUpdate (axes , 1.0 , source , 0.0 , dest )
6017
6118
6219class Layout :
0 commit comments