11from enum import IntEnum
2- from typing import Mapping
32from types import MappingProxyType
43import numpy as np
54import torch
@@ -13,33 +12,31 @@ class DTypeEnum(IntEnum):
1312 UINT8 = 5
1413
1514
16- dtype_to_enum : Mapping [torch .dtype | type [np .generic ] | np .dtype , DTypeEnum ] = (
17- MappingProxyType (
18- {
19- torch .float32 : DTypeEnum .FLOAT32 ,
20- torch .float64 : DTypeEnum .FLOAT64 ,
21- torch .int32 : DTypeEnum .INT32 ,
22- torch .int64 : DTypeEnum .INT64 ,
23- torch .uint8 : DTypeEnum .UINT8 ,
24- # torch
25- np .float32 : DTypeEnum .FLOAT32 ,
26- np .float64 : DTypeEnum .FLOAT64 ,
27- np .int32 : DTypeEnum .INT32 ,
28- np .int64 : DTypeEnum .INT64 ,
29- np .uint8 : DTypeEnum .UINT8 ,
30- # numpy generic
31- np .dtype (np .float32 ): DTypeEnum .FLOAT32 ,
32- np .dtype (np .float64 ): DTypeEnum .FLOAT64 ,
33- np .dtype (np .int32 ): DTypeEnum .INT32 ,
34- np .dtype (np .int64 ): DTypeEnum .INT64 ,
35- np .dtype (np .uint8 ): DTypeEnum .UINT8 ,
36- # numpy dtype
37- }
38- )
15+ dtype_to_enum = MappingProxyType (
16+ {
17+ torch .float32 : DTypeEnum .FLOAT32 ,
18+ torch .float64 : DTypeEnum .FLOAT64 ,
19+ torch .int32 : DTypeEnum .INT32 ,
20+ torch .int64 : DTypeEnum .INT64 ,
21+ torch .uint8 : DTypeEnum .UINT8 ,
22+ # torch
23+ np .float32 : DTypeEnum .FLOAT32 ,
24+ np .float64 : DTypeEnum .FLOAT64 ,
25+ np .int32 : DTypeEnum .INT32 ,
26+ np .int64 : DTypeEnum .INT64 ,
27+ np .uint8 : DTypeEnum .UINT8 ,
28+ # numpy generic
29+ np .dtype (np .float32 ): DTypeEnum .FLOAT32 ,
30+ np .dtype (np .float64 ): DTypeEnum .FLOAT64 ,
31+ np .dtype (np .int32 ): DTypeEnum .INT32 ,
32+ np .dtype (np .int64 ): DTypeEnum .INT64 ,
33+ np .dtype (np .uint8 ): DTypeEnum .UINT8 ,
34+ # numpy dtype
35+ }
3936)
4037
4138
42- enum_to_torch_dtype : Mapping [ DTypeEnum , torch . dtype ] = MappingProxyType (
39+ enum_to_torch_dtype = MappingProxyType (
4340 {
4441 DTypeEnum .FLOAT32 : torch .float32 ,
4542 DTypeEnum .FLOAT64 : torch .float64 ,
0 commit comments