|
12 | 12 | class RMSNorm(nn.Module): |
13 | 13 | def __init__(self, dim: int): |
14 | 14 | super().__init__() |
15 | | - self.scale = dim**-0.5 |
| 15 | + self.scale = dim ** (-0.5) |
16 | 16 | self.g = nn.Parameter(torch.ones(dim)) |
17 | 17 |
|
18 | | - def forward(self, x: Tensor): |
| 18 | + def forward(self, x: Tensor) -> Tensor: |
19 | 19 | return F.normalize(x, dim=-1) * self.scale * self.g |
20 | 20 |
|
21 | 21 |
|
@@ -97,6 +97,7 @@ def forward(self, x: Tensor) -> Tensor: |
97 | 97 | x, _, _ = self.attn(x) |
98 | 98 | x = self.norm(x) |
99 | 99 | x = self.ffn(x) |
| 100 | + |
100 | 101 | return x |
101 | 102 |
|
102 | 103 |
|
@@ -172,33 +173,28 @@ def __init__( |
172 | 173 | self.transformer_depth = transformer_depth |
173 | 174 | self.mamba_depth = mamba_depth |
174 | 175 |
|
175 | | - self.mamba_blocks = nn.ModuleList([]) |
176 | | - self.transformer_blocks = nn.ModuleList([]) |
177 | | - self.ffn_blocks = nn.ModuleList([]) |
178 | | - |
179 | | - self.mamba_blocks.append( |
| 176 | + # Mamba, Transformer, and ffn blocks |
| 177 | + self.mamba_blocks = nn.ModuleList([ |
180 | 178 | MambaBlock(dim, mamba_depth, d_state, *args, **kwargs) |
181 | | - ) |
182 | | - |
183 | | - # Transformer and ffn blocks |
184 | | - for _ in range(depth): |
185 | | - self.ffn_blocks.append( |
186 | | - FeedForward(dim, dim, ff_mult, *args, **kwargs) |
187 | | - ) |
188 | | - |
189 | | - for _ in range(transformer_depth): |
190 | | - self.transformer_blocks.append( |
191 | | - TransformerBlock( |
192 | | - dim, |
193 | | - heads, |
194 | | - dim_head, |
195 | | - dropout, |
196 | | - ff_mult, |
197 | | - use_linear_attn, |
198 | | - *args, |
199 | | - **kwargs, |
200 | | - ) |
201 | | - ) |
| 179 | + for _ in range(mamba_depth) |
| 180 | + ]) |
| 181 | + self.transformer_blocks = nn.ModuleList([ |
| 182 | + TransformerBlock( |
| 183 | + dim, |
| 184 | + heads, |
| 185 | + dim_head, |
| 186 | + dropout, |
| 187 | + ff_mult, |
| 188 | + use_linear_attn, |
| 189 | + *args, |
| 190 | + **kwargs, |
| 191 | + ) for _ in range(transformer_depth) |
| 192 | + ]) |
| 193 | + |
| 194 | + self.ffn_blocks = nn.ModuleList([ |
| 195 | + FeedForward(dim, dim, ff_mult, *args, **kwargs) |
| 196 | + for _ in range(depth) |
| 197 | + ]) |
202 | 198 |
|
203 | 199 | # Layernorm |
204 | 200 | self.norm = nn.LayerNorm(dim) |
|
0 commit comments