diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 850a58c0e5..9e2f7e2ea4 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -35,3 +35,4 @@ from .sqrt_grad import SqrtGrad from .square import Square from .tanh_grad import TanhGrad from .tile import Tile +from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign diff --git a/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py b/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py new file mode 100644 index 0000000000..8820cbfb41 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for LambApplyOptimizerAssign""" +from ._utils import Expander, ExpanderInfoValidator as VLD + +@VLD.check_all_formats_same +class LambApplyOptimizerAssign(Expander): + """LambApplyOptimizerAssign expander""" + + def _expand(self, graph_builder): + + [grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, + do_use_weight, weight_decay_rate] = self.inputs + + # next_v + square_grad = graph_builder.emit('Mul', [grad, grad]) + mul_3_result = graph_builder.emit('Mul', [square_grad, one_minus_beta_2]) + mul_2_result = graph_builder.emit('Mul', [inputv, beta_2]) + next_v = graph_builder.emit('Add', [mul_2_result, mul_3_result]) + + # next_m + mul_0_result = graph_builder.emit('Mul', [inputm, beta_1]) + mul_1_result = graph_builder.emit('Mul', [grad, one_minus_beta_1]) + next_m = graph_builder.emit('Add', [mul_0_result, mul_1_result]) + + shape = next_m.shape + const_one = graph_builder.value(beta_2.dtype, 1) + + beta_1_tensor = graph_builder.emit('BroadcastTo', [beta_1], attrs={'shape': shape}) + beta_2_tensor = graph_builder.emit('BroadcastTo', [beta_2], attrs={'shape': shape}) + + + # pow + beta_1_log = graph_builder.emit('Log', [beta_1_tensor]) + mul_res_1 = graph_builder.emit('Mul', [beta_1_log, steps]) + beta_1_steps = graph_builder.emit('Exp', [mul_res_1]) + + neg_beta_1_step = graph_builder.emit('Neg', [beta_1_steps]) + beta1_correction = graph_builder.emit('Add', [neg_beta_1_step, const_one]) + + next_m_unbiased = graph_builder.emit('RealDiv', [next_m, beta1_correction]) + + # pow + beta_2_log = graph_builder.emit('Log', [beta_2_tensor]) + mul_res_2 = graph_builder.emit('Mul', [beta_2_log, steps]) + beta_2_steps = graph_builder.emit('Exp', [mul_res_2]) + + neg_beta_2_step = graph_builder.emit('Neg', [beta_2_steps]) + beta2_correction = graph_builder.emit('Add', [neg_beta_2_step, const_one]) + + next_v_unbiased = graph_builder.emit('RealDiv', [next_v, beta2_correction]) + # update + sqrt_next_v = graph_builder.emit('Sqrt', [next_v_unbiased]) + + add_2_result = graph_builder.emit('Add', [sqrt_next_v, epsilon]) + update = graph_builder.emit('RealDiv', [next_m_unbiased, add_2_result]) + # update do_use_weight_decay + do_use_weight_mul = graph_builder.emit('Mul', [input_param, weight_decay_rate]) + do_use_weight_decay = graph_builder.emit('Mul', [do_use_weight_mul, do_use_weight]) + update = graph_builder.emit('Add', [do_use_weight_decay, update]) + + res = [update, next_v, next_m] + + return res diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 90800ffd35..c83e24e240 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -44,6 +44,7 @@ std::unordered_set GetExpandOps() { prim::kPrimTile, prim::kPrimSqrtGrad, prim::kPrimClipByNormNoDivSum, + prim::kLambApplyOptimizerAssign, #elif ENABLE_GPU prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ff39a5a10f..a2b9db3c3b 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -298,7 +298,7 @@ inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared( inline const PrimitivePtr kPrimTensorMove = std::make_shared("TensorMove"); inline const PrimitivePtr kPrimL2Normalize = std::make_shared("L2Normalize"); inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared("CustomExtractFeatures"); - +inline const PrimitivePtr kLambApplyOptimizerAssign = std::make_shared("LambApplyOptimizerAssign"); // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); diff --git a/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py b/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py new file mode 100644 index 0000000000..b1e3e469c3 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign() + + def construct(self, grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, + steps, do_use_weight, weight_decay_rate): + return self.lamb_apply_optimizer_assign(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, + one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate) + +def get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, + do_use_weight, weight_decay_rate, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + opt = Net() + output = opt(Tensor(grad), Tensor(inputv), Tensor(inputm), Tensor(input_param), Tensor(beta_1), + Tensor(one_minus_beta_1), Tensor(beta_2), Tensor(one_minus_beta_2), Tensor(epsilon), Tensor(steps), + Tensor(do_use_weight), Tensor(weight_decay_rate)) + return output + +def lamb_apply_optimizer_assign(): + + grad = np.array([0.01, 0.03, 0.05]).astype(np.float32) + inputv = np.array([1.2, 3.4, 5.6]).astype(np.float32) + inputm = np.array([0.11, 0.33, 0.55]).astype(np.float32) + input_param = np.array([1, 3, 5]).astype(np.float32) + beta_1 = np.array([0.9]).astype(np.float32) + beta_2 = np.array([0.999]).astype(np.float32) + one_minus_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) + one_minus_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) + epsilon = np.array([1e-6]).astype(np.float32) + steps = np.array([10]).astype(np.float32) + do_use_weight = np.array([1]).astype(np.float32) + weight_decay_rate = np.array([0.021]).astype(np.float32) + + expect = get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, + steps, do_use_weight, weight_decay_rate, False) + output = get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, + steps, do_use_weight, weight_decay_rate, True) + + e1, e2, e3 = list(expect) + o1, o2, o3 = list(output) + + assert np.allclose(o1.asnumpy(), e1.asnumpy()) + assert np.allclose(o2.asnumpy(), e2.asnumpy()) + assert np.allclose(o3.asnumpy(), e3.asnumpy()) + +def test_lamb_apply_optimizer_assign_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + lamb_apply_optimizer_assign()