|
|
|
@ -21,7 +21,6 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
@ -48,6 +47,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Y");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// CrossEntropy's data type just determined by "X"
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -59,7 +59,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
@ -94,6 +93,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// CrossEntropy's data type just determined by "X"
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|