Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/Transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class TransformerEncoderLayer(Module):
"""

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
encoder_var=1, attn_mult=1, bias=True, nlayers=1, standparam=False):
encoder_var=1, attn_mult=1, bias=True, nlayers=1, standparam=False, is_causal=False):
super(TransformerEncoderLayer, self).__init__()
self.attn_mult = attn_mult
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,attn_mult=attn_mult,
Expand Down Expand Up @@ -249,14 +249,16 @@ class MultiheadAttention(Module):
__constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']

def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
add_zero_attn=False, kdim=None, vdim=None, attn_mult=1, encoder_var=1, standparam=False):
add_zero_attn=False, kdim=None, vdim=None, attn_mult=1, encoder_var=1, standparam=False, batch_first=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.attn_mult = attn_mult
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.standparam = standparam
# FIXME: quickfix since we are ignoring it
self.batch_first = batch_first

self.num_heads = num_heads
self.dropout = dropout
Expand Down