|
|
|
@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
self.deep_layer_act,
|
|
|
|
|
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
|
|
|
|
|
|
|
|
|
self.embeddinglookup = nn.EmbeddingLookup()
|
|
|
|
|
self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
|
|
|
|
# Wide layer
|
|
|
|
|
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0)
|
|
|
|
|
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr)
|
|
|
|
|
wx = self.mul(wide_id_weight, mask)
|
|
|
|
|
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
|
|
|
|
# Deep layer
|
|
|
|
|
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0)
|
|
|
|
|
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr)
|
|
|
|
|
vx = self.mul(deep_id_embs, mask)
|
|
|
|
|
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
|
|
|
|
deep_in = self.dense_layer_1(deep_in)
|
|
|
|
|