|
|
@ -139,7 +139,6 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
|
|
|
auto loss_dims = logits_dims;
|
|
|
|
auto loss_dims = logits_dims;
|
|
|
|
loss_dims[rank - 1] = 1;
|
|
|
|
loss_dims[rank - 1] = 1;
|
|
|
|
ctx->SetOutputDim("Loss", loss_dims);
|
|
|
|
ctx->SetOutputDim("Loss", loss_dims);
|
|
|
|
// ctx->SetOutputDim("Loss", {logits_dims[0], 1});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Softmax");
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|
ctx->ShareLoD("Logits", /*->*/ "Loss");
|
|
|
|