From e67cf4437c00927a6db91a5ea0164fd53b934c87 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Mon, 15 Mar 2021 18:57:20 +0800 Subject: [PATCH] add momentum-decay-scale for gpu pynative --- mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/inner_ops.py | 69 +++++++- .../cv/resnet/gpu_resnet_benchmark.py | 6 +- model_zoo/official/cv/resnet/src/momentum.py | 152 ++++++++++++++++++ 4 files changed, 228 insertions(+), 3 deletions(-) create mode 100644 model_zoo/official/cv/resnet/src/momentum.py diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 92470c7e80..6a9feb245d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,8 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge -from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey +from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, + FusedWeightScaleApplyMomentum) from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, @@ -165,6 +166,7 @@ __all__ = [ 'NLLLoss', 'SGD', 'ApplyMomentum', + 'FusedWeightScaleApplyMomentum', 'ExpandDims', 'Cast', 'IsSubClass', diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index fbfc0a3952..66e3a59262 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -21,7 +21,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.dtype import tensor, dtype_to_pytype from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer - +from .. import signature as sig class ScalarCast(PrimitiveWithInfer): """ @@ -372,3 +372,70 @@ class MakeRefKey(Primitive): def __call__(self): pass + +class FusedWeightScaleApplyMomentum(PrimitiveWithInfer): + """ + Optimizer that implements the Momentum algorithm with weight decay and loss scale. + + Refer to the paper `On the importance of initialization and momentum in deep + learning `_ for more details. + + Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage. + + Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules + to make the data types consistent. + If they have different data types, lower priority data type will be converted to + relatively highest priority data type. + Data type conversion of Parameter is not supported. RuntimeError exception will be thrown. + + Inputs: + - **weight_decay** (Tensor) - The weight decay value, must be a scalar tensor with float data type. + Default: 0.0. + - **loss_scale** (Tensor) - The loss scale value, must be a scalar tensor with float data type. + Default: 1.0. + - **variable** (Parameter) - Weights to be updated. data type must be float. + - **accumulation** (Parameter) - Accumulated gradient value by moment weight. + Has the same data type with `variable`. + - **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or + a scalar tensor with float data type. + - **gradient** (Tensor) - Gradient, has the same data type as `variable`. + - **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or + a scalar tensor with float data type. + + Outputs: + Tensor, parameters to be updated. + + Supported Platforms: + ``GPU`` + Examples: + Please refer to the usage in :class:`mindspore.nn.Momentum`, and add weight_decay and loss_scale as inputs. + """ + __mindspore_signature__ = ( + sig.make_sig('weight_decay', dtype=sig.sig_dtype.T3), + sig.make_sig('loss_scale', dtype=sig.sig_dtype.T3), + sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), + sig.make_sig('gradient', dtype=sig.sig_dtype.T), + sig.make_sig('momentum', dtype=sig.sig_dtype.T2) + ) + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['weight_decay', 'loss_scale', 'variable', 'accumulation', 'learning_rate', + 'gradient', 'momentum'], outputs=['output']) + + def infer_shape(self, d_shape, s_shape, v_shape, a_shape, l_shape, g_shape, m_shape): + return v_shape + + def infer_dtype(self, d_dtype, s_dtype, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): + valid_dtypes = [mstype.float16, mstype.float32] + if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: + validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) + return v_dtype diff --git a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py index 7ef8b3497f..78fe2707cb 100644 --- a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py +++ b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py @@ -33,6 +33,7 @@ import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as C from src.resnet_gpu_benchmark import resnet50 as resnet from src.CrossEntropySmooth import CrossEntropySmooth +from src.momentum import Momentum as MomentumWeightDecay parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.') @@ -228,7 +229,10 @@ def train(): model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) # Mixed precision if compute_type == "fp16": - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) + if mode == context.PYNATIVE_MODE: + opt = MomentumWeightDecay(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) + else: + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False) # define callbacks diff --git a/model_zoo/official/cv/resnet/src/momentum.py b/model_zoo/official/cv/resnet/src/momentum.py new file mode 100644 index 0000000000..7dd96e5d0c --- /dev/null +++ b/model_zoo/official/cv/resnet/src/momentum.py @@ -0,0 +1,152 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""momentum""" +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from mindspore._checkparam import Validator +from mindspore.nn.optim.optimizer import Optimizer + +_momentum_opt = C.MultitypeFuncGraph("momentum_opt") + + +@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt_ext(opt, weight_decay, scale, momentum, learning_rate, gradient, weight, moment): + """Apply momentum optimizer to the weight parameter using Tensor.""" + success = F.depend(True, opt(weight_decay, scale, weight, moment, learning_rate, gradient, momentum)) + return success + + +class Momentum(Optimizer): + r""" + Implements the Momentum algorithm. + + Refer to the paper on the importance of initialization and momentum in deep learning for more details. + + .. math:: + v_{t} = v_{t-1} \ast u + gradients + + If use_nesterov is True: + + .. math:: + p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr) + + If use_nesterov is Flase: + + .. math:: + p_{t} = p_{t-1} - lr \ast v_{t} + + Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + momentum (float): Hyperparameter of type float, means momentum for the moving average. + It must be at least 0.0. + weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0. + loss_scale (int, float): A floating point value for the loss scale. It must be greater than 0.0. Default: 1.0. + use_nesterov (bool): Enable Nesterov momentum. Default: False. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + tuple[bool], all elements are True. + + Raises: + ValueError: If the momentum is less than 0.0. + TypeError: If the momentum is not a float or use_nesterov is not a bool. + + Supported Platforms: + ``GPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) + >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. + >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + """ + def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): + super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) + Validator.check_value_type("momentum", momentum, [float], self.cls_name) + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") + self.params = self.parameters + self.use_nesterov = Validator.check_bool(use_nesterov) + self.moments = self.params.clone(prefix="moments", init='zeros') + self.hyper_map = C.HyperMap() + # Use FusedWeightScaleApplyMomentum to avoid extra kernel launch. + self.opt = P.FusedWeightScaleApplyMomentum() + + def construct(self, gradients): + params = self.params + moments = self.moments + weight_decay = Tensor(0.0, mstype.float32) + scale = Tensor(1.0, mstype.float32) + if self.exec_weight_decay: + weight_decay = self.weight_decay_tensor + if self.need_scale: + scale = self.reciprocal_scale + lr = self.get_lr() + if self.is_group_lr: + success = self.hyper_map(F.partial(_momentum_opt, self.opt, weight_decay, scale, self.momentum), + lr, gradients, params, moments) + else: + success = self.hyper_map(F.partial(_momentum_opt, self.opt, weight_decay, scale, self.momentum, lr), + gradients, params, moments) + return success