|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0):
|
|
|
|
|
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.weights = ParameterTuple(network.trainable_params())
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
|
|
|
|
self.sens = sens
|
|
|
|
@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
|
|
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.weights = ParameterTuple(network.trainable_params())
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
|
get_by_list=True,
|
|
|
|
|