Skip to content

Commit 333b4d3

Browse files
committed
update config
1 parent b7ca9c6 commit 333b4d3

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

deeptables/models/layers.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -923,9 +923,17 @@ def get_config(self):
923923

924924

925925
class VarLenColumnEmbedding(Layer):
926-
def __init__(self, emb_vocab_size, emb_output_dim, dropout_rate=0. , **kwargs):
926+
def __init__(self, emb_vocab_size, emb_output_dim,
927+
embeddings_initializer,
928+
embeddings_regularizer,
929+
activity_regularizer,
930+
dropout_rate=0.,
931+
**kwargs):
927932
self.emb_vocab_size = emb_vocab_size
928933
self.emb_output_dim = emb_output_dim
934+
self.embeddings_initializer = embeddings_initializer
935+
self.embeddings_regularizer = embeddings_regularizer
936+
self.activity_regularizer = activity_regularizer
929937
self.dropout_rate = dropout_rate
930938
super(VarLenColumnEmbedding, self).__init__(**kwargs)
931939
self.dropout = None
@@ -937,28 +945,37 @@ def compute_output_shape(self, input_shape):
937945

938946
def build(self, input_shape=None):
939947
super(VarLenColumnEmbedding, self).build(input_shape)
940-
self.emb_layer = Embedding(input_dim=self.emb_vocab_size, output_dim=self.emb_output_dim)
948+
self.emb_layer = Embedding(input_dim=self.emb_vocab_size,
949+
output_dim=self.emb_output_dim,
950+
embeddings_initializer=self.embeddings_initializer,
951+
embeddings_regularizer=self.embeddings_regularizer,
952+
activity_regularizer=self.activity_regularizer)
941953
if self.dropout_rate > 0:
942954
self.dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
943955
else:
944956
self.dropout = None
945957
self.built = True
946958

947959
def call(self, inputs):
948-
embedding_output = self.emb_layer.call(inputs)
949-
embedding_output = embedding_output.reshape((embedding_output[0], 1, -1))
960+
embedding_output = self.emb_layer(inputs)
961+
embedding_output_reshape = tf.reshape(embedding_output, [embedding_output.shape[0], 1, -1])
950962
if self.dropout is not None:
951-
dropout_output = self.dropout(embedding_output)
963+
dropout_output = self.dropout(embedding_output_reshape)
952964
else:
953-
dropout_output = embedding_output
965+
dropout_output = embedding_output_reshape
954966
return dropout_output
955967

956968
def compute_mask(self, inputs, mask=None):
957969
return None
958970

959971
def get_config(self, ):
960-
config = { 'dropout_rate': self.dropout_rate,
961-
'emb_layer': self.emb_layer.get_config()}
972+
config = { 'dropout_rate': self.dropout_rate,
973+
'emb_layer': self.emb_layer.get_config(),
974+
'embeddings_initializer': self.embeddings_initializer,
975+
'embeddings_regularizer': self.embeddings_regularizer,
976+
'emb_vocab_size': self.emb_vocab_size,
977+
'emb_output_dim': self.emb_output_dim
978+
}
962979
base_config = super(VarLenColumnEmbedding, self).get_config()
963980
return dict(list(base_config.items()) + list(config.items()))
964981

0 commit comments

Comments
 (0)