-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Hey!
looks like there is a mess in tensors arrangements in few place here:
| q, k, v = rearrange_many((q, k, v), 'b (h d) n -> b h n d', h = self.heads) |
does not make sense to go to 'b h n d' last dim is channels dim
than next
| projs = rearrange_many(projs.split(self.heads // self.groups, dim = 1), 'b h n d -> (b h) n d') |
we split heads into groups and merge with batch and try to apply convolution that expects channels as second dim which is seq-len in our case now.
there is also
| ds_convs.append(CausalDepthwiseConv1d(inner_dim, kernel_size)) |
setup of convolution layer to expect full inner_dim as input channels somehow
lucidrains
Metadata
Metadata
Assignees
Labels
No labels