refine code

add_depthwiseConv_op_gpu
chengduoZH 7 years ago
parent ed7e74abf2
commit c9641a03dc

@ -71,6 +71,12 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -285,6 +291,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -62,6 +62,12 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -265,6 +271,12 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -65,6 +65,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
@ -90,6 +96,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

Loading…
Cancel
Save