|
|
|
@ -165,85 +165,118 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ClsLoss"),
|
|
|
|
|
"Input(ClsLoss) of MineHardExamplesOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("MatchIndices"),
|
|
|
|
|
"Input(MatchIndices) of MineHardExamplesOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("MatchDist"),
|
|
|
|
|
"Input(MatchDist) of MineHardExamplesOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("NegIndices"),
|
|
|
|
|
"Output(NegIndices) of MineHardExamplesOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("UpdatedMatchIndices"),
|
|
|
|
|
"Output(UpdatedMatchIndices) of MineHardExamplesOp should "
|
|
|
|
|
"not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("ClsLoss"), "Input", "ClsLoss",
|
|
|
|
|
"mine_hard_examples");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("MatchIndices"), "Input", "MatchIndices",
|
|
|
|
|
"mine_hard_examples");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("MatchDist"), "Input", "MatchDist",
|
|
|
|
|
"mine_hard_examples");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("NegIndices"), "Output", "NegIndices",
|
|
|
|
|
"mine_hard_examples");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("UpdatedMatchIndices"), "Output",
|
|
|
|
|
"UpdatedMatchIndices", "mine_hard_examples");
|
|
|
|
|
|
|
|
|
|
auto cls_loss_dims = ctx->GetInputDim("ClsLoss");
|
|
|
|
|
auto idx_dims = ctx->GetInputDim("MatchIndices");
|
|
|
|
|
auto dis_dims = ctx->GetInputDim("MatchDist");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL,
|
|
|
|
|
"The shape of ClsLoss is [N, Np].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL,
|
|
|
|
|
"The shape of MatchIndices is [N, Np].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of ClsLoss is [N, Np]. But received %d.",
|
|
|
|
|
cls_loss_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
idx_dims.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of MatchIndices is [N, Np]. But received %d.",
|
|
|
|
|
idx_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL,
|
|
|
|
|
"The shape of MatchDist is [N, Np].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of MatchDist is [N, Np]. But received %d.",
|
|
|
|
|
dis_dims.size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("LocLoss")) {
|
|
|
|
|
auto loc_loss_dims = ctx->GetInputDim("LocLoss");
|
|
|
|
|
PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL,
|
|
|
|
|
"The shape of LocLoss is [N, Np].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of LocLoss is [N, Np]. But received %d.",
|
|
|
|
|
loc_loss_dims.size()));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[0], loc_loss_dims[0],
|
|
|
|
|
"Batch size of ClsLoss and LocLoss must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[1], loc_loss_dims[1],
|
|
|
|
|
"Prior box number of ClsLoss and LocLoss must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batch size of ClsLoss and LocLoss must be the "
|
|
|
|
|
"same. But received batch size of ClsLoss was "
|
|
|
|
|
"%d, batch size of LocLoss was %d.",
|
|
|
|
|
cls_loss_dims[0], loc_loss_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims[1], loc_loss_dims[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Prior box number of ClsLoss and LocLoss must be "
|
|
|
|
|
"the same. But received box number of ClsLoss "
|
|
|
|
|
"was %d, box number of LocLoss was %d.",
|
|
|
|
|
cls_loss_dims[1], loc_loss_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[0], idx_dims[0],
|
|
|
|
|
"Batch size of ClsLoss and MatchIndices must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[1], idx_dims[1],
|
|
|
|
|
"Prior box number of ClsLoss and MatchIndices must be the same.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[0], dis_dims[0],
|
|
|
|
|
"Batch size of ClsLoss and MatchDist must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims[0], idx_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batch size of ClsLoss and MatchIndices must be "
|
|
|
|
|
"the same. But received batch size of ClsLoss was "
|
|
|
|
|
"%d, batch size of MatchIndices was %d.",
|
|
|
|
|
cls_loss_dims[0], idx_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cls_loss_dims[1], idx_dims[1],
|
|
|
|
|
"Prior box number of ClsLoss and MatchDist must be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Prior box number of ClsLoss and "
|
|
|
|
|
"MatchIndices must be the same. But received box number of "
|
|
|
|
|
"ClsLoss was %d, box number of MatchIndices was %d.",
|
|
|
|
|
cls_loss_dims[1], idx_dims[1]));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batch size of ClsLoss and MatchDist must be the "
|
|
|
|
|
"same. But received batch size of ClsLoss was %d, "
|
|
|
|
|
"batch size of MatchDist was %d.",
|
|
|
|
|
cls_loss_dims[0], dis_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(cls_loss_dims[1], idx_dims[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Prior box number of ClsLoss and MatchDist must be "
|
|
|
|
|
"the same. But received box number of ClsLoss was "
|
|
|
|
|
"%d, box number of MatchDist was %d.",
|
|
|
|
|
cls_loss_dims[1], idx_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto mining_type =
|
|
|
|
|
GetMiningType(ctx->Attrs().Get<std::string>("mining_type"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(mining_type, MiningType::kNone,
|
|
|
|
|
"mining_type must be hard_example or max_negative");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mining_type must be hard_example or max_negative"));
|
|
|
|
|
|
|
|
|
|
if (mining_type == MiningType::kMaxNegative) {
|
|
|
|
|
auto neg_pos_ratio = ctx->Attrs().Get<float>("neg_pos_ratio");
|
|
|
|
|
auto neg_dist_threshold = ctx->Attrs().Get<float>("neg_dist_threshold");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
neg_pos_ratio, 0.0f,
|
|
|
|
|
"neg_pos_ratio must greater than zero in max_negative mode");
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
neg_dist_threshold, 1.0f,
|
|
|
|
|
"neg_dist_threshold must less than one in max_negative mode");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
neg_dist_threshold, 0.0f,
|
|
|
|
|
"neg_dist_threshold must greater than zero in max_negative mode");
|
|
|
|
|
PADDLE_ENFORCE_GT(neg_pos_ratio, 0.0f,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"neg_pos_ratio must greater than zero in "
|
|
|
|
|
"max_negative mode. But received %f.",
|
|
|
|
|
neg_pos_ratio));
|
|
|
|
|
PADDLE_ENFORCE_LT(neg_dist_threshold, 1.0f,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"neg_dist_threshold must less than one in "
|
|
|
|
|
"max_negative mode. But received %f.",
|
|
|
|
|
neg_dist_threshold));
|
|
|
|
|
PADDLE_ENFORCE_GT(neg_dist_threshold, 0.0f,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"neg_dist_threshold must greater "
|
|
|
|
|
"than zero in max_negative mode. But received %f.",
|
|
|
|
|
neg_dist_threshold));
|
|
|
|
|
} else if (mining_type == MiningType::kHardExample) {
|
|
|
|
|
auto sample_size = ctx->Attrs().Get<int>("sample_size");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
sample_size, 0,
|
|
|
|
|
"sample_size must greater than zero in hard_example mode");
|
|
|
|
|
PADDLE_ENFORCE_GT(sample_size, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"sample_size must greater than zero in "
|
|
|
|
|
"hard_example mode. But received %d.",
|
|
|
|
|
sample_size));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("UpdatedMatchIndices", idx_dims);
|
|
|
|
|