fix spp test (#22675)

revert-22710-feature/integrated_ps_api
dyning 5 years ago committed by GitHub
parent 1a595d8e90
commit 769c032fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,8 +25,11 @@ class TestSppOp(OpTest):
def setUp(self):
self.op_type = "spp"
self.init_test_case()
input = np.random.random(self.shape).astype("float64")
nsize, csize, hsize, wsize = input.shape
nsize, csize, hsize, wsize = self.shape
data = np.array(list(range(nsize * csize * hsize * wsize)))
input = data.reshape(self.shape)
input_random = np.random.random(self.shape).astype("float64")
input = input + input_random
out_level_flatten = []
for i in range(self.pyramid_height):
bins = np.power(2, i)
@ -55,7 +58,6 @@ class TestSppOp(OpTest):
'pyramid_height': self.pyramid_height,
'pooling_type': self.pool_type
}
self.outputs = {'Out': output.astype('float64')}
def test_check_output(self):
@ -65,7 +67,7 @@ class TestSppOp(OpTest):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.shape = [3, 2, 4, 4]
self.shape = [3, 2, 16, 16]
self.pyramid_height = 3
self.pool2D_forward_naive = max_pool2D_forward_naive
self.pool_type = "max"

@ -29,7 +29,6 @@ NEED_TO_FIX_OP_LIST = [
'scatter',
'smooth_l1_loss',
'soft_relu',
'spp',
'squared_l2_distance',
'tree_conv',
]

Loading…
Cancel
Save