diff --git a/examples/Transformer/model.py b/examples/Transformer/model.py index e8e7fae..8cfb820 100644 --- a/examples/Transformer/model.py +++ b/examples/Transformer/model.py @@ -161,7 +161,7 @@ class TransformerEncoderLayer(Module): """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", - encoder_var=1, attn_mult=1, bias=True, nlayers=1, standparam=False): + encoder_var=1, attn_mult=1, bias=True, nlayers=1, standparam=False, is_causal=False): super(TransformerEncoderLayer, self).__init__() self.attn_mult = attn_mult self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,attn_mult=attn_mult, @@ -249,7 +249,7 @@ class MultiheadAttention(Module): __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, - add_zero_attn=False, kdim=None, vdim=None, attn_mult=1, encoder_var=1, standparam=False): + add_zero_attn=False, kdim=None, vdim=None, attn_mult=1, encoder_var=1, standparam=False, batch_first=False): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.attn_mult = attn_mult @@ -257,6 +257,8 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=Fals self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.standparam = standparam + # FIXME: quickfix since we are ignoring it + self.batch_first = batch_first self.num_heads = num_heads self.dropout = dropout