|
|
|
@ -62,6 +62,12 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
@ -265,6 +271,12 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|