|
|
|
@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int width = in_dims[3];
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
|
|
|
|
|
if (rois_num == 0) return;
|
|
|
|
|
|
|
|
|
|
auto in_stride = framework::stride(in_dims);
|
|
|
|
|
auto roi_stride = framework::stride(rois->dims());
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const T* out_grad_data = out_grad->data<T>();
|
|
|
|
|
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
set_zero(dev_ctx, in_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
auto roi_stride = framework::stride(rois->dims());
|
|
|
|
|
auto out_stride = framework::stride(out_grad->dims());
|
|
|
|
|