Skip to content

Commit b7ca9c6

Browse files
committed
update var len embedding
1 parent 4643439 commit b7ca9c6

File tree

7 files changed

+32
-32
lines changed

7 files changed

+32
-32
lines changed

deeptables/models/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __new__(cls,
132132
gpu_usage_strategy=consts.GPU_USAGE_STRATEGY_GROWTH,
133133
distribute_strategy=None,
134134
var_len_categorical_columns=None,
135-
# a tuple3, format is (column_name, separator, pool_strategy), pool_strategy is one of max,avg; e.g. [('genres', '|', 'avg' )]
135+
# a tuple2, format is (column_name, separator), pool_strategy is one of max,avg; e.g. [('genres', '|' )]
136136
):
137137

138138
if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0:

deeptables/models/deepmodel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __build_model(self, task, num_classes, nets, categorical_columns, continuous
271271
if len(embeddings) == 1:
272272
flatten_emb_layer = Flatten(name='flatten_embeddings')(embeddings[0])
273273
else:
274-
flatten_emb_layer = Flatten(name='flatten_embeddings')(Concatenate(name='concat_embeddings_axis_0', axis=1)(embeddings))
274+
flatten_emb_layer = Flatten(name='flatten_embeddings')(Concatenate(name='concat_embeddings_axis_0', axis=-1)(embeddings))
275275

276276
self.model_desc.nets = nets
277277
self.model_desc.stacking = config.stacking_op
@@ -407,9 +407,8 @@ def __build_embeddings(self, categorical_columns, categorical_inputs,
407407
for column in var_len_categorical_columns:
408408
# todo add var len embedding description
409409
input_layer = var_len_inputs[column.name]
410-
var_len_embeddings = VarLenColumnEmbedding(pooling_strategy=column.pooling_strategy,
411-
input_dim=column.vocabulary_size,
412-
output_dim=column.embeddings_output_dim,
410+
var_len_embeddings = VarLenColumnEmbedding(emb_vocab_size=column.vocabulary_size,
411+
emb_output_dim=column.embeddings_output_dim,
413412
dropout_rate=embedding_dropout,
414413
name=consts.LAYER_PREFIX_EMBEDDING + column.name,
415414
embeddings_initializer=self.config.embeddings_initializer,

deeptables/models/layers.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -922,38 +922,43 @@ def get_config(self):
922922
return dict(list(base_config.items()) + list(config.items()))
923923

924924

925-
class VarLenColumnEmbedding(Embedding):
926-
def __init__(self, pooling_strategy='max', dropout_rate=0., **kwargs):
927-
if pooling_strategy not in ['mean', 'max']:
928-
raise ValueError("Param strategy should is one of mean, max")
929-
self.pooling_strategy = pooling_strategy
930-
self.dropout_rate = dropout_rate # 支持dropout
925+
class VarLenColumnEmbedding(Layer):
926+
def __init__(self, emb_vocab_size, emb_output_dim, dropout_rate=0. , **kwargs):
927+
self.emb_vocab_size = emb_vocab_size
928+
self.emb_output_dim = emb_output_dim
929+
self.dropout_rate = dropout_rate
931930
super(VarLenColumnEmbedding, self).__init__(**kwargs)
932-
self._dropout = None
931+
self.dropout = None
932+
self.emb_layer = None
933+
934+
def compute_output_shape(self, input_shape):
935+
n_dim = input_shape[1]
936+
return input_shape[0] , self.emb_output_dim * n_dim
933937

934938
def build(self, input_shape=None):
935939
super(VarLenColumnEmbedding, self).build(input_shape)
940+
self.emb_layer = Embedding(input_dim=self.emb_vocab_size, output_dim=self.emb_output_dim)
936941
if self.dropout_rate > 0:
937-
self._dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
942+
self.dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
938943
else:
939-
self._dropout = None
944+
self.dropout = None
940945
self.built = True
941946

942947
def call(self, inputs):
943-
embedding_output = super(VarLenColumnEmbedding, self).call(inputs)
944-
945-
if self._dropout is not None:
946-
dropout_output = self._dropout(embedding_output)
948+
embedding_output = self.emb_layer.call(inputs)
949+
embedding_output = embedding_output.reshape((embedding_output[0], 1, -1))
950+
if self.dropout is not None:
951+
dropout_output = self.dropout(embedding_output)
947952
else:
948953
dropout_output = embedding_output
949-
950954
return dropout_output
951955

952956
def compute_mask(self, inputs, mask=None):
953957
return None
954958

955959
def get_config(self, ):
956-
config = {'pooling_strategy': self.pooling_strategy}
960+
config = { 'dropout_rate': self.dropout_rate,
961+
'emb_layer': self.emb_layer.get_config()}
957962
base_config = super(VarLenColumnEmbedding, self).get_config()
958963
return dict(list(base_config.items()) + list(config.items()))
959964

deeptables/models/metainfo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,20 @@ class VarLenCategoricalColumn(collections.namedtuple('VarLenCategoricalColumn',
5555
'embeddings_output_dim',
5656
'dtype',
5757
'input_name',
58-
'sep',
59-
'pooling_strategy',
58+
'sep'
6059
])):
6160

6261
def __hash__(self):
6362
return self.name.__hash__()
6463

65-
def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name=None, sep="|", pooling_strategy='max'):
64+
def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name=None, sep="|"):
6665
if input_name is None:
6766
input_name = consts.INPUT_PREFIX_CAT + name
6867
if embeddings_output_dim == 0:
6968
embeddings_output_dim = int(round(vocabulary_size ** 0.25))
7069
# max_elements_length need a variable not const
7170
return super(VarLenCategoricalColumn, cls).__new__(cls, name, vocabulary_size, embeddings_output_dim, dtype,
72-
input_name, sep, pooling_strategy)
71+
input_name, sep)
7372

