diff --git a/model_zoo/mass/src/transformer/transformer_for_train.py b/model_zoo/mass/src/transformer/transformer_for_train.py index eb75e2d7b9..656b9e6f40 100644 --- a/model_zoo/mass/src/transformer/transformer_for_train.py +++ b/model_zoo/mass/src/transformer/transformer_for_train.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode @@ -231,7 +231,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.add_flags(defer_inline=True) - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)