From 2fecdede6b6cca648707bd83e3926c6937b350e6 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 9 Apr 2020 23:37:29 +0800 Subject: [PATCH] support amp when model eval, fix example of UnsortSegmentsSum --- .../ccsrc/parallel/step_auto_parallel.cc | 9 ++++ mindspore/ops/operations/array_ops.py | 7 +-- mindspore/ops/operations/nn_ops.py | 26 ++++++----- mindspore/train/amp.py | 42 +++++++++-------- mindspore/train/model.py | 8 ++-- tests/train_step_wrap.py | 45 +------------------ 6 files changed, 59 insertions(+), 78 deletions(-) diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index fe6be575ee..3f1e18183a 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector &all_nodes) { // Dealing with the RefKey case auto refkeys = cnode_with_refkeys.second; auto cnode = cnode_with_refkeys.first; + + auto cnode_ptr = cnode->cast(); + if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(cnode_ptr)) { + continue; + } + if (refkeys.size() > 1) { MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; } diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 850e895ad0..a7c3f50440 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer): Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. Examples: - >>> input_x = [1, 2, 3, 4] - >>> segment_ids = [0, 0, 1, 2] + >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) + >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) >>> num_segments = 4 - >>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) + >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) + [3, 3, 4, 0] """ @prim_attr_register diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 91f6d7ec01..acccfbaba3 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -22,6 +22,8 @@ from functools import reduce import numpy as np from ... import context +from ..._c_expression import signature_rw as sig_rw +from ..._c_expression import signature_kind as sig_kind from ..._checkparam import ParamValidator as validator from ..._checkparam import Rel, check_bool, check_int_positive from ...common import dtype as mstype @@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer): filter(lambda x: x.requires_grad, net.get_parameters())) >>> model = Model(net, loss, opt) """ - + __mindspore_signature__ = ( + ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), + ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), + ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), + ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), + ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) + ) @prim_attr_register def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): - validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT) - validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT) - validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE) - validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE) - validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE) return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): - validator.check_subclass("v_dtype", v_dtype, mstype.tensor) - validator.check_subclass("a_dtype", a_dtype, mstype.tensor) - v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) + if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: + validator.check_subclass("v_dtype", v_dtype, mstype.tensor) + validator.check_subclass("a_dtype", a_dtype, mstype.tensor) + validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) + validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) - return v_type + return g_dtype class SmoothL1Loss(PrimitiveWithInfer): diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index e909b44e40..c4c115ef27 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -82,6 +82,29 @@ def _check_kwargs(key_words): if loss_scale_manager: validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) + +def _add_loss_network(network, loss_fn, cast_model_type): + class WithLossCell(nn.Cell): + "Wrap loss for amp. Cast network output back to float32" + + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data, label): + out = self._backbone(data) + label = _mp_cast_helper(mstype.float32, label) + return self._loss_fn(F.cast(out, mstype.float32), label) + + validator.check_isinstance('loss_fn', loss_fn, nn.Cell) + if cast_model_type == mstype.float16: + network = WithLossCell(network, loss_fn) + else: + network = nn.WithLossCell(network, loss_fn) + return network + + def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): """ Build the mixed precision training cell automatically. @@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): _do_keep_batchnorm_fp32(network) if loss_fn: - class WithLossCell(nn.Cell): - "Wrap loss for amp. Cast network output back to float32" - - def __init__(self, backbone, loss_fn): - super(WithLossCell, self).__init__(auto_prefix=False) - self._backbone = backbone - self._loss_fn = loss_fn - - def construct(self, data, label): - out = self._backbone(data) - label = _mp_cast_helper(mstype.float32, label) - return self._loss_fn(F.cast(out, mstype.float32), label) - - validator.check_isinstance('loss_fn', loss_fn, nn.Cell) - if config.cast_model_type == mstype.float16: - network = WithLossCell(network, loss_fn) - else: - network = nn.WithLossCell(network, loss_fn) + network = _add_loss_network(network, loss_fn, config.cast_model_type) if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network = _VirtualDatasetCell(network) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 833fb07256..a1acec859c 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -24,8 +24,7 @@ from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper from ..nn.metrics import Loss -from ..nn.wrap import WithLossCell, WithEvalCell, \ - DataWrapper +from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from .parallel_utils import ParallelMode from ..common import dtype as mstype @@ -151,7 +150,10 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = WithEvalCell(self._network, self._loss_fn) + if self._optimizer: + self._eval_network = self._train_network.network + else: + self._eval_network = WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] def _clear_metrics(self): diff --git a/tests/train_step_wrap.py b/tests/train_step_wrap.py index 7289c01004..d48e25b837 100644 --- a/tests/train_step_wrap.py +++ b/tests/train_step_wrap.py @@ -21,47 +21,6 @@ from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore import Parameter, ParameterTuple - -run_opt = C.MultitypeFuncGraph("run_opt") - -# pylint: disable=unused-argument -@run_opt.register("Function", "Int", "Number", "Number", - "Tensor", "Tensor", "Tensor") -def tensor_run_opt(opt, iterator, learning_rate, momentum, - gradient, variable, moment): - success = True - new_weight = opt(gradient, moment, variable, learning_rate, momentum) - success = F.depend(success, P.Assign()(variable, new_weight)) - return success - - -class OptimizerByMomentum(nn.Cell): - """ - OptimizerByMomentum definition - """ - # list of tensor - def __init__(self, weights): - super(OptimizerByMomentum, self).__init__() - self.learning_rate = Parameter(0.1, name="learning_rate") - self.momentum = Parameter(0.05, name="momentum") - self.iter = Parameter(0, name="iter") - - self.weights = weights - self.moments = weights.clone(prefix="moments", init='zeros') - - self.hyper_map = C.HyperMap() - self.opt = P.ApplyMomentum() - - def construct(self, grads): - success = True - weights = self.weights - moments = self.moments - success = self.hyper_map( - F.partial(run_opt, self.opt, self.iter, - self.learning_rate, self.momentum), grads, weights, moments) - # self.learning_rate = updata_lr(self.learning_rate, self.momentum) - return success - class TrainStepWrap(nn.Cell): """ TrainStepWrap definition @@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell): self.network = network self.network.set_train() self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = OptimizerByMomentum(self.weights) + self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) self.hyper_map = C.HyperMap() self.grad = C.GradOperation('grad', get_by_list=True) @@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell): self.network = network self.network.set_train() self.weights = ParameterTuple(network.get_parameters()) - self.optimizer = OptimizerByMomentum(self.weights) + self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) self.hyper_map = C.HyperMap() self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.sens = sens