@ -19,6 +19,7 @@ from op_test import OpTest
class TestSplitOp(OpTest):
def setUp(self):
self._set_op_type()
axis = 1
x = np.random.random((4, 5, 6)).astype('float32')
out = np.split(x, [2, 3], axis)