- Fix to GPU

test=develop
panyx0718-patch-1
Jacek Czaja 6 years ago
parent 513bb6c151
commit be80bb4f28

@ -36,7 +36,9 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
math::SoftmaxFunctor<
DeviceContext, T,
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(

Loading…
Cancel
Save