|
|
|
@ -61,7 +61,9 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::dynload::HasCUDNN();
|
|
|
|
|
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
use_cudnn = false;
|
|
|
|
|
}
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
@ -264,7 +266,9 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::dynload::HasCUDNN();
|
|
|
|
|
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
use_cudnn = false;
|
|
|
|
|
}
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|