|
|
|
@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<framework::Tensor>("ROIs");
|
|
|
|
|
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
|
|
|
|
|
|
|
|
|
|
auto* out_grad =
|
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* x_grad =
|
|
|
|
|
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
auto pooled_height = ctx.Attr<int>("pooled_height");
|
|
|
|
|
auto pooled_width = ctx.Attr<int>("pooled_width");
|
|
|
|
|
|
|
|
|
|
if (x_grad) {
|
|
|
|
|
int channels = in->dims()[1];
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
auto roi_stride = framework::stride(rois->dims());
|
|
|
|
|
|
|
|
|
|
if (in_grad) {
|
|
|
|
|
const int64_t* rois_data = rois->data<int64_t>();
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
|
|
|
|
|
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
const T* out_grad_data = out_grad->data<T>();
|
|
|
|
|
const int64_t* argmax_data = argmax->data<int64_t>();
|
|
|
|
|
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
|
|
|
|
|
set_zero(ctx.device_context(), in_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
size_t roi_offset = roi_stride[0];
|
|
|
|
|
size_t batch_offset = in_stride[0];
|
|
|
|
|
size_t channel_offset = in_stride[1];
|
|
|
|
|
auto in_stride = framework::stride(in->dims());
|
|
|
|
|
auto argmax_stride = framework::stride(argmax->dims());
|
|
|
|
|
auto roi_stride = framework::stride(rois->dims());
|
|
|
|
|
auto out_stride = framework::stride(out_grad->dims());
|
|
|
|
|
|
|
|
|
|
const T* out_grad_data = out_grad->data<T>();
|
|
|
|
|
size_t pool_channel_offset = pooled_height * pooled_width;
|
|
|
|
|
const int64_t* argmax_data = argmax->data<int64_t>();
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
int channels = in->dims()[1];
|
|
|
|
|
|
|
|
|
|
for (size_t n = 0; n < rois_num; ++n) {
|
|
|
|
|
size_t roi_batch_idx = rois_data[0];
|
|
|
|
|
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx;
|
|
|
|
|
for (int n = 0; n < rois_num; ++n) {
|
|
|
|
|
int roi_batch_idx = rois_data[0];
|
|
|
|
|
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
|
|
|
|
|
for (int c = 0; c < channels; ++c) {
|
|
|
|
|
for (int ph = 0; ph < pooled_height; ++ph) {
|
|
|
|
|
for (int pw = 0; pw < pooled_width; ++pw) {
|
|
|
|
|
size_t pool_index = ph * pooled_width + pw;
|
|
|
|
|
|
|
|
|
|
int pool_index = ph * pooled_width + pw;
|
|
|
|
|
if (argmax_data[pool_index] >= 0) {
|
|
|
|
|
size_t index = static_cast<size_t>(argmax_data[pool_index]);
|
|
|
|
|
auto index = argmax_data[pool_index];
|
|
|
|
|
batch_grad_data[index] += out_grad_data[pool_index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
batch_grad_data += channel_offset;
|
|
|
|
|
out_grad_data += pool_channel_offset;
|
|
|
|
|
argmax_data += pool_channel_offset;
|
|
|
|
|
batch_grad_data += in_stride[1];
|
|
|
|
|
out_grad_data += out_stride[1];
|
|
|
|
|
argmax_data += argmax_stride[1];
|
|
|
|
|
}
|
|
|
|
|
rois_data += roi_offset;
|
|
|
|
|
rois_data += roi_stride[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|