|
|
|
@ -14,9 +14,8 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Loss scale cell for loss scale training."""
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
|
|
|
|
from .cell_wrapper import TrainOneStepCell
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
from ...common import Tensor, RowTensor
|
|
|
|
|
from ...common.parameter import Parameter
|
|
|
|
@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell):
|
|
|
|
|
return overflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
|
|
r"""
|
|
|
|
|
Network training with loss scaling.
|
|
|
|
|
|
|
|
|
@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
|
|
|
|
|
>>> output = train_network(inputs, label, scaling_sens)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, optimizer, scale_sense):
|
|
|
|
|
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.network.add_flags(defer_inline=True)
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
|
|
|
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
if context.get_context("device_target") == "GPU":
|
|
|
|
|
self.gpu_target = True
|
|
|
|
@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
self.less_equal = LessEqual()
|
|
|
|
|
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
|
|
self.parallel_mode = _get_parallel_mode()
|
|
|
|
|
self.grad_reducer = F.identity
|
|
|
|
|
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
mean = _get_gradients_mean()
|
|
|
|
|
degree = _get_device_num()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
|
|
|
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
|
|
|
|
|
|
|
|
|
self.loss_scaling_manager = None
|
|
|
|
|