Skip to content

Commit b490ca8

Browse files
author
Kye
committed
[FEAT][LinearAttention]
1 parent 4708d72 commit b490ca8

File tree

5 files changed

+86
-17
lines changed

5 files changed

+86
-17
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ model = MambaTransformer(
3737
ff_mult=4, # Multiplier for the feed-forward layer dimension
3838
return_embeddings=False, # Whether to return the embeddings,
3939
transformer_depth=2, # Number of transformer blocks
40-
mamba_depth=10, # Number of Mamba blocks
40+
mamba_depth=10, # Number of Mamba blocks,
41+
use_linear_attn=True, # Whether to use linear attention
4142
)
4243

4344
# Pass the input tensor through the model and print the output shape
44-
print(model(x).shape)
45+
out = model(x)
46+
47+
print(out.shape)
4548

4649

4750
# to train

example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
ff_mult=4, # Multiplier for the feed-forward layer dimension
1717
return_embeddings=False, # Whether to return the embeddings,
1818
transformer_depth=2, # Number of transformer blocks
19-
mamba_depth=10, # Number of Mamba blocks
19+
mamba_depth=10, # Number of Mamba blocks,
20+
use_linear_attn=True, # Whether to use linear attention
2021
)
2122

2223
# Pass the input tensor through the model and print the output shape
23-
print(model(x).shape)
24+
out = model(x)
25+
26+
print(out.shape)

mamba_transformer/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from mamba_transformer.blocks import LinearAttention
2+
13
from mamba_transformer.model import (
24
RMSNorm,
3-
MultiQueryTransformerBlock,
5+
TransformerBlock,
46
MambaTransformerblock,
57
MambaTransformer,
68
)
79

810
__all__ = [
11+
"LinearAttention",
912
"RMSNorm",
10-
"MultiQueryTransformerBlock",
13+
"TransformerBlock",
1114
"MambaTransformerblock",
1215
"MambaTransformer",
1316
]

mamba_transformer/blocks.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from torch import nn, einsum
2+
3+
from einops import rearrange
4+
5+
from zeta.utils import exists
6+
7+
# linear attention
8+
9+
10+
class LinearAttention(nn.Module):
11+
def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0):
12+
super().__init__()
13+
inner_dim = heads * dim_head
14+
self.heads = heads
15+
self.scale = dim_head**-0.5
16+
17+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
18+
self.to_out = nn.Sequential(
19+
nn.Linear(inner_dim, dim), nn.Dropout(dropout)
20+
)
21+
22+
def forward(self, x, mask=None):
23+
h = self.heads
24+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
25+
q, k, v = map(
26+
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h),
27+
(q, k, v),
28+
)
29+
30+
q = q * self.scale
31+
q, k = q.softmax(dim=-1), k.softmax(dim=-2)
32+
33+
if exists(mask):
34+
k.masked_fill_(mask, 0.0)
35+
36+
context = einsum("b n d, b n e -> b d e", q, k)
37+
out = einsum("b d e, b n d -> b n e", context, v)
38+
out = rearrange(out, " (b h) n d -> b n (h d)", h=h)
39+
return self.to_out(out)

mamba_transformer/model.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import torch
22
from torch import nn, Tensor
3-
from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention
3+
from zeta.nn import (
4+
MambaBlock,
5+
FeedForward,
6+
MultiQueryAttention,
7+
)
48
import torch.nn.functional as F
9+
from mamba_transformer.blocks import LinearAttention
510

611

712
class RMSNorm(nn.Module):
@@ -14,9 +19,9 @@ def forward(self, x: Tensor):
1419
return F.normalize(x, dim=-1) * self.scale * self.g
1520

1621

17-
class MultiQueryTransformerBlock(nn.Module):
22+
class TransformerBlock(nn.Module):
1823
"""
19-
MultiQueryTransformerBlock is a module that represents a single block of the Multi-Query Transformer.
24+
TransformerBlock is a module that represents a single block of the Multi-Query Transformer.
2025
It consists of a multi-query attention layer, a feed-forward network, and layer normalization.
2126
2227
Args:
@@ -38,7 +43,7 @@ class MultiQueryTransformerBlock(nn.Module):
3843
3944
Methods:
4045
forward(x: Tensor) -> Tensor:
41-
Performs a forward pass of the MultiQueryTransformerBlock.
46+
Performs a forward pass of the TransformerBlock.
4247
4348
"""
4449

