@@ -104,7 +104,9 @@ def __init__(
104104 super ().__init__ (num_channels , eps = eps , elementwise_affine = affine , ** kwargs )
105105
106106 def forward (self , x : torch .Tensor ) -> torch .Tensor :
107- x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
107+ weight = self .weight .float () if self .weight is not None else None
108+ bias = self .bias .float () if self .bias is not None else None
109+ x = F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
108110 return x
109111
110112
@@ -146,7 +148,9 @@ def __init__(
146148
147149 def forward (self , x : torch .Tensor ) -> torch .Tensor :
148150 x = x .permute (0 , 2 , 3 , 1 )
149- x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
151+ weight = self .weight .float () if self .weight is not None else None
152+ bias = self .bias .float () if self .bias is not None else None
153+ x = F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
150154 x = x .permute (0 , 3 , 1 , 2 )
151155 return x
152156
@@ -282,7 +286,8 @@ def reset_parameters(self) -> None:
282286 nn .init .ones_ (self .weight )
283287
284288 def forward (self , x : torch .Tensor ) -> torch .Tensor :
285- x = rms_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
289+ weight = self .weight .float () if self .weight is not None else None
290+ x = rms_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
286291 return x
287292
288293
@@ -381,7 +386,8 @@ def reset_parameters(self) -> None:
381386 nn .init .ones_ (self .weight )
382387
383388 def forward (self , x : torch .Tensor ) -> torch .Tensor :
384- x = rms_norm2d (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
389+ weight = self .weight .float () if self .weight is not None else None
390+ x = rms_norm2d (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
385391 return x
386392
387393
@@ -470,7 +476,8 @@ def reset_parameters(self) -> None:
470476 nn .init .ones_ (self .weight )
471477
472478 def forward (self , x : torch .Tensor ) -> torch .Tensor :
473- x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
479+ weight = self .weight .float () if self .weight is not None else None
480+ x = simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
474481 return x
475482
476483
@@ -562,6 +569,7 @@ def reset_parameters(self) -> None:
562569
563570 def forward (self , x : torch .Tensor ) -> torch .Tensor :
564571 x = x .permute (0 , 2 , 3 , 1 )
565- x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
572+ weight = self .weight .float () if self .weight is not None else None
573+ x = simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
566574 x = x .permute (0 , 3 , 1 , 2 )
567575 return x
0 commit comments