2424from .frame import Frame
2525from .window import Window
2626
27+ LEARNABLES = ("basis" , "window" )
28+
2729
2830class ModifiedDiscreteCosineTransform (BaseFunctionalModule ):
2931 """This module is a simple cascade of framing, windowing, and modified DCT.
@@ -36,12 +38,22 @@ class ModifiedDiscreteCosineTransform(BaseFunctionalModule):
3638 window : ['sine', 'vorbis', 'kbd', 'rectangular']
3739 The window type.
3840
41+ learnable : bool or list[str]
42+ Indicates whether the parameters are learnable. If a boolean, it specifies
43+ whether all parameters are learnable. If a list, it contains the keys of the
44+ learnable parameters, which can only be "basis" and "window".
45+
3946 """
4047
41- def __init__ (self , frame_length : int , window : str = "sine" ) -> None :
48+ def __init__ (
49+ self ,
50+ frame_length : int ,
51+ window : str = "sine" ,
52+ learnable : bool | list [str ] = False ,
53+ ) -> None :
4254 super ().__init__ ()
4355
44- self .values , layers , _ = self ._precompute (* get_values (locals ()))
56+ self .values , layers , _ = self ._precompute (* get_values (locals (), full = True ))
4557 self .layers = nn .ModuleList (layers )
4658
4759 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -84,16 +96,29 @@ def _takes_input_size() -> bool:
8496 return False
8597
8698 @staticmethod
87- def _check () -> None :
88- pass
99+ def _check (learnable : bool | list [str ]) -> None :
100+ if isinstance (learnable , (tuple , list )):
101+ if any (x not in LEARNABLES for x in learnable ):
102+ raise ValueError ("An unsupported key is found in learnable." )
103+ elif not isinstance (learnable , bool ):
104+ raise ValueError ("learnable must be boolean or list." )
89105
90106 @staticmethod
91107 def _precompute (
92- frame_length : int , window : str , transform : str = "cosine" , module : bool = True
108+ frame_length : int ,
109+ window : str ,
110+ learnable : bool | list [str ] = False ,
111+ transform : str = "cosine" ,
112+ module : bool = True ,
93113 ) -> Precomputed :
94- ModifiedDiscreteCosineTransform ._check ()
114+ ModifiedDiscreteCosineTransform ._check (learnable )
95115 frame_period = frame_length // 2
96116
117+ if learnable is True :
118+ learnable = LEARNABLES
119+ elif learnable is False :
120+ learnable = ()
121+
97122 frame = get_layer (
98123 module ,
99124 Frame ,
@@ -110,6 +135,7 @@ def _precompute(
110135 out_length = None ,
111136 window = window ,
112137 norm = "none" ,
138+ learnable = "window" in learnable ,
113139 ),
114140 )
115141 mdt = get_layer (
@@ -119,6 +145,7 @@ def _precompute(
119145 length = frame_length ,
120146 window = window ,
121147 transform = transform ,
148+ learnable = "basis" in learnable ,
122149 ),
123150 )
124151 return (frame_period ,), (frame , window_ , mdt ), None
@@ -150,15 +177,27 @@ class ModifiedDiscreteTransform(BaseFunctionalModule):
150177 transform : ['cosine', 'sine']
151178 The transform type.
152179
180+ learnable : bool
181+ Whether to make the DCT matrix learnable.
182+
153183 """
154184
155- def __init__ (self , length : int , window : str , transform : str = "cosine" ) -> None :
185+ def __init__ (
186+ self ,
187+ length : int ,
188+ window : str ,
189+ transform : str = "cosine" ,
190+ learnable : bool = False ,
191+ ) -> None :
156192 super ().__init__ ()
157193
158194 self .in_dim = length
159195
160196 _ , _ , tensors = self ._precompute (* get_values (locals ()))
161- self .register_buffer ("W" , tensors [0 ])
197+ if learnable :
198+ self .W = nn .Parameter (tensors [0 ])
199+ else :
200+ self .register_buffer ("W" , tensors [0 ])
162201
163202 def forward (self , x : torch .Tensor ) -> torch .Tensor :
164203 """Apply MDCT/MDST to the input.
@@ -175,7 +214,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
175214
176215 """
177216 check_size (x .size (- 1 ), self .in_dim , "dimension of input" )
178- return self ._forward (x , ** self ._buffers )
217+ return self ._forward (x , ** self ._buffers , ** self . _parameters )
179218
180219 @staticmethod
181220 def _func (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
0 commit comments