77from easy_rec .python .model .easy_rec_model import EasyRecModel
88from easy_rec .python .protos .loss_pb2 import LossType
99from easy_rec .python .utils .proto_util import copy_obj
10- from easy_rec .python .utils .shape_utils import get_shape_list
1110
1211from easy_rec .python .protos .dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
1312from easy_rec .python .loss .softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
@@ -22,6 +21,14 @@ def cosine_similarity(user_emb, item_emb):
2221 tf .multiply (user_emb , item_emb ), axis = 1 , name = 'cosine' )
2322 return user_item_sim
2423
24+ def bernoulli_dropout (x , rate , training = False ):
25+ if rate == 0.0 or not training :
26+ return x
27+ keep_rate = 1.0 - rate
28+ dist = tf .distributions .Bernoulli (probs = keep_rate , dtype = x .dtype )
29+ mask = dist .sample (sample_shape = tf .stack ([tf .shape (x )[0 ], 1 ]))
30+ return x * mask / keep_rate
31+
2532
2633class DropoutNet (EasyRecModel ):
2734
@@ -68,8 +75,6 @@ def __init__(self,
6875 assert self .item_content_feature is not None or self .item_preference_feature is not None , 'no item feature'
6976
7077 def build_predict_graph (self ):
71- batch_size = get_shape_list (self .item_content_feature )[0 ]
72-
7378 num_user_dnn_layer = len (self .user_tower_layers .hidden_units )
7479 last_user_hidden = self .user_tower_layers .hidden_units .pop ()
7580 num_item_dnn_layer = len (self .item_tower_layers .hidden_units )
@@ -85,15 +90,9 @@ def build_predict_graph(self):
8590 content_feature = user_content_dnn (self .user_content_feature )
8691 user_features .append (content_feature )
8792 if self .user_preference_feature is not None :
88- if self ._is_training :
89- prob = tf .random .uniform ([batch_size ])
90- user_prefer_feature = tf .where (
91- tf .less (prob , self ._model_config .user_dropout_rate ),
92- tf .zeros_like (self .user_preference_feature ),
93- self .user_preference_feature )
94- else :
95- user_prefer_feature = self .user_preference_feature
96-
93+ user_prefer_feature = bernoulli_dropout (self .user_preference_feature ,
94+ self ._model_config .user_dropout_rate ,
95+ self ._is_training )
9796 user_prefer_dnn = dnn .DNN (self .user_preference_layers , self ._l2_reg ,
9897 'user_preference' , self ._is_training )
9998 prefer_feature = user_prefer_dnn (user_prefer_feature )
@@ -119,15 +118,9 @@ def build_predict_graph(self):
119118 content_feature = item_content_dnn (self .item_content_feature )
120119 item_features .append (content_feature )
121120 if self .item_preference_feature is not None :
122- if self ._is_training :
123- prob = tf .random .uniform ([batch_size ])
124- item_prefer_feature = tf .where (
125- tf .less (prob , self ._model_config .item_dropout_rate ),
126- tf .zeros_like (self .item_preference_feature ),
127- self .item_preference_feature )
128- else :
129- item_prefer_feature = self .item_preference_feature
130-
121+ item_prefer_feature = bernoulli_dropout (self .item_preference_feature ,
122+ self ._model_config .item_dropout_rate ,
123+ self ._is_training )
131124 item_prefer_dnn = dnn .DNN (self .item_preference_layers , self ._l2_reg ,
132125 'item_preference' , self ._is_training )
133126 prefer_feature = item_prefer_dnn (item_prefer_feature )
0 commit comments