|
|
|
@ -99,10 +99,10 @@ class NormGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, 3> shape(pre, n, post);
|
|
|
|
|
Eigen::DSizes<int, 2> norm_shape(pre, post);
|
|
|
|
|
Eigen::DSizes<int, 3> rshape(pre, 1, post);
|
|
|
|
|
auto x = x_e.reshape(shape);
|
|
|
|
|
auto dy = dy_e.reshape(shape);
|
|
|
|
|
auto norm = norm_e.reshape(norm_shape);
|
|
|
|
|
auto norm = norm_e.reshape(rshape);
|
|
|
|
|
auto dx = dx_e.reshape(shape);
|
|
|
|
|
|
|
|
|
|
framework::Tensor rsum;
|
|
|
|
@ -111,7 +111,6 @@ class NormGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, 1> rdim(1);
|
|
|
|
|
Eigen::DSizes<int, 3> bcast(1, n, 1);
|
|
|
|
|
Eigen::DSizes<int, 3> rshape(pre, 1, post);
|
|
|
|
|
|
|
|
|
|
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
|
|
|
|
|
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
|
|
|
|
|