expand layernorm_grad op

pull/8564/head
zengzitao 4 years ago
parent 5f7a9bd0b8
commit 326540cbbd

@ -28,3 +28,4 @@ from .tanh_grad import expand_tanhgrad
from .maximum_grad import expand_maximumgrad
from .minimum_grad import expand_minimumgrad
from .dropout_grad import expand_dropoutgrad
from .layernorm_grad import expand_layernormgrad

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for LayerNormGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_layernormgrad(expand_info):
"""LayerNormGrad expander"""
# get op info.
x_desc = expand_info['input_desc'][0]
dy_desc = expand_info['input_desc'][1]
var_desc = expand_info['input_desc'][2]
mean_desc = expand_info['input_desc'][3]
gamma_desc = expand_info['input_desc'][4]
begin_norm_axis = None
begin_params_axis = None
epsilon = 1e-11
for item in expand_info['attr']:
if 'begin_norm_axis' in item:
begin_norm_axis = item['begin_norm_axis']
if 'begin_params_axis' in item:
begin_params_axis = item['begin_params_axis']
if 'epsilon' in item:
epsilon = item['epsilon']
shape_x = x_desc['shape']
if begin_norm_axis < 0:
begin_norm_axis += len(shape_x)
if begin_params_axis < 0:
begin_params_axis += len(shape_x)
norm_axis = tuple(range(begin_norm_axis, len(shape_x)))
param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0
for i in norm_axis:
reduce_size *= shape_x[i]
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope('main') as graph_scope:
# create input tensors.
x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format'])
dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format'])
mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format'])
gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format'])
graph_scope.set_input(x, dy, variance, mean, gamma)
# set some constant val.
eps = graph_builder.value(x.dtype, epsilon, x.data_format)
const_one = graph_builder.value(x.dtype, 1.0, x.data_format)
const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format)
const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format)
const_two = graph_builder.value(x.dtype, 2.0, x.data_format)
const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format)
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
# cal dg db
# dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
# db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
var_eps = graph_builder.emit('TensorAdd', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
x_sub_mean = graph_builder.emit('Sub', [x, mean])
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# cal sum_1
# sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
# keepdims=True)
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_2
# sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_3
# sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal dx = dx1 + dx2 + dx3
# dx1 = dy * gamma * rsqrt_var_eps
# dx2 = sum1 * 2.0 / mean_cof * x_sub_mean
# dx3 = (1.0 / mean_cof) * (-1.0 * rsqrt_var_eps * sum2 + 1.0 / mean_cof * sum1 * sum3)
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2])
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3])
# set graph output.
graph_scope.set_output(dx, dg, db)
graph = graph_builder.get()[0]
return graph

@ -105,6 +105,9 @@ def test_adam():
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam_weight_decay():
np.random.seed(0)
beta1 = np.array([0.9]).astype(np.float32)

@ -15,9 +15,12 @@
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.nn as nn
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
@ -32,6 +35,43 @@ class Net(Cell):
return self.layernorm(x, y, z)
class LayerNormGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradNet, self).__init__()
self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
def construct(self, dy, x, var, mean, gamma):
return self.norm(dy, x, var, mean, gamma)
def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]
mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)
gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -75,3 +115,30 @@ def test_basic():
assert res2
else:
assert False
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad():
np.random.seed(0)
begin_norm_axis = 1
begin_params_axis = 1
x_np = np.random.randn(4096, 3072).astype(np.float32)
dy_np = np.random.randn(4096, 3072).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 1e-11
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)

Loading…
Cancel
Save