|
|
|
@ -50,33 +50,33 @@ class TestUnsqueezeOp(OpTest):
|
|
|
|
|
# Correct: Single input index.
|
|
|
|
|
class TestUnsqueezeOp1(TestUnsqueezeOp):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (-1, )
|
|
|
|
|
self.new_shape = (3, 5, 1)
|
|
|
|
|
self.new_shape = (20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: Mixed input axis.
|
|
|
|
|
class TestUnsqueezeOp2(TestUnsqueezeOp):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (0, -1)
|
|
|
|
|
self.new_shape = (1, 3, 5, 1)
|
|
|
|
|
self.new_shape = (1, 20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: There is duplicated axis.
|
|
|
|
|
class TestUnsqueezeOp3(TestUnsqueezeOp):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (0, 3, 3)
|
|
|
|
|
self.new_shape = (1, 3, 2, 1, 1, 5)
|
|
|
|
|
self.new_shape = (1, 10, 2, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: Reversed axes.
|
|
|
|
|
class TestUnsqueezeOp4(TestUnsqueezeOp):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (3, 1, 1)
|
|
|
|
|
self.new_shape = (3, 1, 1, 2, 5, 1)
|
|
|
|
|
self.new_shape = (10, 1, 1, 2, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# axes is a list(with tensor)
|
|
|
|
@ -107,9 +107,9 @@ class TestUnsqueezeOp_AxesTensorList(OpTest):
|
|
|
|
|
self.check_grad(["X"], "Out")
|
|
|
|
|
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (1, 2)
|
|
|
|
|
self.new_shape = (3, 1, 1, 5)
|
|
|
|
|
self.new_shape = (20, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {}
|
|
|
|
@ -117,30 +117,30 @@ class TestUnsqueezeOp_AxesTensorList(OpTest):
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (-1, )
|
|
|
|
|
self.new_shape = (3, 5, 1)
|
|
|
|
|
self.new_shape = (20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (0, -1)
|
|
|
|
|
self.new_shape = (1, 3, 5, 1)
|
|
|
|
|
self.new_shape = (1, 20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (0, 3, 3)
|
|
|
|
|
self.new_shape = (1, 3, 2, 1, 1, 5)
|
|
|
|
|
self.new_shape = (1, 10, 2, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (3, 1, 1)
|
|
|
|
|
self.new_shape = (3, 1, 1, 2, 5, 1)
|
|
|
|
|
self.new_shape = (10, 1, 1, 2, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# axes is a Tensor
|
|
|
|
@ -166,9 +166,9 @@ class TestUnsqueezeOp_AxesTensor(OpTest):
|
|
|
|
|
self.check_grad(["X"], "Out")
|
|
|
|
|
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (1, 2)
|
|
|
|
|
self.new_shape = (3, 1, 1, 5)
|
|
|
|
|
self.new_shape = (20, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {}
|
|
|
|
@ -176,30 +176,30 @@ class TestUnsqueezeOp_AxesTensor(OpTest):
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (-1, )
|
|
|
|
|
self.new_shape = (3, 5, 1)
|
|
|
|
|
self.new_shape = (20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 5)
|
|
|
|
|
self.ori_shape = (20, 5)
|
|
|
|
|
self.axes = (0, -1)
|
|
|
|
|
self.new_shape = (1, 3, 5, 1)
|
|
|
|
|
self.new_shape = (1, 20, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (0, 3, 3)
|
|
|
|
|
self.new_shape = (1, 3, 2, 1, 1, 5)
|
|
|
|
|
self.new_shape = (1, 10, 2, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.ori_shape = (3, 2, 5)
|
|
|
|
|
self.ori_shape = (10, 2, 5)
|
|
|
|
|
self.axes = (3, 1, 1)
|
|
|
|
|
self.new_shape = (3, 1, 1, 2, 5, 1)
|
|
|
|
|
self.new_shape = (10, 1, 1, 2, 5, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# test api
|
|
|
|
|