|
|
|
@ -22,23 +22,35 @@ class AccuracyOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Inference"),
|
|
|
|
|
"Input(Inference) of AccuracyOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"),
|
|
|
|
|
"Input (Out) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Indices"),
|
|
|
|
|
"Input (Indices) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
|
|
|
|
"Input(Label) of AccuracyOp should not be null.");
|
|
|
|
|
"Input (Label) of accuracy op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
|
|
|
|
|
"Output(Accuracy) of AccuracyOp should not be null.");
|
|
|
|
|
"Output (Accuracy) of AccuracyOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto inference_dim = ctx->GetInputDim("Inference");
|
|
|
|
|
auto inference_dim = ctx->GetInputDim("Out");
|
|
|
|
|
auto label_dim = ctx->GetInputDim("Label");
|
|
|
|
|
// Assume indices has same shape with infernece, because
|
|
|
|
|
// it's the output of topk.
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
|
|
|
|
|
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
|
|
|
|
|
"inference size must be the same as label size");
|
|
|
|
|
"the inference tensor's num_rows must be"
|
|
|
|
|
" the same as label.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Accuracy", {1});
|
|
|
|
|
ctx->ShareLoD("Inference", /*->*/ "Accuracy");
|
|
|
|
|
ctx->ShareLoD("Out", /*->*/ "Accuracy");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// IndicateDataType
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -48,7 +60,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
// TODO(typhoonzero): support both inference value and indices.
|
|
|
|
|
AddInput("Inference", "topk(indices) the network output");
|
|
|
|
|
AddInput("Out", "topk (inferences) the network output");
|
|
|
|
|
AddInput("Indices", "topk (indices) the network output");
|
|
|
|
|
AddInput("Label", "Label of the training data");
|
|
|
|
|
// TODO(typhoonzero): AddInput("Weight", ...
|
|
|
|
|
AddOutput("Accuracy", "The accuracy of current batch");
|
|
|
|
@ -59,7 +72,7 @@ The accuracy is:
|
|
|
|
|
.. math::
|
|
|
|
|
accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})
|
|
|
|
|
|
|
|
|
|
Both the input `Inference` and `Label` can carry the LoD (Level of Details)
|
|
|
|
|
Both the input `Out` and `Label` can carry the LoD (Level of Details)
|
|
|
|
|
information, or not. But the output only shares the LoD with input `Inference`.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -71,6 +84,8 @@ information, or not. But the output only shares the LoD with input `Inference`.
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
|
|
|
|
|
ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>);
|
|
|
|
|
// FIXME(typhoonzero): types of T is for infernece data.
|
|
|
|
|
// label data is always int.
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(accuracy,
|
|
|
|
|
ops::AccuracyKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::AccuracyKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
|