parent
5f7a9bd0b8
commit
326540cbbd
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for LayerNormGrad"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_layernormgrad(expand_info):
|
||||||
|
"""LayerNormGrad expander"""
|
||||||
|
# get op info.
|
||||||
|
x_desc = expand_info['input_desc'][0]
|
||||||
|
dy_desc = expand_info['input_desc'][1]
|
||||||
|
var_desc = expand_info['input_desc'][2]
|
||||||
|
mean_desc = expand_info['input_desc'][3]
|
||||||
|
gamma_desc = expand_info['input_desc'][4]
|
||||||
|
begin_norm_axis = None
|
||||||
|
begin_params_axis = None
|
||||||
|
epsilon = 1e-11
|
||||||
|
for item in expand_info['attr']:
|
||||||
|
if 'begin_norm_axis' in item:
|
||||||
|
begin_norm_axis = item['begin_norm_axis']
|
||||||
|
if 'begin_params_axis' in item:
|
||||||
|
begin_params_axis = item['begin_params_axis']
|
||||||
|
if 'epsilon' in item:
|
||||||
|
epsilon = item['epsilon']
|
||||||
|
|
||||||
|
shape_x = x_desc['shape']
|
||||||
|
if begin_norm_axis < 0:
|
||||||
|
begin_norm_axis += len(shape_x)
|
||||||
|
if begin_params_axis < 0:
|
||||||
|
begin_params_axis += len(shape_x)
|
||||||
|
norm_axis = tuple(range(begin_norm_axis, len(shape_x)))
|
||||||
|
param_axis = tuple(range(0, begin_params_axis))
|
||||||
|
reduce_size = 1.0
|
||||||
|
for i in norm_axis:
|
||||||
|
reduce_size *= shape_x[i]
|
||||||
|
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create input tensors.
|
||||||
|
x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format'])
|
||||||
|
dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
|
||||||
|
variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format'])
|
||||||
|
mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format'])
|
||||||
|
gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format'])
|
||||||
|
graph_scope.set_input(x, dy, variance, mean, gamma)
|
||||||
|
|
||||||
|
# set some constant val.
|
||||||
|
eps = graph_builder.value(x.dtype, epsilon, x.data_format)
|
||||||
|
const_one = graph_builder.value(x.dtype, 1.0, x.data_format)
|
||||||
|
const_neg_half = graph_builder.value(x.dtype, -0.5, x.data_format)
|
||||||
|
const_neg_two = graph_builder.value(x.dtype, -2.0, x.data_format)
|
||||||
|
const_two = graph_builder.value(x.dtype, 2.0, x.data_format)
|
||||||
|
const_neg_one = graph_builder.value(x.dtype, -1.0, x.data_format)
|
||||||
|
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
|
||||||
|
|
||||||
|
# cal dg db
|
||||||
|
# dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
|
||||||
|
# db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
|
||||||
|
var_eps = graph_builder.emit('TensorAdd', [variance, eps])
|
||||||
|
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
||||||
|
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
||||||
|
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
||||||
|
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
||||||
|
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
||||||
|
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||||
|
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
||||||
|
|
||||||
|
# cal sum_1
|
||||||
|
# sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
|
||||||
|
# keepdims=True)
|
||||||
|
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
|
||||||
|
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
|
||||||
|
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
|
||||||
|
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
||||||
|
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
|
||||||
|
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
|
||||||
|
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
|
|
||||||
|
# cal sum_2
|
||||||
|
# sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
|
||||||
|
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
|
|
||||||
|
# cal sum_3
|
||||||
|
# sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
|
||||||
|
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
||||||
|
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
||||||
|
|
||||||
|
# cal dx = dx1 + dx2 + dx3
|
||||||
|
# dx1 = dy * gamma * rsqrt_var_eps
|
||||||
|
# dx2 = sum1 * 2.0 / mean_cof * x_sub_mean
|
||||||
|
# dx3 = (1.0 / mean_cof) * (-1.0 * rsqrt_var_eps * sum2 + 1.0 / mean_cof * sum1 * sum3)
|
||||||
|
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
||||||
|
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
|
||||||
|
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
|
||||||
|
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
|
||||||
|
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
|
||||||
|
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
||||||
|
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
||||||
|
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
||||||
|
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
||||||
|
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
||||||
|
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2])
|
||||||
|
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3])
|
||||||
|
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(dx, dg, db)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
Loading…
Reference in new issue