|
|
|
@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"),
|
|
|
|
|
"Input of Inference must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input of Inference must be initialized.");
|
|
|
|
|
auto *inference = ctx.Input<framework::Tensor>("Inference");
|
|
|
|
|
auto *label = ctx.Input<framework::Tensor>("Label");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
|
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Inference"),
|
|
|
|
|
"Input of Inference must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
|
|
|
|
"Input of Label must be initialized.");
|
|
|
|
|
auto inference_dim = ctx->GetInputDim("Inference");
|
|
|
|
|
auto label_dim = ctx->GetInputDim("Label");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(inference_dim, label_dim,
|
|
|
|
|
"inference and label should have same shape");
|
|
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>("AUC")->Resize({1});
|
|
|
|
|
ctx->SetOutputDim("AUC", {1});
|
|
|
|
|
ctx->ShareLoD("Inference", /*->*/ "AUC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|