From 381d06549c1fce69193ba4b268606baa07a30431 Mon Sep 17 00:00:00 2001 From: lvchangquan Date: Mon, 14 Sep 2020 12:30:34 +0800 Subject: [PATCH] fix a bug with add set_grad() in wide_and_deep network --- model_zoo/official/gnn/gat/src/utils.py | 2 +- .../official/recommend/wide_and_deep/src/wide_and_deep.py | 3 ++- .../recommend/wide_and_deep_multitable/src/wide_and_deep.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/gnn/gat/src/utils.py b/model_zoo/official/gnn/gat/src/utils.py index 8c68fa9696..23e0c6c306 100644 --- a/model_zoo/official/gnn/gat/src/utils.py +++ b/model_zoo/official/gnn/gat/src/utils.py @@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell): def __init__(self, network, optimizer, sens=1.0): super(TrainOneStepCell, self).__init__(auto_prefix=True) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer @@ -167,7 +168,6 @@ class TrainGAT(nn.Cell): def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff): super(TrainGAT, self).__init__(auto_prefix=False) self.network = network - self.network.set_grad() loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff) optimizer = nn.Adam(loss_net.trainable_params(), learning_rate=learning_rate) diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index d231e456ce..f7b8914b90 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -328,7 +328,6 @@ class TrainStepWrap(nn.Cell): parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.network = network - self.network.set_grad() self.network.set_train() self.trainable_params = network.trainable_params() weights_w = [] @@ -361,6 +360,8 @@ class TrainStepWrap(nn.Cell): self.sens = sens self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_d = IthOutputCell(network, output_index=1) + self.loss_net_w.set_grad() + self.loss_net_d.set_grad() self.reducer_flag = False self.grad_reducer_w = None diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py index 1358f6f76b..6fe30d3b0e 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py @@ -510,7 +510,6 @@ class TrainStepWrap(nn.Cell): def __init__(self, network, config, sens=1000.0): super(TrainStepWrap, self).__init__() self.network = network - self.network.set_grad() self.network.set_train() self.trainable_params = network.trainable_params() weights_w = [] @@ -546,6 +545,8 @@ class TrainStepWrap(nn.Cell): self.sens = sens self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_d = IthOutputCell(network, output_index=1) + self.loss_net_w.set_grad() + self.loss_net_w.set_grad() self.reducer_flag = False self.grad_reducer_w = None