diff --git a/model_zoo/official/gnn/gat/src/utils.py b/model_zoo/official/gnn/gat/src/utils.py index 8c68fa9696..23e0c6c306 100644 --- a/model_zoo/official/gnn/gat/src/utils.py +++ b/model_zoo/official/gnn/gat/src/utils.py @@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell): def __init__(self, network, optimizer, sens=1.0): super(TrainOneStepCell, self).__init__(auto_prefix=True) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer @@ -167,7 +168,6 @@ class TrainGAT(nn.Cell): def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff): super(TrainGAT, self).__init__(auto_prefix=False) self.network = network - self.network.set_grad() loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff) optimizer = nn.Adam(loss_net.trainable_params(), learning_rate=learning_rate) diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index d231e456ce..f7b8914b90 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -328,7 +328,6 @@ class TrainStepWrap(nn.Cell): parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.network = network - self.network.set_grad() self.network.set_train() self.trainable_params = network.trainable_params() weights_w = [] @@ -361,6 +360,8 @@ class TrainStepWrap(nn.Cell): self.sens = sens self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_d = IthOutputCell(network, output_index=1) + self.loss_net_w.set_grad() + self.loss_net_d.set_grad() self.reducer_flag = False self.grad_reducer_w = None diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py index 519d65efea..9904ef028d 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py @@ -509,7 +509,6 @@ class TrainStepWrap(nn.Cell): def __init__(self, network, config, sens=1000.0): super(TrainStepWrap, self).__init__() self.network = network - self.network.set_grad() self.network.set_train() self.trainable_params = network.trainable_params() weights_w = [] @@ -544,6 +543,8 @@ class TrainStepWrap(nn.Cell): self.sens = sens self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_d = IthOutputCell(network, output_index=1) + self.loss_net_w.set_grad() + self.loss_net_w.set_grad() self.reducer_flag = False self.grad_reducer_w = None