|
|
|
@ -50,12 +50,14 @@ class ROIPoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
|
|
|
|
|
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_height, 0,
|
|
|
|
|
"The pooled output height must greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_width, 0,
|
|
|
|
|
"The pooled output width must greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
|
|
|
|
|
"The spatial scale must greater than 0");
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_height, 0,
|
|
|
|
|
"The pooled output height must greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_width, 0,
|
|
|
|
|
"The pooled output width must greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
|
|
|
|
|
"The spatial scale must greater than 0");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_dims = input_dims;
|
|
|
|
|
out_dims[0] = rois_dims[0];
|
|
|
|
|