|
|
|
@ -26,39 +26,58 @@ class ROIPoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of ROIPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
|
|
|
|
|
"Input(ROIs) of ROIPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ROIPoolOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Argmax"),
|
|
|
|
|
"Output(Argmax) of ROIPoolOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto rois_dims = ctx->GetInputDim("ROIs");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("RoisLod")) {
|
|
|
|
|
auto rois_lod_dims = ctx->GetInputDim("RoisLod");
|
|
|
|
|
PADDLE_ENFORCE(rois_lod_dims.size() == 1, "");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4,
|
|
|
|
|
"The format of input tensor is NCHW.");
|
|
|
|
|
PADDLE_ENFORCE(rois_dims.size() == 2,
|
|
|
|
|
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
|
|
|
|
|
"given as [[x1, y1, x2, y2], ...].");
|
|
|
|
|
PADDLE_ENFORCE(rois_dims[1] == kROISize,
|
|
|
|
|
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
|
|
|
|
|
"given as [[x1, y1, x2, y2], ...].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input data should be a four-dimensional "
|
|
|
|
|
"tensor with [N,C,H,W], but received input data with "
|
|
|
|
|
" %d dimension",
|
|
|
|
|
input_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
|
|
|
|
|
"given as [[x1, y1, x2, y2], ...], but received ROIs is "
|
|
|
|
|
"%d-dimensional LoDTensor",
|
|
|
|
|
rois_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_dims[1], kROISize,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
|
|
|
|
|
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
|
|
|
|
|
"the received data is %d",
|
|
|
|
|
rois_dims[1]));
|
|
|
|
|
|
|
|
|
|
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
|
|
|
|
|
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
|
|
|
|
|
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_height, 0,
|
|
|
|
|
"The pooled output height must greater than 0");
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"The pooled output height must be greater than 0"
|
|
|
|
|
"but received height is %d",
|
|
|
|
|
pooled_height));
|
|
|
|
|
PADDLE_ENFORCE_GT(pooled_width, 0,
|
|
|
|
|
"The pooled output width must greater than 0");
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"The pooled output width must be greater than 0"
|
|
|
|
|
"but received width is %d",
|
|
|
|
|
pooled_width));
|
|
|
|
|
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
|
|
|
|
|
"The spatial scale must greater than 0");
|
|
|
|
|
platform::errors::OutOfRange(
|
|
|
|
|
"The spatial scale must be greater than 0, "
|
|
|
|
|
"but received spatial scale is %f",
|
|
|
|
|
spatial_scale));
|
|
|
|
|
|
|
|
|
|
auto out_dims = input_dims;
|
|
|
|
|
out_dims[0] = rois_dims[0];
|
|
|
|
@ -84,10 +103,10 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"The gradient of Out should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
|
|
|
|
|
"The gradient of X should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "roi_pool");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
|
framework::GradVarName("X"), "roi_pool");
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|