fix deformable_conv small cases, test=develop (#22441)

revert-22710-feature/integrated_ps_api
Bai Yifan 5 years ago committed by GitHub
parent 943cb8c664
commit c8b90d8f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -154,32 +154,11 @@ class TestModulatedDeformableConvOp(OpTest):
'Output',
max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Input']))
def test_check_grad_no_offset_no_mask(self):
self.check_grad(
['Input', 'Filter'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 4, 4, 4] # NCHW
self.input_size = [2, 8, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
@ -229,7 +208,7 @@ class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [2, 3, 4, 4] # NCHW
self.input_size = [4, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
@ -250,14 +229,14 @@ class TestWithDilation(TestModulatedDeformableConvOp):
self.dilations = [2, 2]
class TestWith1x1(TestModulatedDeformableConvOp):
class TestWith3x3(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [0, 0]
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[

@ -21,7 +21,6 @@ NEED_TO_FIX_OP_LIST = [
'fused_elemwise_activation',
'bilinear_tensor_product',
'conv2d_transpose',
'deformable_conv',
'depthwise_conv2d_transpose',
'grid_sampler',
'hierarchical_sigmoid',

Loading…
Cancel
Save