|
|
|
@ -40,10 +40,12 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
|
|
|
|
|
softmax->mutable_data<T>(context.GetPlace());
|
|
|
|
|
loss->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int axis_dim = logits->dims()[logits->dims().size()-1];
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
context.template device_context<platform::CPUDeviceContext>();
|
|
|
|
|
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
|
|
|
|
|
dev_ctx, -1, logits, softmax);
|
|
|
|
|
dev_ctx, axis_dim, logits, softmax);
|
|
|
|
|
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
|
|
|
|
|
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
|
|
|
|
|
context.Attr<int>("ignore_index"));
|
|
|
|
|