11import torch
22from torch import nn , Tensor
3- from zeta .nn import MambaBlock , FeedForward , MultiQueryAttention
3+ from zeta .nn import (
4+ MambaBlock ,
5+ FeedForward ,
6+ MultiQueryAttention ,
7+ )
48import torch .nn .functional as F
9+ from mamba_transformer .blocks import LinearAttention
510
611
712class RMSNorm (nn .Module ):
@@ -14,9 +19,9 @@ def forward(self, x: Tensor):
1419 return F .normalize (x , dim = - 1 ) * self .scale * self .g
1520
1621
17- class MultiQueryTransformerBlock (nn .Module ):
22+ class TransformerBlock (nn .Module ):
1823 """
19- MultiQueryTransformerBlock is a module that represents a single block of the Multi-Query Transformer.
24+ TransformerBlock is a module that represents a single block of the Multi-Query Transformer.
2025 It consists of a multi-query attention layer, a feed-forward network, and layer normalization.
2126
2227 Args:
@@ -38,7 +43,7 @@ class MultiQueryTransformerBlock(nn.Module):
3843
3944 Methods:
4045 forward(x: Tensor) -> Tensor:
41- Performs a forward pass of the MultiQueryTransformerBlock .
46+ Performs a forward pass of the TransformerBlock .
4247
4348 """
4449
@@ -49,6 +54,7 @@ def __init__(
4954 dim_head : int ,
5055 dropout : float = 0.1 ,
5156 ff_mult : int = 4 ,
57+ use_linear_attn : bool = False ,
5258 * args ,
5359 ** kwargs ,
5460 ):
@@ -58,17 +64,23 @@ def __init__(
5864 self .dim_head = dim_head
5965 self .dropout = dropout
6066 self .ff_mult = ff_mult
67+ self .use_linear_attn = use_linear_attn
6168
6269 self .attn = MultiQueryAttention (dim , heads , * args , ** kwargs )
6370
71+ # Linear Attention
72+ self .linear_attn = LinearAttention (
73+ dim = dim , heads = heads , dim_head = dim_head , dropout = dropout
74+ )
75+
6476 self .ffn = FeedForward (dim , dim , ff_mult , * args , ** kwargs )
6577
6678 # Normalization
6779 self .norm = nn .LayerNorm (dim )
6880
6981 def forward (self , x : Tensor ) -> Tensor :
7082 """
71- Performs a forward pass of the MultiQueryTransformerBlock .
83+ Performs a forward pass of the TransformerBlock .
7284
7385 Args:
7486 x (Tensor): The input tensor.
@@ -77,9 +89,14 @@ def forward(self, x: Tensor) -> Tensor:
7789 Tensor: The output tensor.
7890
7991 """
80- x , _ , _ = self .attn (x )
81- x = self .norm (x )
82- x = self .ffn (x )
92+ if self .use_linear_attn :
93+ x = self .linear_attn (x )
94+ x = self .norm (x )
95+ x = self .ffn (x )
96+ else :
97+ x , _ , _ = self .attn (x )
98+ x = self .norm (x )
99+ x = self .ffn (x )
83100 return x
84101
85102
@@ -106,7 +123,7 @@ class MambaTransformerblock(nn.Module):
106123 dropout (float): The dropout rate.
107124 ff_mult (int): The multiplier for the feed-forward network dimension.
108125 mamba_blocks (nn.ModuleList): List of MambaBlock instances.
109- transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances.
126+ transformer_blocks (nn.ModuleList): List of TransformerBlock instances.
110127 ffn_blocks (nn.ModuleList): List of FeedForward instances.
111128 norm (nn.LayerNorm): Layer normalization module.
112129
@@ -140,6 +157,7 @@ def __init__(
140157 d_state : int = None ,
141158 transformer_depth : int = 1 ,
142159 mamba_depth : int = 1 ,
160+ use_linear_attn : bool = False ,
143161 * args ,
144162 ** kwargs ,
145163 ):
@@ -167,15 +185,16 @@ def __init__(
167185 self .ffn_blocks .append (
168186 FeedForward (dim , dim , ff_mult , * args , ** kwargs )
169187 )
170-
188+
171189 for _ in range (transformer_depth ):
172190 self .transformer_blocks .append (
173- MultiQueryTransformerBlock (
191+ TransformerBlock (
174192 dim ,
175193 heads ,
176194 dim_head ,
177195 dropout ,
178196 ff_mult ,
197+ use_linear_attn ,
179198 * args ,
180199 ** kwargs ,
181200 )
@@ -247,6 +266,7 @@ def __init__(
247266 return_embeddings : bool = False ,
248267 transformer_depth : int = 1 ,
249268 mamba_depth : int = 1 ,
269+ use_linear_attn = False ,
250270 * args ,
251271 ** kwargs ,
252272 ):
@@ -274,6 +294,7 @@ def __init__(
274294 return_embeddings ,
275295 transformer_depth ,
276296 mamba_depth ,
297+ use_linear_attn ,
277298 * args ,
278299 ** kwargs ,
279300 )
0 commit comments