|
|
|
@ -27,6 +27,8 @@ class SppKernel : public framework::OpKernel<T> {
|
|
|
|
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
int pyramid_height = context.template Attr<int>("pyramid_height");
|
|
|
|
|
std::string pooling_type =
|
|
|
|
|
context.template Attr<std::string>("pooling_type");
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto out_stride = framework::stride(out->dims());
|
|
|
|
|
int input_h = in_x->dims()[2];
|
|
|
|
@ -48,10 +50,17 @@ class SppKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::DDim output_shape(framework::make_ddim(output_shape_vec));
|
|
|
|
|
out_level.mutable_data<T>(output_shape, context.GetPlace());
|
|
|
|
|
// pooling
|
|
|
|
|
math::Pool2dFunctor<DeviceContext, math::MaxPool<T>, T> pool_forward;
|
|
|
|
|
math::MaxPool<T> max_process;
|
|
|
|
|
pool_forward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
kernel_size, strides, paddings, max_process, &out_level);
|
|
|
|
|
if (pooling_type == "max") {
|
|
|
|
|
math::Pool2dFunctor<DeviceContext, math::MaxPool<T>, T> pool_forward;
|
|
|
|
|
math::MaxPool<T> max_process;
|
|
|
|
|
pool_forward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
kernel_size, strides, paddings, max_process, &out_level);
|
|
|
|
|
} else if (pooling_type == "avg") {
|
|
|
|
|
math::Pool2dFunctor<DeviceContext, math::AvgPool<T>, T> pool_forward;
|
|
|
|
|
math::AvgPool<T> avg_process;
|
|
|
|
|
pool_forward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
kernel_size, strides, paddings, avg_process, &out_level);
|
|
|
|
|
}
|
|
|
|
|
// flatten pooling output shape
|
|
|
|
|
int output_flatten_w = in_x->dims()[1] * bins * bins;
|
|
|
|
|
std::vector<int64_t> output_flatten_shape_vec(
|
|
|
|
@ -79,6 +88,8 @@ class SppGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::Tensor* in_x_grad =
|
|
|
|
|
context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
int pyramid_height = context.template Attr<int>("pyramid_height");
|
|
|
|
|
std::string pooling_type =
|
|
|
|
|
context.template Attr<std::string>("pooling_type");
|
|
|
|
|
auto& device_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
math::SetConstant<DeviceContext, T> zero;
|
|
|
|
|
in_x_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
@ -130,10 +141,19 @@ class SppGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
outgrad_level.ShareDataWith(outgrad_level);
|
|
|
|
|
outgrad_level.Resize(out_shape);
|
|
|
|
|
// pooling backward
|
|
|
|
|
math::MaxPool2dGradFunctor<DeviceContext, T> pool2d_backward;
|
|
|
|
|
pool2d_backward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
if (pooling_type == "max") {
|
|
|
|
|
math::MaxPool2dGradFunctor<DeviceContext, T> pool2d_backward;
|
|
|
|
|
pool2d_backward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
*&out_level, *&outgrad_level, kernel_size, strides,
|
|
|
|
|
paddings, in_x_grad);
|
|
|
|
|
} else if (pooling_type == "avg") {
|
|
|
|
|
math::Pool2dGradFunctor<DeviceContext, math::AvgPoolGrad<T>, T>
|
|
|
|
|
pool_backward;
|
|
|
|
|
math::AvgPoolGrad<T> avg_process;
|
|
|
|
|
pool_backward(context.template device_context<DeviceContext>(), *in_x,
|
|
|
|
|
*&out_level, *&outgrad_level, kernel_size, strides,
|
|
|
|
|
paddings, in_x_grad);
|
|
|
|
|
paddings, avg_process, in_x_grad);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|