expand layernorm_grad op

pull/8564/head
zengzitao 4 years ago
parent 5f7a9bd0b8
commit 326540cbbd

@ -28,3 +28,4 @@ from .tanh_grad import expand_tanhgrad
from .maximum_grad import expand_maximumgrad from .maximum_grad import expand_maximumgrad
from .minimum_grad import expand_minimumgrad from .minimum_grad import expand_minimumgrad
from .dropout_grad import expand_dropoutgrad from .dropout_grad import expand_dropoutgrad
from .layernorm_grad import expand_layernormgrad

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for LayerNormGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_layernormgrad(expand_info):
"""LayerNormGrad expander"""
# get op info.
x_desc = expand_info['input_desc'][0]
dy_desc = expand_info['input_desc'][1]
var_desc = expand_info['input_desc'][2]
mean_desc = expand_info['input_desc'][3]
gamma_desc = expand_info['input_desc'][4]
begin_norm_axis = None
begin_params_axis = None
epsilon = 1e-11
for item in expand_info['attr']:
if 'begin_norm_axis' in item:
begin_norm_axis = item['begin_norm_axis']
if 'begin_params_axis' in item:
begin_params_axis = item['begin_params_axis']
if 'epsilon' in item:
epsilon = item['epsilon']
shape_x = x_desc['shape']
if begin_norm_axis < 0:
begin_norm_axis += len(shape_x)
if begin_params_axis < 0:
begin_params_axis += len(shape_x)
norm_axis = tuple(range(begin_norm_axis, len(shape_x)))
param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0
for i in norm_axis:
reduce_size *= shape_x[i]
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope('main') as graph_scope:
# create input tensors.
x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format'])
dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format'])
mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format'])
gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format'])
graph_scope.set_input(x, dy, variance, mean, gamma)
# set some constant val.
eps = graph_builder.value(x.dtype, epsilon, x.data_format)
const_one = graph_builder.value(x.dtype, 1.0, x.data_format)
const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format)
const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format)
const_two = graph_builder.value(x.dtype, 2.0, x.data_format)
const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format)
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
# cal dg db
# dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
# db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
var_eps = graph_builder.emit('TensorAdd', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
x_sub_mean = graph_builder.emit('Sub', [x, mean])
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# cal sum_1
# sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
# keepdims=True)
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_2
# sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_3
# sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal dx = dx1 + dx2 + dx3
# dx1 = dy * gamma * rsqrt_var_eps
# dx2 = sum1 * 2.0 / mean_cof * x_sub_mean
# dx3 = (1.0 / mean_cof) * (-1.0 * rsqrt_var_eps * sum2 + 1.0 / mean_cof * sum1 * sum3)
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2])
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3])
# set graph output.
graph_scope.set_output(dx, dg, db)
graph = graph_builder.get()[0]
return graph

@ -105,6 +105,9 @@ def test_adam():
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam_weight_decay(): def test_adam_weight_decay():
np.random.seed(0) np.random.seed(0)
beta1 = np.array([0.9]).astype(np.float32) beta1 = np.array([0.9]).astype(np.float32)

@ -15,9 +15,12 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell
import mindspore.nn as nn
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.operations as P import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
@ -32,6 +35,43 @@ class Net(Cell):
return self.layernorm(x, y, z) return self.layernorm(x, y, z)
class LayerNormGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradNet, self).__init__()
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
def construct(self, dy, x, var, mean, gamma):
return self.norm(dy, x, var, mean, gamma)
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -75,3 +115,30 @@ def test_basic():
assert res2 assert res2
else: else:
assert False assert False
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad():
np.random.seed(0)
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 1e-11
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)

Loading…
Cancel
Save