Fix bp of roi perspective transform op. (#17216)

revert-17304-fix_default_paddle_version
whs 6 years ago committed by qingqing01
parent 7bd1d03ee5
commit 7d7e29957f

@ -466,6 +466,10 @@ class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
const T* out_grad_data = out_grad->data<T>();
const int* out2in_idx_data = out2in_idx->data<int>();
const T* out2in_w_data = out2in_w->data<T>();

Loading…
Cancel
Save