|
|
|
@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::dynload::HasCUDNN();
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::dynload::HasCUDNN();
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|