|
|
|
@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest):
|
|
|
|
|
self.init_test_case()
|
|
|
|
|
self.op_type = "unsqueeze"
|
|
|
|
|
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
|
|
|
|
|
self.attrs = {"axes": self.axes, "inplace": False}
|
|
|
|
|
self.init_attrs()
|
|
|
|
|
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest):
|
|
|
|
|
self.axes = (1, 2)
|
|
|
|
|
self.new_shape = (3, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {"axes": self.axes, "inplace": False}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: Single input index.
|
|
|
|
|
class TestUnsqueezeOp1(TestUnsqueezeOp):
|
|
|
|
@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
|
|
|
|
|
self.axes = (0, 2)
|
|
|
|
|
self.new_shape = (1, 3, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {"axes": self.axes, "inplace": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: Inplace. There is mins index.
|
|
|
|
|
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
|
|
|
|
@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
|
|
|
|
|
self.axes = (0, -2)
|
|
|
|
|
self.new_shape = (1, 3, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {"axes": self.axes, "inplace": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Correct: Inplace. There is duplicated axis.
|
|
|
|
|
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
|
|
|
|
@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
|
|
|
|
|
self.axes = (0, 3, 3)
|
|
|
|
|
self.new_shape = (1, 3, 2, 1, 1, 5)
|
|
|
|
|
|
|
|
|
|
def init_attrs(self):
|
|
|
|
|
self.attrs = {"axes": self.axes, "inplace": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|