|
|
|
@ -41,7 +41,7 @@ struct CudnnActivationFunctor {
|
|
|
|
|
TensorDescriptor x_desc, out_desc;
|
|
|
|
|
x_desc.set(x);
|
|
|
|
|
out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnActivationForward(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward(
|
|
|
|
|
ctx_.cudnn_handle(), act_desc.desc(),
|
|
|
|
|
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
|
|
|
|
|
platform::CudnnDataType<T>::kZero(), out_desc.desc(),
|
|
|
|
@ -67,7 +67,7 @@ struct CudnnActivationGradFunctor {
|
|
|
|
|
out_desc.set(out);
|
|
|
|
|
dout_desc.set(dout);
|
|
|
|
|
dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward(
|
|
|
|
|
ctx_.cudnn_handle(), act_desc.desc(),
|
|
|
|
|
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
|
|
|
|
|
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
|
|
|
|
|