@@ -923,9 +923,17 @@ def get_config(self):
923923
924924
925925class VarLenColumnEmbedding (Layer ):
926- def __init__ (self , emb_vocab_size , emb_output_dim , dropout_rate = 0. , ** kwargs ):
926+ def __init__ (self , emb_vocab_size , emb_output_dim ,
927+ embeddings_initializer ,
928+ embeddings_regularizer ,
929+ activity_regularizer ,
930+ dropout_rate = 0. ,
931+ ** kwargs ):
927932 self .emb_vocab_size = emb_vocab_size
928933 self .emb_output_dim = emb_output_dim
934+ self .embeddings_initializer = embeddings_initializer
935+ self .embeddings_regularizer = embeddings_regularizer
936+ self .activity_regularizer = activity_regularizer
929937 self .dropout_rate = dropout_rate
930938 super (VarLenColumnEmbedding , self ).__init__ (** kwargs )
931939 self .dropout = None
@@ -937,28 +945,37 @@ def compute_output_shape(self, input_shape):
937945
938946 def build (self , input_shape = None ):
939947 super (VarLenColumnEmbedding , self ).build (input_shape )
940- self .emb_layer = Embedding (input_dim = self .emb_vocab_size , output_dim = self .emb_output_dim )
948+ self .emb_layer = Embedding (input_dim = self .emb_vocab_size ,
949+ output_dim = self .emb_output_dim ,
950+ embeddings_initializer = self .embeddings_initializer ,
951+ embeddings_regularizer = self .embeddings_regularizer ,
952+ activity_regularizer = self .activity_regularizer )
941953 if self .dropout_rate > 0 :
942954 self .dropout = SpatialDropout1D (self .dropout_rate , name = 'var_len_emb_dropout' )
943955 else :
944956 self .dropout = None
945957 self .built = True
946958
947959 def call (self , inputs ):
948- embedding_output = self .emb_layer . call (inputs )
949- embedding_output = embedding_output .reshape (( embedding_output [ 0 ], 1 , - 1 ) )
960+ embedding_output = self .emb_layer (inputs )
961+ embedding_output_reshape = tf .reshape (embedding_output , [ embedding_output . shape [ 0 ], 1 , - 1 ] )
950962 if self .dropout is not None :
951- dropout_output = self .dropout (embedding_output )
963+ dropout_output = self .dropout (embedding_output_reshape )
952964 else :
953- dropout_output = embedding_output
965+ dropout_output = embedding_output_reshape
954966 return dropout_output
955967
956968 def compute_mask (self , inputs , mask = None ):
957969 return None
958970
959971 def get_config (self , ):
960- config = { 'dropout_rate' : self .dropout_rate ,
961- 'emb_layer' : self .emb_layer .get_config ()}
972+ config = { 'dropout_rate' : self .dropout_rate ,
973+ 'emb_layer' : self .emb_layer .get_config (),
974+ 'embeddings_initializer' : self .embeddings_initializer ,
975+ 'embeddings_regularizer' : self .embeddings_regularizer ,
976+ 'emb_vocab_size' : self .emb_vocab_size ,
977+ 'emb_output_dim' : self .emb_output_dim
978+ }
962979 base_config = super (VarLenColumnEmbedding , self ).get_config ()
963980 return dict (list (base_config .items ()) + list (config .items ()))
964981
0 commit comments