|
|
|
@ -69,7 +69,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::SoftmaxFunctor<platform::GPUPlace, T>()(context.device_context(),
|
|
|
|
|
logits, softmax);
|
|
|
|
|
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
|
|
|
|
|
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
|
|
|
|
|
context.device_context(), loss, softmax, labels,
|
|
|
|
|
context.Attr<bool>("softLabel"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|