|
|
|
@ -207,27 +207,35 @@ class TestUnsqueezeAPI(OpTest):
|
|
|
|
|
def test_api(self):
|
|
|
|
|
input = np.random.random([3, 2, 5]).astype("float32")
|
|
|
|
|
x = fluid.data(name='x', shape=[3, 2, 5], dtype="float32")
|
|
|
|
|
positive_3 = fluid.layers.fill_constant([1], "int32", 3)
|
|
|
|
|
axes_tensor = fluid.data(name='axes_tensor', shape=[3], dtype="int32")
|
|
|
|
|
positive_3_int32 = fluid.layers.fill_constant([1], "int32", 3)
|
|
|
|
|
positive_1_int64 = fluid.layers.fill_constant([1], "int64", 1)
|
|
|
|
|
axes_tensor_int32 = fluid.data(
|
|
|
|
|
name='axes_tensor_int32', shape=[3], dtype="int32")
|
|
|
|
|
axes_tensor_int64 = fluid.data(
|
|
|
|
|
name='axes_tensor_int64', shape=[3], dtype="int64")
|
|
|
|
|
|
|
|
|
|
out_1 = fluid.layers.unsqueeze(x, axes=[3, 1, 1])
|
|
|
|
|
out_2 = fluid.layers.unsqueeze(x, axes=[positive_3, 1, 1])
|
|
|
|
|
out_3 = fluid.layers.unsqueeze(x, axes=axes_tensor)
|
|
|
|
|
out_2 = fluid.layers.unsqueeze(
|
|
|
|
|
x, axes=[positive_3_int32, positive_1_int64, 1])
|
|
|
|
|
out_3 = fluid.layers.unsqueeze(x, axes=axes_tensor_int32)
|
|
|
|
|
out_4 = fluid.layers.unsqueeze(x, axes=3)
|
|
|
|
|
out_5 = fluid.layers.unsqueeze(x, axes=axes_tensor_int64)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(place=fluid.CPUPlace())
|
|
|
|
|
res_1, res_2, res_3, res_4 = exe.run(
|
|
|
|
|
res_1, res_2, res_3, res_4, res_5 = exe.run(
|
|
|
|
|
fluid.default_main_program(),
|
|
|
|
|
feed={
|
|
|
|
|
"x": input,
|
|
|
|
|
"axes_tensor": np.array([3, 1, 1]).astype("int32")
|
|
|
|
|
"axes_tensor_int32": np.array([3, 1, 1]).astype("int32"),
|
|
|
|
|
"axes_tensor_int64": np.array([3, 1, 1]).astype("int64")
|
|
|
|
|
},
|
|
|
|
|
fetch_list=[out_1, out_2, out_3, out_4])
|
|
|
|
|
fetch_list=[out_1, out_2, out_3, out_4, out_5])
|
|
|
|
|
|
|
|
|
|
assert np.array_equal(res_1, input.reshape([3, 1, 1, 2, 5, 1]))
|
|
|
|
|
assert np.array_equal(res_2, input.reshape([3, 1, 1, 2, 5, 1]))
|
|
|
|
|
assert np.array_equal(res_3, input.reshape([3, 1, 1, 2, 5, 1]))
|
|
|
|
|
assert np.array_equal(res_4, input.reshape([3, 2, 5, 1]))
|
|
|
|
|
assert np.array_equal(res_5, input.reshape([3, 1, 1, 2, 5, 1]))
|
|
|
|
|
|
|
|
|
|
def test_error(self):
|
|
|
|
|
def test_axes_type():
|
|
|
|
|