Skip to content

Commit 27a8622

Browse files
authored
modify dropoutnet in case of batch size mismatch (#505)
* modify dropoutnet in case of batch size mismatch
1 parent 4468723 commit 27a8622

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

easy_rec/python/model/dropoutnet.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from easy_rec.python.model.easy_rec_model import EasyRecModel
88
from easy_rec.python.protos.loss_pb2 import LossType
99
from easy_rec.python.utils.proto_util import copy_obj
10-
from easy_rec.python.utils.shape_utils import get_shape_list
1110

1211
from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
1312
from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
@@ -22,6 +21,14 @@ def cosine_similarity(user_emb, item_emb):
2221
tf.multiply(user_emb, item_emb), axis=1, name='cosine')
2322
return user_item_sim
2423

24+
def bernoulli_dropout(x, rate, training=False):
25+
if rate == 0.0 or not training:
26+
return x
27+
keep_rate = 1.0 - rate
28+
dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype)
29+
mask = dist.sample(sample_shape=tf.stack([tf.shape(x)[0], 1]))
30+
return x * mask / keep_rate
31+
2532

2633
class DropoutNet(EasyRecModel):
2734

@@ -68,8 +75,6 @@ def __init__(self,
6875
assert self.item_content_feature is not None or self.item_preference_feature is not None, 'no item feature'
6976

7077
def build_predict_graph(self):
71-
batch_size = get_shape_list(self.item_content_feature)[0]
72-
7378
num_user_dnn_layer = len(self.user_tower_layers.hidden_units)
7479
last_user_hidden = self.user_tower_layers.hidden_units.pop()
7580
num_item_dnn_layer = len(self.item_tower_layers.hidden_units)
@@ -85,15 +90,9 @@ def build_predict_graph(self):
8590
content_feature = user_content_dnn(self.user_content_feature)
8691
user_features.append(content_feature)
8792
if self.user_preference_feature is not None:
88-
if self._is_training:
89-
prob = tf.random.uniform([batch_size])
90-
user_prefer_feature = tf.where(
91-
tf.less(prob, self._model_config.user_dropout_rate),
92-
tf.zeros_like(self.user_preference_feature),
93-
self.user_preference_feature)
94-
else:
95-
user_prefer_feature = self.user_preference_feature
96-
93+
user_prefer_feature = bernoulli_dropout(self.user_preference_feature,
94+
self._model_config.user_dropout_rate,
95+
self._is_training)
9796
user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
9897
'user_preference', self._is_training)
9998
prefer_feature = user_prefer_dnn(user_prefer_feature)
@@ -119,15 +118,9 @@ def build_predict_graph(self):
119118
content_feature = item_content_dnn(self.item_content_feature)
120119
item_features.append(content_feature)
121120
if self.item_preference_feature is not None:
122-
if self._is_training:
123-
prob = tf.random.uniform([batch_size])
124-
item_prefer_feature = tf.where(
125-
tf.less(prob, self._model_config.item_dropout_rate),
126-
tf.zeros_like(self.item_preference_feature),
127-
self.item_preference_feature)
128-
else:
129-
item_prefer_feature = self.item_preference_feature
130-
121+
item_prefer_feature = bernoulli_dropout(self.item_preference_feature,
122+
self._model_config.item_dropout_rate,
123+
self._is_training)
131124
item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
132125
'item_preference', self._is_training)
133126
prefer_feature = item_prefer_dnn(item_prefer_feature)

0 commit comments

Comments
 (0)