|
|
|
@ -251,7 +251,9 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
mode,
|
|
|
|
|
value=0,
|
|
|
|
|
data_format="NCDHW"):
|
|
|
|
|
if data_format == "NCDHW":
|
|
|
|
|
if mode == "constant" and len(pad) == len(input_data.shape) * 2:
|
|
|
|
|
pad = np.reshape(pad, (-1, 2)).tolist()
|
|
|
|
|
elif data_format == "NCDHW":
|
|
|
|
|
pad = [
|
|
|
|
|
(0, 0),
|
|
|
|
|
(0, 0),
|
|
|
|
@ -316,6 +318,7 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
input_shape = (1, 2, 3, 4, 5)
|
|
|
|
|
pad = [1, 2, 1, 1, 3, 4]
|
|
|
|
|
pad_3 = [1, 2, 1, 1, 3, 4, 5, 6, 7, 8]
|
|
|
|
|
mode = "constant"
|
|
|
|
|
value = 100
|
|
|
|
|
input_data = np.random.rand(*input_shape).astype(np.float32)
|
|
|
|
@ -323,6 +326,8 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
input_data, pad, mode, value, data_format="NCDHW")
|
|
|
|
|
np_out2 = self._get_numpy_out(
|
|
|
|
|
input_data, pad, mode, value, data_format="NDHWC")
|
|
|
|
|
np_out3 = self._get_numpy_out(
|
|
|
|
|
input_data, pad_3, mode, value, data_format="NCDHW")
|
|
|
|
|
tensor_data = paddle.to_tensor(input_data)
|
|
|
|
|
|
|
|
|
|
y1 = F.pad(tensor_data,
|
|
|
|
@ -335,14 +340,21 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NDHWC")
|
|
|
|
|
y3 = F.pad(tensor_data,
|
|
|
|
|
pad=pad_3,
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NCDHW")
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(y1.numpy(), np_out1))
|
|
|
|
|
self.assertTrue(np.allclose(y2.numpy(), np_out2))
|
|
|
|
|
self.assertTrue(np.allclose(y3.numpy(), np_out3))
|
|
|
|
|
|
|
|
|
|
def test_dygraph_2(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
input_shape = (2, 3, 4, 5)
|
|
|
|
|
pad = [1, 1, 3, 4]
|
|
|
|
|
pad_3 = [1, 2, 1, 1, 3, 4, 5, 6]
|
|
|
|
|
mode = "constant"
|
|
|
|
|
value = 100
|
|
|
|
|
input_data = np.random.rand(*input_shape).astype(np.float32)
|
|
|
|
@ -350,6 +362,8 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
input_data, pad, mode, value, data_format="NCHW")
|
|
|
|
|
np_out2 = self._get_numpy_out(
|
|
|
|
|
input_data, pad, mode, value, data_format="NHWC")
|
|
|
|
|
np_out3 = self._get_numpy_out(
|
|
|
|
|
input_data, pad_3, mode, value, data_format="NCHW")
|
|
|
|
|
|
|
|
|
|
tensor_data = paddle.to_tensor(input_data)
|
|
|
|
|
tensor_pad = paddle.to_tensor(pad, dtype="int32")
|
|
|
|
@ -364,14 +378,21 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NHWC")
|
|
|
|
|
y3 = F.pad(tensor_data,
|
|
|
|
|
pad=pad_3,
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NCHW")
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(y1.numpy(), np_out1))
|
|
|
|
|
self.assertTrue(np.allclose(y2.numpy(), np_out2))
|
|
|
|
|
self.assertTrue(np.allclose(y3.numpy(), np_out3))
|
|
|
|
|
|
|
|
|
|
def test_dygraph_3(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
input_shape = (3, 4, 5)
|
|
|
|
|
pad = [3, 4]
|
|
|
|
|
pad_3 = [3, 4, 5, 6, 7, 8]
|
|
|
|
|
mode = "constant"
|
|
|
|
|
value = 100
|
|
|
|
|
input_data = np.random.rand(*input_shape).astype(np.float32)
|
|
|
|
@ -379,6 +400,8 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
input_data, pad, mode, value, data_format="NCL")
|
|
|
|
|
np_out2 = self._get_numpy_out(
|
|
|
|
|
input_data, pad, mode, value, data_format="NLC")
|
|
|
|
|
np_out3 = self._get_numpy_out(
|
|
|
|
|
input_data, pad_3, mode, value, data_format="NCL")
|
|
|
|
|
tensor_data = paddle.to_tensor(input_data)
|
|
|
|
|
tensor_pad = paddle.to_tensor(pad, dtype="int32")
|
|
|
|
|
|
|
|
|
@ -392,9 +415,15 @@ class TestPadAPI(unittest.TestCase):
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NLC")
|
|
|
|
|
y3 = F.pad(tensor_data,
|
|
|
|
|
pad=pad_3,
|
|
|
|
|
mode=mode,
|
|
|
|
|
value=value,
|
|
|
|
|
data_format="NCL")
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(y1.numpy(), np_out1))
|
|
|
|
|
self.assertTrue(np.allclose(y2.numpy(), np_out2))
|
|
|
|
|
self.assertTrue(np.allclose(y3.numpy(), np_out3))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPad1dAPI(unittest.TestCase):
|
|
|
|
|