|
|
@ -13,6 +13,7 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from py_paddle import swig_paddle
|
|
|
|
from py_paddle import swig_paddle
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -36,6 +37,17 @@ class TestArguments(unittest.TestCase):
|
|
|
|
np_arr = iv.toNumpyArrayInplace()
|
|
|
|
np_arr = iv.toNumpyArrayInplace()
|
|
|
|
self.assertEqual(np_arr.shape, (6, ))
|
|
|
|
self.assertEqual(np_arr.shape, (6, ))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_arguments_shape(self):
|
|
|
|
|
|
|
|
h, w = 4, 6
|
|
|
|
|
|
|
|
v = np.random.rand(2, h * w)
|
|
|
|
|
|
|
|
m = swig_paddle.Matrix.createDense(v.flatten(), 2, h * w)
|
|
|
|
|
|
|
|
args = swig_paddle.Arguments.createArguments(1)
|
|
|
|
|
|
|
|
args.setSlotValue(0, m)
|
|
|
|
|
|
|
|
args.setSlotFrameHeight(0, h)
|
|
|
|
|
|
|
|
args.setSlotFrameWidth(0, w)
|
|
|
|
|
|
|
|
self.assertEqual(args.getSlotFrameHeight(), h)
|
|
|
|
|
|
|
|
self.assertEqual(args.getSlotFrameWidth(), w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
swig_paddle.initPaddle("--use_gpu=0")
|
|
|
|
swig_paddle.initPaddle("--use_gpu=0")
|
|
|
|