Refine unitest im2seq op. (#22372)

revert-22710-feature/integrated_ps_api
whs 5 years ago committed by Tao Luo
parent de9edf58ef
commit 5f655d2cef

@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, skip_check_grad_ci
def get_output_shape(attrs, in_shape, img_real_size): def get_output_shape(attrs, in_shape, img_real_size):
@ -142,7 +142,6 @@ class TestBlockExpandOp(OpTest):
x = np.random.uniform(0.1, 1, [ x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32") ]).astype("float32")
real_size = np.array([]).astype("float32") real_size = np.array([]).astype("float32")
out = Im2Sequence(x, real_size, self.attrs) out = Im2Sequence(x, real_size, self.attrs)
self.inputs = {'X': x} self.inputs = {'X': x}
@ -194,6 +193,9 @@ class TestBlockExpandOpCase4(TestBlockExpandOp):
} }
@skip_check_grad_ci(
reason="Since 'real_size' is used just in forward computation, we don't test the gradient here."
)
class TestBlockExpandOpCase5(OpTest): class TestBlockExpandOpCase5(OpTest):
def config(self): def config(self):
self.batch_size = 1 self.batch_size = 1
@ -206,6 +208,7 @@ class TestBlockExpandOpCase5(OpTest):
'paddings': [2, 1, 2, 1], 'paddings': [2, 1, 2, 1],
'out_stride': [2, 2], 'out_stride': [2, 2],
} }
self.real_size = np.array([[8, 10], [5, 8]]).astype("float32")
def setUp(self): def setUp(self):
self.config() self.config()
@ -213,16 +216,15 @@ class TestBlockExpandOpCase5(OpTest):
x = np.random.uniform(0.1, 1, [ x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32") ]).astype("float32")
real_size = np.array([[8, 10], [5, 8]]).astype("float32") out = np.array(Im2Sequence(x, self.real_size, self.attrs))
out = np.array(Im2Sequence(x, real_size, self.attrs)) self.inputs = {'X': x, 'Y': self.real_size}
self.inputs = {'X': x, 'Y': real_size} #l ??
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestBlockExpandOpCase6(OpTest): class TestBlockExpandOpCase6(TestBlockExpandOpCase5):
def config(self): def config(self):
self.batch_size = 3 self.batch_size = 3
self.img_channels = 1 self.img_channels = 1
@ -234,23 +236,10 @@ class TestBlockExpandOpCase6(OpTest):
'paddings': [0, 0, 0, 0], 'paddings': [0, 0, 0, 0],
'out_stride': [1, 1], 'out_stride': [1, 1],
} }
self.real_size = np.array([[8, 10], [5, 8], [5, 8]]).astype("float32")
def setUp(self):
self.config()
self.op_type = "im2sequence"
x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
real_size = np.array([[8, 10], [5, 8], [5, 8]]).astype("float32")
out = np.array(Im2Sequence(x, real_size, self.attrs))
self.inputs = {'X': x, 'Y': real_size} #l ??
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestBlockExpandOpCase7(OpTest): class TestBlockExpandOpCase7(TestBlockExpandOpCase6):
def config(self): def config(self):
self.batch_size = 2 self.batch_size = 2
self.img_channels = 2 self.img_channels = 2
@ -262,22 +251,8 @@ class TestBlockExpandOpCase7(OpTest):
'paddings': [1, 0, 1, 0], 'paddings': [1, 0, 1, 0],
'out_stride': [2, 2], 'out_stride': [2, 2],
} }
self.real_size = np.array([[6, 6], [4, 4]]).astype("float32")
def setUp(self):
self.config()
self.op_type = "im2sequence"
x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
real_size = np.array([[6, 6], [4, 4]]).astype("float32")
out = np.array(Im2Sequence(x, real_size, self.attrs))
self.inputs = {'X': x, 'Y': real_size}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#set shiftwidth=4 set expandtab set tabstop=4

Loading…
Cancel
Save