|
|
|
@ -42,17 +42,17 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("Inference",
|
|
|
|
|
"A floating point `Tensor` of arbitrary shape and whose values"
|
|
|
|
|
"are in the range `[0, 1]`.");
|
|
|
|
|
"A floating point tensor of arbitrary shape and whose values"
|
|
|
|
|
"are in the range [0, 1].");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"A `Tensor` whose shape matches "
|
|
|
|
|
"`Inference`. Will be cast to `bool`.");
|
|
|
|
|
"A tensor whose shape matches "
|
|
|
|
|
"Inference. Will be cast to bool.");
|
|
|
|
|
// TODO(typhoonzero): support weight input
|
|
|
|
|
AddOutput("AUC",
|
|
|
|
|
"A scalar `Tensor` representing the "
|
|
|
|
|
"A scalar representing the "
|
|
|
|
|
"current area-under-curve.");
|
|
|
|
|
|
|
|
|
|
AddAttr<std::string>("curve", "Possible curves are ROC and PR")
|
|
|
|
|
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
|
|
|
|
|
.SetDefault("ROC");
|
|
|
|
|
AddAttr<int>("num_thresholds",
|
|
|
|
|
"The number of thresholds to use when discretizing the"
|
|
|
|
@ -62,7 +62,8 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(
|
|
|
|
|
R"DOC(Computes the AUC according forward output and label.
|
|
|
|
|
Best to use for binary classification evaluations.
|
|
|
|
|
If `label` can be values other than 0 and 1, it will be cast
|
|
|
|
|
|
|
|
|
|
If input label contains values other than 0 and 1, it will be cast
|
|
|
|
|
to bool.
|
|
|
|
|
|
|
|
|
|
You can find the definations here:
|
|
|
|
|