1- < << << << HEAD
2- import tensorflow as tf
3-
4-
5- def din_attention (query , keys , keys_length , is_softmax = False ):
6- """
7- 实现DIN模型中的attention模块
8- Args:
9- query (tf.Tensor): 目标 shape=(B, H)
10- keys (tf.Tensor): 历史行为序列, shape=(B, T, H)
11- keys_length (tf.Tensor): 历史行为队列长度, 目的为生成mask, shape=(B, )
12- is_softmax (bool): attention权重是否使用softmax激活
13-
14- Returns:
15- tf.Tensor, weighted sum pooling结果
16- """
17-
18- embedding_dim = query .shape [- 1 ].value
19- query = tf .tile (query , multiples = [1 , tf .shape (keys )[1 ]]) # (B, H*T)
20- query = tf .reshape (query , shape = (- 1 , tf .shape (keys )[1 ], embedding_dim )) # (B, T, H)
21- cross_all = tf .concat ([query , keys , query - keys , query * keys ], axis = - 1 ) # (B, T, 4*H)
22- d_layer_1_all = tf .layers .dense (cross_all , 64 , activation = tf .nn .relu , name = 'f1_att' , reuse = tf .AUTO_REUSE ) # (B, T, 64)
23- d_layer_2_all = tf .layers .dense (d_layer_1_all , 32 , activation = tf .nn .relu , name = 'f2_att' , reuse = tf .AUTO_REUSE ) # (B, T, 32)
24- d_layer_3_all = tf .layers .dense (d_layer_2_all , 1 , activation = None , name = 'f3_att' , reuse = tf .AUTO_REUSE ) # (B, T, 1)
25- output_weight = d_layer_3_all # (B, T, 1)
26-
27- # mask
28- keys_mask = tf .sequence_mask (keys_length , tf .shape (keys )[1 ]) # (B, T)
29- keys_mask = tf .expand_dims (keys_mask , - 1 ) # 与output_weight对齐, (B, T, 1)
30-
31- if is_softmax :
32- paddings = tf .ones_like (output_weight ) * (- 2 ** 32 + 1 ) # (B, T, 1)
33- output_weight = tf .where (keys_mask , output_weight , paddings ) # (B, T, 1)
34- # scale, 防止梯度消失
35- output_weight = output_weight / (embedding_dim ** 0.5 ) # (B, T, 1)
36- output_weight = tf .nn .softmax (output_weight , axis = 1 ) # (B, T, 1)
37- else : # 按论文原文, 不使用softmax激活
38- output_weight = tf .cast (keys_mask , tf .float32 ) # (B, T, 1)
39-
40- outputs = tf .matmul (output_weight , keys , transpose_a = True ) # (B, 1, T) * (B, T, H) = (B, 1, H)
41- outputs = tf .squeeze (outputs , 1 ) # (B, H)
42-
43- return outputs
44-
45-
46- if __name__ == "__main__" :
47- # Test
48- # B=2, T=3, H=4
49- # fake_keys = tf.zeros(shape=(2, 3, 4))
50- fake_keys = tf .random_normal (shape = (2 , 3 , 4 ))
51- fake_query = tf .random_normal (shape = (2 , 4 ))
52- fake_keys_length = tf .constant ([0 , 1 ], 3 )
53- attention_out1 = din_attention (fake_query , fake_keys , fake_keys_length , is_softmax = False )
54- attention_out2 = din_attention (fake_query , fake_keys , fake_keys_length , is_softmax = True )
55-
56- with tf .Session () as sess :
57- sess .run (tf .global_variables_initializer ())
58- print ("不使用softmax激活:" )
59- print (sess .run (attention_out1 ))
60- print ("使用softmax激活:" )
61- print (sess .run (attention_out2 ))
62- == == == =
631import tensorflow as tf
642
653
@@ -119,5 +57,4 @@ def din_attention(query, keys, keys_length, is_softmax=False):
11957 print ("不使用softmax激活:" )
12058 print (sess .run (attention_out1 ))
12159 print ("使用softmax激活:" )
122- print (sess .run (attention_out2 ))
123- > >> >> >> 734986 b93a9246f05fb1b15f98977242f436de04
60+ print (sess .run (attention_out2 ))
0 commit comments