From cebe9f8198b1912ba48a85c604e3a433a6974a2b Mon Sep 17 00:00:00 2001 From: yao_yf Date: Wed, 16 Dec 2020 10:51:32 +0800 Subject: [PATCH] wide_and_deep_dropout_do_mask_remove --- model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 0dda5d58e0..149908185f 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -212,6 +212,7 @@ class WideDeepModel(nn.Cell): self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) + self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) @@ -233,6 +234,7 @@ class WideDeepModel(nn.Cell): self.wide_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.reduce_sum.shard(((1, get_group_size(), 1),)) self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) + self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: