|
|
|
@ -100,6 +100,41 @@ class TestConcatOp5(TestConcatOp):
|
|
|
|
|
self.axis = -3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConcatOp6(TestConcatOp):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "concat"
|
|
|
|
|
self.dtype = self.get_dtype()
|
|
|
|
|
self.init_test_data()
|
|
|
|
|
self.lod = [[20, 80]]
|
|
|
|
|
self.out_lod = [[20, 80, 20, 80, 20, 80]]
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'X': [('x0', (self.x0, self.lod)), ('x1', (self.x1, self.lod)),
|
|
|
|
|
('x2', (self.x2, self.lod))]
|
|
|
|
|
}
|
|
|
|
|
self.attrs = {'axis': self.axis}
|
|
|
|
|
if self.axis < 0:
|
|
|
|
|
self.actual_axis = self.axis + len(self.x0.shape)
|
|
|
|
|
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
|
|
|
|
|
else:
|
|
|
|
|
self.actual_axis = self.axis
|
|
|
|
|
out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis)
|
|
|
|
|
self.outputs = {'Out': (out, self.out_lod)}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output(check_dygraph=False)
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(['x0'], 'Out', check_dygraph=False)
|
|
|
|
|
self.check_grad(['x1'], 'Out', check_dygraph=False)
|
|
|
|
|
self.check_grad(['x2'], 'Out', check_dygraph=False)
|
|
|
|
|
|
|
|
|
|
def init_test_data(self):
|
|
|
|
|
self.x0 = np.random.random([100]).astype(self.dtype)
|
|
|
|
|
self.x1 = np.random.random([100]).astype(self.dtype)
|
|
|
|
|
self.x2 = np.random.random([100]).astype(self.dtype)
|
|
|
|
|
self.axis = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_test_AxisTensor(parent):
|
|
|
|
|
class TestConcatAxisTensor(parent):
|
|
|
|
|
def setUp(self):
|
|
|
|
@ -134,6 +169,7 @@ create_test_AxisTensor(TestConcatOp2)
|
|
|
|
|
create_test_AxisTensor(TestConcatOp3)
|
|
|
|
|
create_test_AxisTensor(TestConcatOp4)
|
|
|
|
|
create_test_AxisTensor(TestConcatOp5)
|
|
|
|
|
create_test_AxisTensor(TestConcatOp6)
|
|
|
|
|
|
|
|
|
|
#----------------Concat Fp16----------------
|
|
|
|
|
|
|
|
|
@ -155,6 +191,7 @@ create_test_fp16(TestConcatOp2)
|
|
|
|
|
create_test_fp16(TestConcatOp3)
|
|
|
|
|
create_test_fp16(TestConcatOp4)
|
|
|
|
|
create_test_fp16(TestConcatOp5)
|
|
|
|
|
create_test_fp16(TestConcatOp6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConcatOpError(unittest.TestCase):
|
|
|
|
|