|
|
|
@ -27,15 +27,9 @@ class BCELossOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) should be not null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Label"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Label) should be not null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Output(Out) should be not null."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
@ -74,18 +68,12 @@ class BCELossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) should be not null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Label"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Label) should be not null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@GRAD) shoudl be not null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(X@GRAD) should be not null."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELossGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELossGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "BCELossGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
|
framework::GradVarName("X"), "BCELossGrad");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
@ -152,7 +140,6 @@ class BCELossGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
op->SetInput("Label", this->Input("Label"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
|
|
|
|
// op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|