!3298 modify model_zoo code

Merge pull request !3298 from changzherui/mod_zoo_code
pull/3298/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ade60ad3d3

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -231,7 +231,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, self.grad = C.GradOperation('grad', get_by_list=True,
sens_param=True) sens_param=True)

Loading…
Cancel
Save