|
|
|
@ -36,7 +36,9 @@ class SoftmaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_ON_INFERENCE
|
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T, true>()(
|
|
|
|
|
math::SoftmaxFunctor<
|
|
|
|
|
DeviceContext, T,
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
|
|
|
|
|
#else
|
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T, false>()(
|
|
|
|
|