|
|
|
@ -22,18 +22,30 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("MaxProbs"),
|
|
|
|
|
"Input(MaxProbs) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Indices"),
|
|
|
|
|
"Input(Indices) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
|
|
|
|
"Input(Labels) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"),
|
|
|
|
|
"Output(BatchMetrics) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AccumMetrics"),
|
|
|
|
|
"Output(AccumMetrics) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"),
|
|
|
|
|
"Output(AccumStatesInfo) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("MaxProbs"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Input(MaxProbs) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Indices"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Input(Indices) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Labels"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Input(Labels) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("BatchMetrics"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Output(BatchMetrics) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("AccumMetrics"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Output(AccumMetrics) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("AccumStatesInfo"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"PrecisionRecallOp Output(AccumStatesInfo) should not be null."));
|
|
|
|
|
|
|
|
|
|
int64_t cls_num =
|
|
|
|
|
static_cast<int64_t>(ctx->Attrs().Get<int>("class_number"));
|
|
|
|
@ -42,37 +54,61 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
|
|
|
|
|
"Each instance contains one max probability, so the "
|
|
|
|
|
"shape of Input(MaxProbs) should be [batch_size, 1].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each instance of PrecisionRecallOp "
|
|
|
|
|
"Input(MaxProbs) contains one max probability, "
|
|
|
|
|
"the shape of Input(MaxProbs) should be "
|
|
|
|
|
"[batch_size, 1], the 2nd dimension of "
|
|
|
|
|
"Input(MaxProbs) should be 1. But the 2nd "
|
|
|
|
|
"dimension we received is %d",
|
|
|
|
|
max_probs_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Indices"), max_probs_dims,
|
|
|
|
|
"The shape of Input(Indices) should bes same with max_probs_dims");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of PrecisionRecallOp Input(Indices) should be same "
|
|
|
|
|
"with "
|
|
|
|
|
"max_probs_dims. But received the shape of Input(Indices) is "
|
|
|
|
|
"[%d, %d], max_probs_dims is [%d, %d]",
|
|
|
|
|
ctx->GetInputDim("Indices")[0], ctx->GetInputDim("Indices")[1],
|
|
|
|
|
max_probs_dims[0], max_probs_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
max_probs_dims[0], labels_dims[0],
|
|
|
|
|
"The 1st dimension of Input(MaxProbs) and "
|
|
|
|
|
"Input(Labels) both are batch_size and the shape should "
|
|
|
|
|
"be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 1st dimension of PrecisionRecallOp Input(MaxProbs) and "
|
|
|
|
|
"Input(Labels) both should be batch_size"
|
|
|
|
|
"But the 1st dimension we received max_probs_dims[0] = %d, "
|
|
|
|
|
"labels_dims[0] = %d",
|
|
|
|
|
max_probs_dims[0], labels_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
|
|
|
|
|
"The 2nd dimension of Input(Labels) contains instance "
|
|
|
|
|
"label and the shape should be equal to 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 2nd dimension of PrecisionRecallOp "
|
|
|
|
|
"Input(Labels) contains instance label and "
|
|
|
|
|
"the shape should be equal to 1. But the 2nd "
|
|
|
|
|
"dimension we received is %d",
|
|
|
|
|
labels_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Weights")) {
|
|
|
|
|
auto weights_dims = ctx->GetInputDim("Weights");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(weights_dims,
|
|
|
|
|
framework::make_ddim({max_probs_dims[0], 1}),
|
|
|
|
|
"The shape of Input(Weights) should be "
|
|
|
|
|
"[batch_size, 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weights_dims, framework::make_ddim({max_probs_dims[0], 1}),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of PrecisionRecallOp Input(Weights) should be "
|
|
|
|
|
"[batch_size, 1]. But the shape we received is [%d, %d]",
|
|
|
|
|
weights_dims[0], weights_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("StatesInfo")) {
|
|
|
|
|
auto states_dims = ctx->GetInputDim("StatesInfo");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
|
|
|
|
|
"The shape of Input(StatesInfo) should be "
|
|
|
|
|
"[class_number, 4].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
states_dims, framework::make_ddim({cls_num, 4}),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of PrecisionRecallOp Input(StatesInfo) should be "
|
|
|
|
|
"[class_number, 4]. But the shape we received is [%d, %d]",
|
|
|
|
|
states_dims[0], states_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|