|
|
|
@ -139,10 +139,8 @@ class PoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
if (in_x_grad) {
|
|
|
|
|
in_x_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
|
|
|
|
|
temp.device(
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device()) =
|
|
|
|
|
temp.constant(static_cast<T>(0));
|
|
|
|
|
paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
|
|
|
|
|
set_constant(dev_ctx, in_x_grad, 0.0);
|
|
|
|
|
|
|
|
|
|
switch (ksize.size()) {
|
|
|
|
|
case 2: {
|
|
|
|
|