From b2d98d275125bc167ae7348a34ccfbfe51e99588 Mon Sep 17 00:00:00 2001 From: shibeiji Date: Tue, 3 Nov 2020 09:41:36 +0800 Subject: [PATCH] add tbe fusion operators LambApplyOptimizerAssign and LambApplyWeightAssign for lamb optimizer --- .../ccsrc/transform/graph_ir/op_adapter_map.h | 2 + .../elewise_calculation_ops_declare.cc | 20 +++ .../elewise_calculation_ops_declare.h | 6 + mindspore/nn/optim/lamb.py | 54 +++++++- mindspore/ops/_op_impl/tbe/__init__.py | 2 + .../tbe/lamb_apply_optimizer_assign.py | 55 ++++++++ .../_op_impl/tbe/lamb_apply_weight_assign.py | 42 ++++++ mindspore/ops/_selected_ops.py | 12 ++ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/inner_ops.py | 129 ++++++++++++++++++ .../test_bert_tdt_lossscale.py | 4 +- .../bert_precision/test_bert_tdt_lossscale.py | 8 +- 12 files changed, 328 insertions(+), 8 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py create mode 100644 mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index 3a49b5d3c7..1e8e3a6b04 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -202,6 +202,8 @@ constexpr const char kNameCase[] = "Case"; constexpr const char kNameAssert[] = "Assert"; constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder"; constexpr const char kNameReverseV2[] = "ReverseV2"; +constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign"; +constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign"; class OpAdapterMap { public: diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc index 348aa9dbb4..e3f05965f3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc @@ -362,4 +362,24 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; REG_ADPT_DESC(Atan2, kNameAtan2, ADPT_DESC(Atan2)) + +// LambApplyOptimizerAssign +INPUT_MAP(LambApplyOptimizerAssign) = { + {1, INPUT_DESC(grad)}, {2, INPUT_DESC(inputv)}, {3, INPUT_DESC(inputm)}, + {4, INPUT_DESC(input3)}, {5, INPUT_DESC(mul0_x)}, {6, INPUT_DESC(mul1_x)}, + {7, INPUT_DESC(mul2_x)}, {8, INPUT_DESC(mul3_x)}, {9, INPUT_DESC(add2_y)}, + {10, INPUT_DESC(steps)}, {11, INPUT_DESC(do_use_weight)}, {12, INPUT_DESC(weight_decay_rate)}}; +ATTR_MAP(LambApplyOptimizerAssign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LambApplyOptimizerAssign) = {{0, OUTPUT_DESC(output0)}, {1, OUTPUT_DESC(inputv)}, {2, OUTPUT_DESC(inputm)}}; +REG_ADPT_DESC(LambApplyOptimizerAssign, kNameLambApplyOptimizerAssign, ADPT_DESC(LambApplyOptimizerAssign)) + +// LambApplyWeightAssign +INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)}, + {2, INPUT_DESC(input1)}, + {3, INPUT_DESC(input2)}, + {4, INPUT_DESC(input3)}, + {5, INPUT_DESC(input_param)}}; +ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}}; +REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h index 27a1407d3b..0f780aefee 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h @@ -189,5 +189,11 @@ DECLARE_OP_USE_OUTPUT(Round) DECLARE_OP_ADAPTER(Atan2) DECLARE_OP_USE_OUTPUT(Atan2) + +DECLARE_OP_ADAPTER(LambApplyOptimizerAssign) +DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign) + +DECLARE_OP_ADAPTER(LambApplyWeightAssign) +DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_ diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 19be613615..cffabcf393 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -111,6 +111,52 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v return op_cast(next_param, F.dtype(param)) return gradient +_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend") + +@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, + optim_filter): + """ + Update parameters function when device target is ascend. + + Args: + beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. + lr (Tensor): Learning rate. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. + global_step (Tensor): Global step. + param (Tensor): Parameters. + m (Tensor): m value of parameters. + v (Tensor): v value of parameters. + gradient (Tensor): Gradient of parameters. + decay_flag (bool): Specifies whether param update with weight decay. + optim_filter(bool): Applies parameter update or not. + + Returns: + Tensor, the new value of v after updating. + """ + if optim_filter: + op_cast = P.Cast() + op_norm = layer.Norm() + op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign() + op_lamb_apply_weight_assign = P.LambApplyWeightAssign() + + param_fp32 = op_cast(param, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + new_global_step = op_cast(global_step + num_one, mstype.float32) + weight_decay_flag = op_cast(decay_flag, mstype.float32) + + update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32, + beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps, + new_global_step, weight_decay_flag, weight_decay) + w_norm = op_norm(param_fp32) + g_norm = op_norm(update) + update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param)) + return update + return gradient + lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") @@ -279,6 +325,7 @@ class Lamb(Optimizer): self.hyper_map = C.HyperMap() self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ context.get_context("enable_graph_kernel") + self.device_ascend = context.get_context("device_target") == "Ascend" def construct(self, gradients): lr = self.get_lr() @@ -299,19 +346,20 @@ class Lamb(Optimizer): self.global_step, lr, self.weight_decay), self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: + lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt if self.is_group: if self.is_group_lr: - optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, self.global_step), lr, self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, self.global_step, lr), self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, + optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, self.global_step, lr, self.weight_decay), self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index d704687320..0f27fd921b 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -351,3 +351,5 @@ from .conv3d import _conv3d_tbe from .conv3d_backprop_input import _conv3d_backprop_input_tbe from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe from .conv3d_transpose import _conv3d_transpose_tbe +from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe +from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe diff --git a/mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py b/mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py new file mode 100644 index 0000000000..8740babdc2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LambApplyOptimizerAssign op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lamb_apply_optimizer_assign_op_info = TBERegOp("LambApplyOptimizerAssign") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("lamb_apply_optimizer_assign.so") \ + .compute_cost(10) \ + .kernel_name("lamb_apply_optimizer_assign") \ + .partial_flag(True) \ + .input(0, "grad", False, "required", "all") \ + .input(1, "inputv", False, "required", "all") \ + .input(2, "inputm", False, "required", "all") \ + .input(3, "input3", False, "required", "all") \ + .input(4, "mul0_x", False, "required", "all") \ + .input(5, "mul1_x", False, "required", "all") \ + .input(6, "mul2_x", False, "required", "all") \ + .input(7, "mul3_x", False, "required", "all") \ + .input(8, "add2_y", False, "required", "all") \ + .input(9, "steps", False, "required", "all") \ + .input(10, "do_use_weight", False, "required", "all") \ + .input(11, "weight_decay_rate", False, "required", "all") \ + .output(0, "output0", False, "required", "all") \ + .output(0, "inputv", False, "required", "all") \ + .output(0, "inputm", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(lamb_apply_optimizer_assign_op_info) +def _lamb_apply_optimizer_assign_tbe(): + """LambApplyOptimizerAssign TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py b/mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py new file mode 100644 index 0000000000..5a5015a01b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LambApplyWeightAssign op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lamb_apply_weight_assign_op_info = TBERegOp("LambApplyWeightAssign") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("lamb_apply_weight_assign.so") \ + .compute_cost(10) \ + .kernel_name("lamb_apply_weight_assign") \ + .partial_flag(True) \ + .input(0, "input0", False, "required", "all") \ + .input(1, "input1", False, "required", "all") \ + .input(2, "input2", False, "required", "all") \ + .input(3, "input3", False, "required", "all") \ + .input(4, "input_param", False, "required", "all") \ + .output(0, "input_param", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(lamb_apply_weight_assign_op_info) +def _lamb_apply_weight_assign_tbe(): + """LambApplyWeightAssign TBE register""" + return diff --git a/mindspore/ops/_selected_ops.py b/mindspore/ops/_selected_ops.py index dd5ee410e7..f0cb494327 100644 --- a/mindspore/ops/_selected_ops.py +++ b/mindspore/ops/_selected_ops.py @@ -112,3 +112,15 @@ class LambUpdateWithLR: class LambNextMV: def __call__(self, *args): pass + + +@op_selector +class LambApplyOptimizerAssign: + def __call__(self, *args): + pass + + +@op_selector +class LambApplyWeightAssign: + def __call__(self, *args): + pass diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 3e116715ee..00d37aefe5 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge -from .inner_ops import ScalarCast, Randperm, NoRepeatNGram +from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index 5c74acae48..10572819bb 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -172,3 +172,132 @@ class NoRepeatNGram(PrimitiveWithInfer): valid_values = (mstype.float16, mstype.float32, mstype.float64) validator.check_type_name("log_type", log_type, valid_values, self.name) return log_type + + +class LambApplyOptimizerAssign(PrimitiveWithInfer): + r""" + Updates gradients by LAMB optimizer algorithm. Get the compute ratio. + + The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes + `_. + + The updating formulas are as follows, + + .. math:: + \begin{array}{ll} \\ + m = \beta_1 * m + (1 - \beta_1) * g \\ + v = \beta_2 * v + (1 - \beta_2) * g * g \\ + m = \frac{m}{1 - \beta_1^t} \\ + v = \frac{v}{1 - \beta_2^t} \\ + r = \frac{m}{\sqrt{v} + \epsilon} \\ + w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) + \end{array} + + :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents + `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, + :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and + `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents + `epsilon`. + + Inputs: + - **gradient** (Tensor) - Gradient of parameters, float32/float16. + - **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`. + - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`. + - **var** (Tensor) - Weights to be updated, has the same type as `gradient`. + - **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16. + - **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`. + - **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`. + - **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`. + - **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`. + - **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`. + - **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`. + - **decay_flag** (Tensor) -Specify whether param upadte with weight decay, has the same type as `beta1`. + - **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`. + + Outputs: + Tensor, the compute ratio r. + - **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`. + - **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace, + has the same type as `gradient`. + - **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace, + has the same type as `gradient`. + + Supported Platforms: + ``Ascend`` + """ + @prim_attr_register + def __init__(self): + """Initialize LambApplyOptimizerAssign""" + + def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape, + beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape): + validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) + return m_shape, v_shape, m_shape + + def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype, + beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype): + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + + args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype, + "eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype, + "weight_decay": weight_decay_dtype} + validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) + return m_dtype, v_dtype, v_dtype + + +class LambApplyWeightAssign(PrimitiveWithInfer): + r""" + Updates gradients by LAMB optimizer algorithm. The weight update part. + + The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes + `_. + + The updating formulas are as follows, + + .. math:: + \begin{array}{ll} \\ + m = \beta_1 * m + (1 - \beta_1) * g \\ + v = \beta_2 * v + (1 - \beta_2) * g * g \\ + m = \frac{m}{1 - \beta_1^t} \\ + v = \frac{v}{1 - \beta_2^t} \\ + r = \frac{m}{\sqrt{v} + \epsilon} \\ + w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) + \end{array} + + :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents + `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, + :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and + `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents + `epsilon`. + + Inputs: + - **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16. + - **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`. + - **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16. + - **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16. + - **var** (Tensor) - Weights to be updated, the same shape and type as `update`. + + Outputs: + - **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs. + + Supported Platforms: + ``Ascend`` + """ + @prim_attr_register + def __init__(self): + """Initialize LambApplyWeightAssign""" + + def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape): + validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name) + return var_shape + + def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype): + args = {"var": var_dtype, "update": update_dtype} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + + args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype} + validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) + return var_dtype diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py index 5dabf02415..67e0cc1aa4 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py @@ -229,7 +229,7 @@ def test_bert_performance(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [10.235566, 10.207392, 10.206976] + expect_loss_value = [11.325791, 11.285011, 11.284766] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) @@ -239,7 +239,7 @@ def test_bert_performance(): assert (overflow == expect_overflow).all() loss_scale = np.array(callback.lossscale_list) - expect_loss_scale = [262144.0, 262144.0, 262144.0] + expect_loss_scale = [65536.0, 65536.0, 65536.0] print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0, 0) diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index 3f6d9e72c5..a2d7427784 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -225,8 +225,12 @@ def test_bert_percision(enable_graph_kernel=False): loss_value = np.array(callback.loss_list) assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001) - expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, - 12.6212320, 12.2229223, 12.4272099] + if enable_graph_kernel: + expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, + 12.6212320, 12.2229223, 12.4272099] + else: + expect_loss_value = [12.2065868, 11.94102, 11.931558, 11.938105, 11.932648, 12.556579, 12.130686, 12.783716, + 12.360179, 12.578461] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)