@@ -49,6 +54,7 @@ def __init__(
4954
dim_head: int,
5055
dropout: float = 0.1,
5156
ff_mult: int = 4,
57+
use_linear_attn: bool = False,
5258
*args,
5359
**kwargs,
5460
):
@@ -58,17 +64,23 @@ def __init__(
5864
self.dim_head = dim_head
5965
self.dropout = dropout
6066
self.ff_mult = ff_mult
67+
self.use_linear_attn = use_linear_attn
6168

6269
self.attn = MultiQueryAttention(dim, heads, *args, **kwargs)
6370

71+
# Linear Attention
72+
self.linear_attn = LinearAttention(
73+
dim=dim, heads=heads, dim_head=dim_head, dropout=dropout
74+
)
75+
6476
self.ffn = FeedForward(dim, dim, ff_mult, *args, **kwargs)
6577

6678
# Normalization
6779
self.norm = nn.LayerNorm(dim)
6880

6981
def forward(self, x: Tensor) -> Tensor:
7082
"""
71-
Performs a forward pass of the MultiQueryTransformerBlock.
83+
Performs a forward pass of the TransformerBlock.
7284
7385
Args:
7486
x (Tensor): The input tensor.
@@ -77,9 +89,14 @@ def forward(self, x: Tensor) -> Tensor:
7789
Tensor: The output tensor.
7890
7991
"""
80-
x, _, _ = self.attn(x)
81-
x = self.norm(x)
82-
x = self.ffn(x)
92+
if self.use_linear_attn:
93+
x = self.linear_attn(x)
94+
x = self.norm(x)
95+
x = self.ffn(x)
96+
else:
97+
x, _, _ = self.attn(x)
98+
x = self.norm(x)
99+
x = self.ffn(x)
83100
return x
84101

85102

@@ -106,7 +123,7 @@ class MambaTransformerblock(nn.Module):
106123
dropout (float): The dropout rate.
107124
ff_mult (int): The multiplier for the feed-forward network dimension.
108125
mamba_blocks (nn.ModuleList): List of MambaBlock instances.
109-
transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances.
126+
transformer_blocks (nn.ModuleList): List of TransformerBlock instances.
110127
ffn_blocks (nn.ModuleList): List of FeedForward instances.
111128
norm (nn.LayerNorm): Layer normalization module.
112129
@@ -140,6 +157,7 @@ def __init__(
140157
d_state: int = None,
141158
transformer_depth: int = 1,
142159
mamba_depth: int = 1,
160+
use_linear_attn: bool = False,
143161
*args,
144162
**kwargs,
145163
):
@@ -167,15 +185,16 @@ def __init__(
167185
self.ffn_blocks.append(
168186
FeedForward(dim, dim, ff_mult, *args, **kwargs)
169187
)
170-
188+
171189
for _ in range(transformer_depth):
172190
self.transformer_blocks.append(
173-
MultiQueryTransformerBlock(
191+
TransformerBlock(
174192
dim,
175193
heads,
176194
dim_head,
177195
dropout,
178196
ff_mult,
197+
use_linear_attn,
179198
*args,
180199
**kwargs,
181200
)
@@ -247,6 +266,7 @@ def __init__(
247266
return_embeddings: bool = False,
248267
transformer_depth: int = 1,
249268
mamba_depth: int = 1,
269+
use_linear_attn=False,
250270
*args,
251271
**kwargs,
252272
):
@@ -274,6 +294,7 @@ def __init__(
274294
return_embeddings,
275295
transformer_depth,
276296
mamba_depth,
297+
use_linear_attn,
277298
*args,
278299
**kwargs,
279300
)

0 commit comments

Comments
 (0)