|
|
|
@ -25,9 +25,15 @@ class TestConcatOp(OpTest):
|
|
|
|
|
self.init_test_data()
|
|
|
|
|
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
|
|
|
|
|
self.attrs = {'axis': self.axis}
|
|
|
|
|
if self.axis < 0:
|
|
|
|
|
self.actual_axis = self.axis + len(self.x0.shape)
|
|
|
|
|
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
|
|
|
|
|
else:
|
|
|
|
|
self.actual_axis = self.axis
|
|
|
|
|
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Out': np.concatenate(
|
|
|
|
|
(self.x0, self.x1, self.x2), axis=self.axis)
|
|
|
|
|
(self.x0, self.x1, self.x2), axis=self.actual_axis)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
@ -75,5 +81,13 @@ class TestConcatOp4(TestConcatOp):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConcatOp5(TestConcatOp):
|
|
|
|
|
def init_test_data(self):
|
|
|
|
|
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
|
|
|
|
|
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
|
|
|
|
|
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
|
|
|
|
|
self.axis = -3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|