|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/softmax_with_cross_entropy_op.h"
|
|
|
|
|
#include <paddle/function/TensorType.h>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
@ -149,6 +155,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
|
|
|
|
ctx->GetInputDim("Softmax"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|