Pre Merge pull request !13475 from peiwenfang/adapt_for_layernorm_in_ascend

pull/13475/MERGE
peiwenfang 4 years ago committed by Gitee
commit b4431d7f65

@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
class LayerNorm(Expander):
"""LayerNorm expander"""
def to_frac_z_axis(self, ori_shape, ori_axis):
"""
judge the format is fractal NZ
Parameters
----------
ori_shape: list or tuple
original shape of input
ori_axis: list or tuple
original axis of original shape to operate
Returns
-------
output: list
axis of the fractal Nz shape
"""
frac_z_axis = list(ori_axis)
shape_len = len(ori_shape)
axis_count = len(frac_z_axis)
axis_negative_1 = shape_len - 1
axis_negative_2 = shape_len - 2
for i in range(axis_count):
axis_index = (frac_z_axis[i] + shape_len) % shape_len
if axis_index == axis_negative_1:
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
else: # no case cover this branch now
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
elif axis_index == axis_negative_2:
frac_z_axis[i] = axis_index + 1
frac_z_axis.append(axis_index + 2)
else:
frac_z_axis[i] = axis_index
return frac_z_axis
def infer_shape_from_fractalNz(self, fractal):
"get original shape from fractalNz shape"
shape = []
dims = len(fractal)
batch = dims - 4
for i in range(batch):
shape.append(fractal[i])
m = fractal[dims - 3] * fractal[dims - 2]
n = fractal[dims - 4] * fractal[dims - 1]
shape.append(m)
shape.append(n)
return shape
def get_reduced_ori_shape(self, shape, axis):
"get shape after reduced which is based on original shape"
reduced_ori_shape = []
for i, value in enumerate(shape):
if i in axis:
reduced_ori_shape.append(1)
else:
reduced_ori_shape.append(value)
return reduced_ori_shape
def _expand(self, graph_builder):
input_x, input_gamma, input_beta = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis']
epsilon = self.attrs['epsilon']
ori_dtype = input_x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
ori_shape_x = input_x.shape
if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape)
# Calculate the scaling ratio of the average
if begin_norm_axis < 0:
begin_norm_axis += len(input_x.shape)
begin_norm_axis += len(ori_shape_x)
reduce_axis = ()
for i, _ in enumerate(input_x.shape):
for i, _ in enumerate(ori_shape_x):
if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= input_x.shape[i]
reduce_elts *= ori_shape_x[i]
if input_x.data_format == DF.FRAC_NZ:
reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis)
ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced
mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
# Calculate mean
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean_red = graph_builder.emit('ReduceSum', [input_x],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x})
# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
@ -51,6 +136,8 @@ class LayerNorm(Expander):
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x})
# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
@ -60,7 +147,11 @@ class LayerNorm(Expander):
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
res = graph_builder.emit('Add', [scale_mul, input_beta])
if processor == 'aicore' and ori_dtype == 'float16':
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
return res, mean, variance

@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
def _expand(self, graph_builder):
x, dy, variance, mean, gamma = self.inputs
processor = self.processor
begin_norm_axis = self.attrs['begin_norm_axis']
begin_params_axis = self.attrs['begin_params_axis']
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
ori_dtype = x.dtype
if processor == 'aicore' and ori_dtype == 'float16':
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
if begin_norm_axis < 0:
begin_norm_axis += len(x.shape)
if begin_params_axis < 0:
begin_params_axis += len(x.shape)
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0
for i in norm_axis:
reduce_size *= x.shape[i]
# set some constant val.
eps = graph_builder.value(x.dtype, epsilon)
const_one = graph_builder.value(x.dtype, 1.0)
const_neg_half = graph_builder.value(x.dtype, -0.5)
const_neg_two = graph_builder.value(x.dtype, -2.0)
const_two = graph_builder.value(x.dtype, 2.0)
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
# cal dg db
var_eps = graph_builder.emit('Add', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
var_eps_log = graph_builder.emit('Log', [var_eps])
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
x_sub_mean = graph_builder.emit('Sub', [x, mean])
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# cal sum_1
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
# pd_var
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_2
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal sum_3
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# cal dx = dx1 + dx2 + dx3
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
# pd_mean
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
# cal dx
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
if processor == 'aicore' and ori_dtype == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
return dx, dg, db

@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSquare,
prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
#if ENABLE_D
prim::kPrimTile,
prim::kPrimSqrtGrad,
@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimDropout,
prim::kPrimDropoutGrad,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimRelu,
prim::kPrimReluGrad,
prim::kPrimSigmoid,

@ -54,7 +54,6 @@
#include "debug/data_dump/dump_json_parser.h"
#include "debug/tensor_load.h"
#include "debug/anf_ir_utils.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/ascend_auto_monad.h"
#include "debug/data_dump/e2e_dump_util.h"

@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False):
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
if enable_graph_kernel:
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
12.185522, 12.386192]
else:
assert np.allclose(loss_value[0], 12.2066, 0, 0.0005)
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
12.407923, 12.631133]
print("loss value: {}".format(loss_value))

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save