diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 83696e8e80..938fa0c6ad 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -28,3 +28,4 @@ from .tanh_grad import expand_tanhgrad from .maximum_grad import expand_maximumgrad from .minimum_grad import expand_minimumgrad from .dropout_grad import expand_dropoutgrad +from .layernorm_grad import expand_layernormgrad diff --git a/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py new file mode 100644 index 0000000000..b8129083ed --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py @@ -0,0 +1,121 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for LayerNormGrad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_layernormgrad(expand_info): + """LayerNormGrad expander""" + # get op info. + x_desc = expand_info['input_desc'][0] + dy_desc = expand_info['input_desc'][1] + var_desc = expand_info['input_desc'][2] + mean_desc = expand_info['input_desc'][3] + gamma_desc = expand_info['input_desc'][4] + begin_norm_axis = None + begin_params_axis = None + epsilon = 1e-11 + for item in expand_info['attr']: + if 'begin_norm_axis' in item: + begin_norm_axis = item['begin_norm_axis'] + if 'begin_params_axis' in item: + begin_params_axis = item['begin_params_axis'] + if 'epsilon' in item: + epsilon = item['epsilon'] + + shape_x = x_desc['shape'] + if begin_norm_axis < 0: + begin_norm_axis += len(shape_x) + if begin_params_axis < 0: + begin_params_axis += len(shape_x) + norm_axis = tuple(range(begin_norm_axis, len(shape_x))) + param_axis = tuple(range(0, begin_params_axis)) + reduce_size = 1.0 + for i in norm_axis: + reduce_size *= shape_x[i] + + graph_builder = builder.GraphBuilder() + with graph_builder.graph_scope('main') as graph_scope: + # create input tensors. + x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format']) + dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) + variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format']) + mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format']) + gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format']) + graph_scope.set_input(x, dy, variance, mean, gamma) + + # set some constant val. + eps = graph_builder.value(x.dtype, epsilon, x.data_format) + const_one = graph_builder.value(x.dtype, 1.0, x.data_format) + const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format) + const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format) + const_two = graph_builder.value(x.dtype, 2.0, x.data_format) + const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format) + mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format) + + # cal dg db + # dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) + # db = np.sum(dy, axis=tuple(param_axis), keepdims=True) + var_eps = graph_builder.emit('TensorAdd', [variance, eps]) + sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) + rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) + x_sub_mean = graph_builder.emit('Sub', [x, mean]) + x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) + dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) + dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) + db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) + + # cal sum_1 + # sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), + # keepdims=True) + tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps]) + r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps]) + x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps]) + dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) + tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps]) + sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul]) + sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + + # cal sum_2 + # sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) + sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + + # cal sum_3 + # sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) + sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) + sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + + # cal dx = dx1 + dx2 + dx3 + # dx1 = dy * gamma * rsqrt_var_eps + # dx2 = sum1 * 2.0 / mean_cof * x_sub_mean + # dx3 = (1.0 / mean_cof) * (-1.0 * rsqrt_var_eps * sum2 + 1.0 / mean_cof * sum1 * sum3) + dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) + sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two]) + sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof]) + dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean]) + neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps]) + neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) + sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) + mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) + add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) + dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) + dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2]) + dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3]) + + # set graph output. + graph_scope.set_output(dx, dg, db) + + graph = graph_builder.get()[0] + return graph diff --git a/tests/st/ops/graph_kernel/test_fused_adam.py b/tests/st/ops/graph_kernel/test_fused_adam.py index ac230a2fac..49ccab9108 100644 --- a/tests/st/ops/graph_kernel/test_fused_adam.py +++ b/tests/st/ops/graph_kernel/test_fused_adam.py @@ -105,6 +105,9 @@ def test_adam(): assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_adam_weight_decay(): np.random.seed(0) beta1 = np.array([0.9]).astype(np.float32) diff --git a/tests/st/ops/graph_kernel/test_layernorm.py b/tests/st/ops/graph_kernel/test_layernorm.py index 668389cd80..4dc1805783 100644 --- a/tests/st/ops/graph_kernel/test_layernorm.py +++ b/tests/st/ops/graph_kernel/test_layernorm.py @@ -15,9 +15,12 @@ import numpy as np import pytest + import mindspore.context as context from mindspore import Tensor from mindspore.nn import Cell +import mindspore.nn as nn +from mindspore.ops.operations import _grad_ops as G import mindspore.ops.operations as P context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") @@ -32,6 +35,43 @@ class Net(Cell): return self.layernorm(x, y, z) +class LayerNormGradNet(nn.Cell): + def __init__(self, begin_norm_axis, begin_params_axis): + super(LayerNormGradNet, self).__init__() + self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) + + def construct(self, dy, x, var, mean, gamma): + return self.norm(dy, x, var, mean, gamma) + + +def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): + begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) + begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) + + norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] + param_axis = [i for i in range(0, begin_params_axis)] + num = 1 + for i in range(begin_norm_axis, len(x.shape)): + num *= x.shape[i] + + mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) + var = np.var(x, axis=tuple(norm_axis), keepdims=True) + gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) + dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) + db = np.sum(dy, axis=tuple(param_axis), keepdims=True) + + sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), + keepdims=True) + sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) + sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) + + dx1 = dy * gamma * np.power(var + epsilon, -0.5) + dx2 = sum1 * 2.0 / num * (x - mean) + dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) + dx = dx1 + dx2 + dx3 + return dx, dg, db, mean, var + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -75,3 +115,30 @@ def test_basic(): assert res2 else: assert False + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad(): + np.random.seed(0) + begin_norm_axis = 1 + begin_params_axis = 1 + x_np = np.random.randn(4096, 3072).astype(np.float32) + dy_np = np.random.randn(4096, 3072).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 1e-11 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)