|
|
|
@ -311,8 +311,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
auto dy_arr = dy_e.reshape(shape);
|
|
|
|
|
auto x_arr = x_e.reshape(shape);
|
|
|
|
|
|
|
|
|
|
auto tmp =
|
|
|
|
|
(x_arr - mean_arr.broadcast(bcast)) * inv_var_arr.broadcast(bcast);
|
|
|
|
|
auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) *
|
|
|
|
|
inv_var_arr.eval().broadcast(bcast);
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
|
|
|
|
|
// math: d_bias = np.sum(d_y, axis=(n,h,w))
|
|
|
|
@ -333,7 +333,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
(tmp * dy_arr).sum(mean_rdims).reshape(param_shape).sum(rdims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dy_mean = dy_arr.mean(mean_rdims).reshape(NxC_shape).broadcast(bcast);
|
|
|
|
|
auto dy_mean =
|
|
|
|
|
dy_arr.mean(mean_rdims).reshape(NxC_shape).eval().broadcast(bcast);
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_param(N, sample_size);
|
|
|
|
|
set_constant(dev_ctx, d_x, static_cast<T>(0));
|
|
|
|
@ -351,6 +352,7 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
(dy_arr * tmp)
|
|
|
|
|
.mean(mean_rdims)
|
|
|
|
|
.reshape(NxC_shape)
|
|
|
|
|
.eval()
|
|
|
|
|
.broadcast(bcast));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|