Skip to content

Commit 6c3739a

Browse files
authored
Merge pull request #7 from Liberatedwinner/patch-1
2 parents edf2300 + ea0f66d commit 6c3739a

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

mamba_transformer/model.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
class RMSNorm(nn.Module):
1313
def __init__(self, dim: int):
1414
super().__init__()
15-
self.scale = dim**-0.5
15+
self.scale = dim ** (-0.5)
1616
self.g = nn.Parameter(torch.ones(dim))
1717

18-
def forward(self, x: Tensor):
18+
def forward(self, x: Tensor) -> Tensor:
1919
return F.normalize(x, dim=-1) * self.scale * self.g
2020

2121

@@ -97,6 +97,7 @@ def forward(self, x: Tensor) -> Tensor:
9797
x, _, _ = self.attn(x)
9898
x = self.norm(x)
9999
x = self.ffn(x)
100+
100101
return x
101102

102103

@@ -172,33 +173,28 @@ def __init__(
172173
self.transformer_depth = transformer_depth
173174
self.mamba_depth = mamba_depth
174175

175-
self.mamba_blocks = nn.ModuleList([])
176-
self.transformer_blocks = nn.ModuleList([])
177-
self.ffn_blocks = nn.ModuleList([])
178-
179-
self.mamba_blocks.append(
176+
# Mamba, Transformer, and ffn blocks
177+
self.mamba_blocks = nn.ModuleList([
180178
MambaBlock(dim, mamba_depth, d_state, *args, **kwargs)
181-
)
182-
183-
# Transformer and ffn blocks
184-
for _ in range(depth):
185-
self.ffn_blocks.append(
186-
FeedForward(dim, dim, ff_mult, *args, **kwargs)
187-
)
188-
189-
for _ in range(transformer_depth):
190-
self.transformer_blocks.append(
191-
TransformerBlock(
192-
dim,
193-
heads,
194-
dim_head,
195-
dropout,
196-
ff_mult,
197-
use_linear_attn,
198-
*args,
199-
**kwargs,
200-
)
201-
)
179+
for _ in range(mamba_depth)
180+
])
181+
self.transformer_blocks = nn.ModuleList([
182+
TransformerBlock(
183+
dim,
184+
heads,
185+
dim_head,
186+
dropout,
187+
ff_mult,
188+
use_linear_attn,
189+
*args,
190+
**kwargs,
191+
) for _ in range(transformer_depth)
192+
])
193+
194+
self.ffn_blocks = nn.ModuleList([
195+
FeedForward(dim, dim, ff_mult, *args, **kwargs)
196+
for _ in range(depth)
197+
])
202198

203199
# Layernorm
204200
self.norm = nn.LayerNorm(dim)

0 commit comments

Comments
 (0)