@@ -922,38 +922,43 @@ def get_config(self):
922922 return dict (list (base_config .items ()) + list (config .items ()))
923923
924924
925- class VarLenColumnEmbedding (Embedding ):
926- def __init__ (self , pooling_strategy = 'max' , dropout_rate = 0. , ** kwargs ):
927- if pooling_strategy not in ['mean' , 'max' ]:
928- raise ValueError ("Param strategy should is one of mean, max" )
929- self .pooling_strategy = pooling_strategy
930- self .dropout_rate = dropout_rate # 支持dropout
925+ class VarLenColumnEmbedding (Layer ):
926+ def __init__ (self , emb_vocab_size , emb_output_dim , dropout_rate = 0. , ** kwargs ):
927+ self .emb_vocab_size = emb_vocab_size
928+ self .emb_output_dim = emb_output_dim
929+ self .dropout_rate = dropout_rate
931930 super (VarLenColumnEmbedding , self ).__init__ (** kwargs )
932- self ._dropout = None
931+ self .dropout = None
932+ self .emb_layer = None
933+
934+ def compute_output_shape (self , input_shape ):
935+ n_dim = input_shape [1 ]
936+ return input_shape [0 ] , self .emb_output_dim * n_dim
933937
934938 def build (self , input_shape = None ):
935939 super (VarLenColumnEmbedding , self ).build (input_shape )
940+ self .emb_layer = Embedding (input_dim = self .emb_vocab_size , output_dim = self .emb_output_dim )
936941 if self .dropout_rate > 0 :
937- self ._dropout = SpatialDropout1D (self .dropout_rate , name = 'var_len_emb_dropout' )
942+ self .dropout = SpatialDropout1D (self .dropout_rate , name = 'var_len_emb_dropout' )
938943 else :
939- self ._dropout = None
944+ self .dropout = None
940945 self .built = True
941946
942947 def call (self , inputs ):
943- embedding_output = super ( VarLenColumnEmbedding , self ) .call (inputs )
944-
945- if self ._dropout is not None :
946- dropout_output = self ._dropout (embedding_output )
948+ embedding_output = self . emb_layer .call (inputs )
949+ embedding_output = embedding_output . reshape (( embedding_output [ 0 ], 1 , - 1 ))
950+ if self .dropout is not None :
951+ dropout_output = self .dropout (embedding_output )
947952 else :
948953 dropout_output = embedding_output
949-
950954 return dropout_output
951955
952956 def compute_mask (self , inputs , mask = None ):
953957 return None
954958
955959 def get_config (self , ):
956- config = {'pooling_strategy' : self .pooling_strategy }
960+ config = { 'dropout_rate' : self .dropout_rate ,
961+ 'emb_layer' : self .emb_layer .get_config ()}
957962 base_config = super (VarLenColumnEmbedding , self ).get_config ()
958963 return dict (list (base_config .items ()) + list (config .items ()))
959964
0 commit comments