Fix CPU implementation of roi_align_op backward (#18728)

DDDivano-patch-1
qingqing01 6 years ago committed by GitHub
parent 70b03760fd
commit 3429e65aa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, in_grad, static_cast<T>(0));
auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims());

Loading…
Cancel
Save