fix conv, pool, conv_trans to decide use cudnn or not

add_depthwiseConv_op_gpu
chengduoZH 8 years ago
parent 78dc93430c
commit 79aa51229a

@ -70,6 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -283,6 +284,7 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {

@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {

@ -64,6 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -88,6 +89,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {

@ -57,6 +57,10 @@ void EnforceCUDNNLoaded(const char* fn_name) {
bool HasCUDNN() { return true; }
#endif
#ifndef PADDLE_WITH_CUDA
bool HasCUDNN() { return false; }
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle

Loading…
Cancel
Save