7473

7574
class ContinuousColumn(collections.namedtuple('ContinuousColumn',

deeptables/models/preprocessor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,8 @@ def _prepare_features(self, X):
279279
else:
280280
var_len_column_names.append(v[0])
281281
var_len_col_sep_dict = {v[0]: v[1] for v in var_len_categorical_columns}
282-
var_len_col_pooling_strategy_dict = {v[0]: v[2] for v in var_len_categorical_columns}
283282
else:
284283
var_len_col_sep_dict = {}
285-
var_len_col_pooling_strategy_dict = {}
286284

287285
X_shape = self._get_shape(X)
288286
unique_upper_limit = round(X_shape[0] ** self.config.cat_exponent)
@@ -299,8 +297,7 @@ def _prepare_features(self, X):
299297

300298
# handle var len feature
301299
if c in var_len_column_names:
302-
self.__append_var_len_categorical_col(c, nunique, var_len_col_sep_dict[c],
303-
var_len_col_pooling_strategy_dict[c])
300+
self.__append_var_len_categorical_col(c, nunique, var_len_col_sep_dict[c])
304301
continue
305302

306303
if self.config.categorical_columns is not None and isinstance(self.config.categorical_columns, list):
@@ -454,7 +451,7 @@ def _gbm_features_to_continuous_cols(self, X, gbmencoder):
454451
# return [name for name in gbmencoder.new_columns]
455452
return gbmencoder.new_columns
456453

457-
def __append_var_len_categorical_col(self, name, voc_size, sep, pooling_strategy):
454+
def __append_var_len_categorical_col(self, name, voc_size, sep):
458455
logger.debug(f'Var len categorical variables {name} appended.')
459456

460457
if self.config.fixed_embedding_dim:
@@ -470,8 +467,7 @@ def __append_var_len_categorical_col(self, name, voc_size, sep, pooling_strategy
470467
voc_size,
471468
embedding_output_dim if embedding_output_dim > 0 else min(
472469
4 * int(pow(voc_size, 0.25)), 20),
473-
sep=sep,
474-
pooling_strategy=pooling_strategy)
470+
sep=sep)
475471

476472
self.var_len_categorical_columns.append(vc)
477473

deeptables/tests/models/config_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ def test_embeddings_output_dim(self):
2323
dt = deeptable.DeepTable(config=conf)
2424

2525
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
26+
2627
model, history = dt.fit(X_train, y_train, epochs=1)

deeptables/utils/dataset_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __call__(self, X, y=None, *, batch_size, shuffle, drop_remainder):
5252
train_data.append(tf.constant(np.array(X[col.name].tolist()).astype(consts.DATATYPE_TENSOR_FLOAT).tolist()))
5353

5454
if y is None:
55-
ds = tf.data.Dataset.from_tensor_slices(train_data, name='train_x')
55+
ds = tf.data.Dataset.from_tensor_slices((tuple(train_data), ), name='train_x')
5656
else:
5757
y = tf.constant(np.array(y).tolist())
5858
if self.task == consts.TASK_MULTICLASS:

0 commit comments

Comments
 (0)