fix unit test

emailweixu-patch-1
chengduoZH 7 years ago
parent ca0177190f
commit ae0ea54159

@ -233,13 +233,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product = [](T ele) { return ele * ele * ele; };
auto neg_inv_std = [](T ele) { return T(-1.0) * std::sqrt(1 / ele); };
auto triple_product = [](T ele) { return ele * ele; };
auto neg_inv_std = [](T ele) { return -std::sqrt(1 / ele); };
auto inv_std_scale_func = [scale_data](T ele) {
return std::sqrt(1 / ele) * scale_data;
};
auto neg_inv_std_scale_func = [scale_data](T ele) {
return T(-1.0) * std::sqrt(1 / ele) * scale_data;
return -std::sqrt(1 / ele) * scale_data;
};
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_scale_func)
@ -260,10 +260,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
auto dvar_end = var_map.unaryExpr(neg_inv_std)
.unaryExpr(triple_product)
.cwiseProduct(dvar_end_0);
auto dx_var = (1.0f / right) *
auto dx_var = (T(1.0) / right) *
(x_map - mean_map.replicate(1, right))
.cwiseProduct(dvar_end.replicate(1, right));
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
d_x_map = dx_end + dx_mean + dx_var;
}
}

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save