|
|
|
@ -23,19 +23,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of Yolov3LossOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("GTBox"),
|
|
|
|
|
"Input(GTBox) of Yolov3LossOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("GTLabel"),
|
|
|
|
|
"Input(GTLabel) of Yolov3LossOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
|
|
|
|
|
"Output(Loss) of Yolov3LossOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("ObjectnessMask"),
|
|
|
|
|
"Output(ObjectnessMask) of Yolov3LossOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"),
|
|
|
|
|
"Output(GTMatchMask) of Yolov3LossOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Yolov3LossOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("GTBox"), "Input", "GTBox", "Yolov3LossOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("GTLabel"), "Input", "GTLabel",
|
|
|
|
|
"Yolov3LossOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "Yolov3LossOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ObjectnessMask"), "Output", "ObjectnessMask",
|
|
|
|
|
"Yolov3LossOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("GTMatchMask"), "Output", "GTMatchMask",
|
|
|
|
|
"Yolov3LossOp");
|
|
|
|
|
|
|
|
|
|
auto dim_x = ctx->GetInputDim("X");
|
|
|
|
|
auto dim_gtbox = ctx->GetInputDim("GTBox");
|
|
|
|
@ -46,44 +42,96 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
|
|
|
|
|
int mask_num = anchor_mask.size();
|
|
|
|
|
auto class_num = ctx->Attrs().Get<int>("class_num");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) should be a 4-D tensor. But received "
|
|
|
|
|
"X dimension size(%s)",
|
|
|
|
|
dim_x.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
|
|
|
|
|
"Input(X) dim[3] and dim[4] should be euqal.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) dim[3] and dim[4] should be euqal."
|
|
|
|
|
"But received dim[3](%s) != dim[4](%s)",
|
|
|
|
|
dim_x[2], dim_x[3]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_x[1], mask_num * (5 + class_num),
|
|
|
|
|
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
|
|
|
|
|
"+ class_num)).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
|
|
|
|
|
"+ class_num))."
|
|
|
|
|
"But received dim[1](%s) != (anchor_mask_number * "
|
|
|
|
|
"(5+class_num)(%s).",
|
|
|
|
|
dim_x[1], mask_num * (5 + class_num)));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
|
|
|
|
|
"Input(GTBox) should be a 3-D tensor");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2,
|
|
|
|
|
"Input(GTLabel) should be a 2-D tensor");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0],
|
|
|
|
|
"Input(GTBox) and Input(GTLabel) dim[0] should be same");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
|
|
|
|
|
"Input(GTBox) and Input(GTLabel) dim[1] should be same");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) should be a 3-D tensor, but "
|
|
|
|
|
"received gtbox dimension size(%s)",
|
|
|
|
|
dim_gtbox.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtbox[2], 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) dim[2] should be 4",
|
|
|
|
|
"But receive dim[2](%s) != 5. ", dim_gtbox[2]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_gtlabel.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTLabel) should be a 2-D tensor,"
|
|
|
|
|
"But received Input(GTLabel) dimension size(%s) != 2.",
|
|
|
|
|
dim_gtlabel.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_gtlabel[0], dim_gtbox[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) dim[0] and Input(GTLabel) dim[0] should be same,"
|
|
|
|
|
"But received Input(GTLabel) dim[0](%s) != "
|
|
|
|
|
"Input(GTBox) dim[0](%s)",
|
|
|
|
|
dim_gtlabel[0], dim_gtbox[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_gtlabel[1], dim_gtbox[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) and Input(GTLabel) dim[1] should be same,"
|
|
|
|
|
"But received Input(GTBox) dim[1](%s) != Input(GTLabel) "
|
|
|
|
|
"dim[1](%s)",
|
|
|
|
|
dim_gtbox[1], dim_gtlabel[1]));
|
|
|
|
|
PADDLE_ENFORCE_GT(anchors.size(), 0,
|
|
|
|
|
"Attr(anchors) length should be greater then 0.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(anchors) length should be greater then 0."
|
|
|
|
|
"But received anchors length(%s)",
|
|
|
|
|
anchors.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
|
|
|
|
|
"Attr(anchors) length should be even integer.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(anchors) length should be even integer."
|
|
|
|
|
"But received anchors length(%s)",
|
|
|
|
|
anchors.size()));
|
|
|
|
|
for (size_t i = 0; i < anchor_mask.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
anchor_mask[i], anchor_num,
|
|
|
|
|
"Attr(anchor_mask) should not crossover Attr(anchors).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(anchor_mask) should not crossover Attr(anchors)."
|
|
|
|
|
"But received anchor_mask[i](%s) > anchor_num(%s)",
|
|
|
|
|
anchor_mask[i], anchor_num));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_GT(class_num, 0,
|
|
|
|
|
"Attr(class_num) should be an integer greater then 0.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(class_num) should be an integer greater then 0."
|
|
|
|
|
"But received class_num(%s) < 0",
|
|
|
|
|
class_num));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("GTScore")) {
|
|
|
|
|
auto dim_gtscore = ctx->GetInputDim("GTScore");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2,
|
|
|
|
|
"Input(GTScore) should be a 2-D tensor");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTScore) should be a 2-D tensor"
|
|
|
|
|
"But received GTScore dimension(%s)",
|
|
|
|
|
dim_gtbox.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_gtscore[0], dim_gtbox[0],
|
|
|
|
|
"Input(GTBox) and Input(GTScore) dim[0] should be same");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) and Input(GTScore) dim[0] should be same"
|
|
|
|
|
"But received GTBox dim[0](%s) != GTScore dim[0](%s)",
|
|
|
|
|
dim_gtbox[0], dim_gtscore[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_gtscore[1], dim_gtbox[1],
|
|
|
|
|
"Input(GTBox) and Input(GTScore) dim[1] should be same");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(GTBox) and Input(GTScore) dim[1] should be same"
|
|
|
|
|
"But received GTBox dim[1](%s) != GTScore dim[1](%s)",
|
|
|
|
|
dim_gtscore[1], dim_gtbox[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> dim_out({dim_x[0]});
|
|
|
|
@ -245,9 +293,12 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
|
|
|
|
|
"Input(Loss@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::NotFound("Input(X) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Loss")), true,
|
|
|
|
|
platform::errors::NotFound("Input(Loss@GRAD) should not be null"));
|
|
|
|
|
auto dim_x = ctx->GetInputDim("X");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
|
|
|
|