[ROCM] fix test_conv2d_transpose_op (#31749)

2.0.1-rocm-post
ronnywang 4 years ago committed by GitHub
parent a45c8ca69d
commit 8c19d7aa2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -202,7 +202,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups;
int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1;
c_groups = groups;
groups = 1;
@ -452,7 +452,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups;
int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1;
c_groups = groups;
groups = 1;

@ -116,7 +116,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2DTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.need_check_grad = True
self.is_test = False
self.use_cudnn = False

Loading…
Cancel
Save