|
|
|
@ -22,18 +22,32 @@ class AccuracyOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"),
|
|
|
|
|
"Input (Out) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Indices"),
|
|
|
|
|
"Input (Indices) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
|
|
|
|
"Input (Label) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
|
|
|
|
|
"Output (Accuracy) of AccuracyOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Correct"),
|
|
|
|
|
"Output (Correct) of AccuracyOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Total"),
|
|
|
|
|
"Output (Total) of AccuracyOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Out"), true,
|
|
|
|
|
platform::errors::NotFound("Input (Out) of AccuracyOp is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input (Indices) of AccuracyOp is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input (Label) of AccuracyOp is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Accuracy"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output (Accuracy) of AccuracyOp is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Correct"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output (Correct) of AccuracyOp is not found."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Total"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output (Total) of AccuracyOp is not found."));
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "Accuracy");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Accuracy");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Accuracy");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Accuracy"), "Output", "Accuracy",
|
|
|
|
|
"Accuracy");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Correct"), "Output", "Correct", "Accuracy");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Total"), "Output", "Total", "Accuracy");
|
|
|
|
|
|
|
|
|
|
auto inference_dim = ctx->GetInputDim("Out");
|
|
|
|
|
auto label_dim = ctx->GetInputDim("Label");
|
|
|
|
@ -42,22 +56,26 @@ class AccuracyOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dim.size(), 2,
|
|
|
|
|
"ShapeError: label's dimensions of AccuracyOp must be 2. "
|
|
|
|
|
"But received label's dimensions = %d, label's shape = [%s]",
|
|
|
|
|
label_dim.size(), label_dim);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: label's dimensions of AccuracyOp must be 2. "
|
|
|
|
|
"But received label's dimensions = %d, label's shape = [%s]",
|
|
|
|
|
label_dim.size(), label_dim));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim[1], 1,
|
|
|
|
|
"ShapeError: label's second dimension of "
|
|
|
|
|
"AccuracyOp must be 1. But received label's "
|
|
|
|
|
"second dimension is = %d, label's shape = [%s]",
|
|
|
|
|
label_dim[1], label_dim);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: label's second dimension of "
|
|
|
|
|
"AccuracyOp must be 1. But received label's "
|
|
|
|
|
"second dimension is = %d, label's shape = [%s]",
|
|
|
|
|
label_dim[1], label_dim));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
inference_dim[0], label_dim[0],
|
|
|
|
|
"ShapeError: the output's num_rows of AccuracyOp must be"
|
|
|
|
|
" the same as label's num_rows. But received output's "
|
|
|
|
|
"shape = [%s], label's shape = [%s], output's num_rows = %d, label's "
|
|
|
|
|
"num_rows = %d",
|
|
|
|
|
inference_dim, label_dim, inference_dim[0], label_dim[0]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the output's num_rows of AccuracyOp must be"
|
|
|
|
|
" the same as label's num_rows. But received output's "
|
|
|
|
|
"shape = [%s], label's shape = [%s], output's num_rows = %d, "
|
|
|
|
|
"label's "
|
|
|
|
|
"num_rows = %d",
|
|
|
|
|
inference_dim, label_dim, inference_dim[0], label_dim[0]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Accuracy", {1});
|
|
|
|
|