|
|
|
@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (prior_box_var_dims.size() == 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
prior_box_var_dims[0], 4,
|
|
|
|
|
"The 1st dimension of Input(PriorBoxVar) should be 1"
|
|
|
|
|
"The 1st dimension of Input(PriorBoxVar) should be 4"
|
|
|
|
|
"when the rank is 1.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel {
|
|
|
|
|
"the dimension of Input(PriorBox when the rank is 2.)");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto code_type =
|
|
|
|
|
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
if (code_type == BoxCodeType::kEncodeCenterSize) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
|
|
|
|
|
"The rank of Input of TargetBox must be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
|
|
|
|
|
"The shape of TargetBox is [M, 4]");
|
|
|
|
|
ctx->SetOutputDim(
|
|
|
|
|
"OutputBox",
|
|
|
|
|
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
|
|
|
|
|
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
|
|
|
|
|
"The rank of Input of TargetBox must be 3");
|
|
|
|
|
if (axis == 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
|
|
|
|
|
} else if (axis == 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("axis must be 0 or 1.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
|
|
|
|
|
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
|
|
|
|
|
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
|
|
|
|
|
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
|
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
if (code_type == BoxCodeType::kEncodeCenterSize) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
|
|
|
|
|
"The rank of Input of TargetBox must be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
|
|
|
|
|
"The shape of TargetBox is [M, 4]");
|
|
|
|
|
ctx->SetOutputDim(
|
|
|
|
|
"OutputBox",
|
|
|
|
|
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
|
|
|
|
|
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
|
|
|
|
|
"The rank of Input of TargetBox must be 3");
|
|
|
|
|
if (axis == 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
|
|
|
|
|
} else if (axis == 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
|
|
|
|
|
} else {
|
|
|
|
|
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
|
|
|
|
|
PADDLE_THROW("axis must be 0 or 1.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
|
|
|
|
|
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
|
|
|
|
|
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
|
|
|
|
|
} else {
|
|
|
|
|
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|