follow comments

add_depthwiseConv_op_gpu
chengduoZH 7 years ago
parent 251c6032fb
commit 24f528a1a5

@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;

@ -41,10 +41,9 @@ def img_conv_group(input,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
conv_use_cudnn=True,
pool_stride=1,
pool_type=None,
pool_use_cudnn=True):
use_cudnn=True):
"""
Image Convolution Group, Used for vgg net.
"""
@ -76,7 +75,7 @@ def img_conv_group(input,
padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act,
use_cudnn=conv_use_cudnn)
use_cudnn=use_cudnn)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act)
@ -89,7 +88,7 @@ def img_conv_group(input,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
use_cudnn=pool_use_cudnn)
use_cudnn=use_cudnn)
return pool_out

Loading…
Cancel
Save