|
|
|
@ -24,23 +24,33 @@ class LayerNormGrad(Expander):
|
|
|
|
|
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
x, dy, variance, mean, gamma = self.inputs
|
|
|
|
|
processor = self.processor
|
|
|
|
|
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
|
|
|
begin_params_axis = self.attrs['begin_params_axis']
|
|
|
|
|
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
|
|
|
|
|
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
|
|
|
|
|
|
|
|
|
ori_dtype = x.dtype
|
|
|
|
|
if processor == 'aicore' and ori_dtype == 'float16':
|
|
|
|
|
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
|
|
|
|
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
|
|
|
|
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
|
|
|
|
|
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
|
|
|
|
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
|
|
|
|
|
|
|
|
|
if begin_norm_axis < 0:
|
|
|
|
|
begin_norm_axis += len(x.shape)
|
|
|
|
|
if begin_params_axis < 0:
|
|
|
|
|
begin_params_axis += len(x.shape)
|
|
|
|
|
|
|
|
|
|
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
|
|
|
|
param_axis = tuple(range(0, begin_params_axis))
|
|
|
|
|
|
|
|
|
|
reduce_size = 1.0
|
|
|
|
|
for i in norm_axis:
|
|
|
|
|
reduce_size *= x.shape[i]
|
|
|
|
|
|
|
|
|
|
# set some constant val.
|
|
|
|
|
eps = graph_builder.value(x.dtype, epsilon)
|
|
|
|
|
const_one = graph_builder.value(x.dtype, 1.0)
|
|
|
|
|
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
|
|
|
|
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
|
|
|
|
const_two = graph_builder.value(x.dtype, 2.0)
|
|
|
|
@ -49,42 +59,55 @@ class LayerNormGrad(Expander):
|
|
|
|
|
|
|
|
|
|
# cal dg db
|
|
|
|
|
var_eps = graph_builder.emit('Add', [variance, eps])
|
|
|
|
|
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
|
|
|
|
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
|
|
|
|
var_eps_log = graph_builder.emit('Log', [var_eps])
|
|
|
|
|
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
|
|
|
|
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
|
|
|
|
|
|
|
|
|
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
|
|
|
|
|
|
|
|
|
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
|
|
|
|
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
|
|
|
|
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
|
|
|
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
|
|
|
|
|
|
|
|
# cal sum_1
|
|
|
|
|
tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps])
|
|
|
|
|
r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps])
|
|
|
|
|
x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps])
|
|
|
|
|
# pd_var
|
|
|
|
|
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
|
|
|
|
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
|
|
|
|
|
|
|
|
|
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
|
|
|
|
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps])
|
|
|
|
|
sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul])
|
|
|
|
|
sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
|
|
|
|
|
# cal sum_2
|
|
|
|
|
sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
|
|
|
|
|
# cal sum_3
|
|
|
|
|
sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
|
|
|
|
sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
|
|
|
|
|
# cal dx = dx1 + dx2 + dx3
|
|
|
|
|
dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
|
|
|
sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two])
|
|
|
|
|
sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof])
|
|
|
|
|
dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean])
|
|
|
|
|
neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps])
|
|
|
|
|
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
|
|
|
|
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
|
|
|
|
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
|
|
|
|
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
|
|
|
|
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
|
|
|
|
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
|
|
|
|
|
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
|
|
|
|
|
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
|
|
|
|
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
|
|
|
|
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
|
|
|
|
|
|
|
|
|
# pd_mean
|
|
|
|
|
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
|
|
|
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
|
|
|
|
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
|
|
|
|
|
|
|
|
|
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
|
|
|
|
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
|
|
|
|
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
|
|
|
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
|
|
|
|
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
|
|
|
|
|
|
|
|
|
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
|
|
|
|
|
|
|
|
|
# cal dx
|
|
|
|
|
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
|
|
|
|
|
|
|
|
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
|
|
|
|
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
|
|
|
|
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
|
|
|
|
|
|
|
|
|
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
|
|
|
|
|
|
|
|
|
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
|
|
|
|
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
|
|
|
|
|
|
|
|
|
if processor == 'aicore' and ori_dtype == 'float16':
|
|
|
|
|
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
|
|
|
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
|
|
|
|
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
|
|
|
|
return dx, dg, db
|
|
|
|
|