fix conv depthwise bug (#27278)

Fix conv deepwise bug when in_channels=1.
disable_ut_1
LielinJiang 5 years ago committed by GitHub
parent bbad3414e8
commit a685435962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -267,8 +267,8 @@ def conv1d(x,
dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1]
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
if (num_channels == groups and num_channels != 1 and
num_filters % num_channels == 0 and not use_cudnn):
l_type = 'depthwise_conv2d'
use_cudnn = False
@ -491,7 +491,8 @@ def conv2d(x,
dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0):
if (num_channels == groups and num_channels != 1 and
num_filters % num_channels == 0):
l_type = 'depthwise_conv2d'
use_cudnn = False
@ -761,7 +762,8 @@ def conv_transpose1d(x,
op_type = 'conv2d_transpose'
num_filters = weight.shape[1]
if (num_channels == groups and num_filters == 1 and not use_cudnn):
if (num_channels == groups and num_channels != 1 and num_filters == 1 and
not use_cudnn):
op_type = 'depthwise_conv2d_transpose'
use_cudnn = False
@ -1010,7 +1012,7 @@ def conv_transpose2d(x,
op_type = 'conv2d_transpose'
num_filters = weight.shape[1]
if (num_channels == groups and num_filters == 1):
if (num_channels == groups and num_channels != 1 and num_filters == 1):
op_type = 'depthwise_conv2d_transpose'
use_cudnn = False

Loading…
Cancel
Save