|
|
|
@ -21,22 +21,36 @@ class BoxClipOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of BoxClipOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
|
|
|
|
|
"Input(ImInfo) of BoxClipOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
|
|
|
|
platform::errors::NotFound("Input(Input) of BoxClipOp "
|
|
|
|
|
"is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("ImInfo"), true,
|
|
|
|
|
platform::errors::NotFound("Input(ImInfo) of BoxClipOp "
|
|
|
|
|
"is not found."));
|
|
|
|
|
|
|
|
|
|
auto input_box_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto im_info_dims = ctx->GetInputDim("ImInfo");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
auto input_box_size = input_box_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
|
|
|
|
|
"The last dimension of Input must be 4");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_box_dims[input_box_size - 1], 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The last dimension "
|
|
|
|
|
"of Input must be 4. But received last dimension = %d",
|
|
|
|
|
input_box_dims[input_box_size - 1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
|
|
|
|
|
"The rank of Input(Input) in BoxClipOp must be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
|
|
|
|
|
"The last dimension of ImInfo must be 3");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of "
|
|
|
|
|
"Input(Input) in BoxClipOp must be 2. But received "
|
|
|
|
|
"rank = %d",
|
|
|
|
|
im_info_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
im_info_dims[1], 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The last dimension "
|
|
|
|
|
"of ImInfo must be 3. But received last dimension = %d",
|
|
|
|
|
im_info_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
ctx->ShareDim("Input", /*->*/ "Output");
|
|
|
|
|
ctx->ShareLoD("Input", /*->*/ "Output");
|
|
|
|
|