Merge pull request #4841 from QiJune/pybind_selected_rows
export SelectedRows to Pythonrevert-4814-Add_sequence_project_op
commit
cdc236cb82
@ -0,0 +1,37 @@
|
|||||||
|
import paddle.v2.framework.core as core
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestSelectedRows(unittest.TestCase):
|
||||||
|
def test_selected_rows(self):
|
||||||
|
place = core.CPUPlace()
|
||||||
|
height = 10
|
||||||
|
rows = [0, 4, 7]
|
||||||
|
row_numel = 10
|
||||||
|
selcted_rows = core.SelectedRows(rows, row_numel)
|
||||||
|
np_array = np.ones((len(rows), height)).astype("float32")
|
||||||
|
np_array[0, 0] = 2.0
|
||||||
|
np_array[2, 8] = 4.0
|
||||||
|
tensor = selcted_rows.get_tensor()
|
||||||
|
tensor.set(np_array, place)
|
||||||
|
|
||||||
|
# compare rows
|
||||||
|
self.assertEqual(0, selcted_rows.rows()[0])
|
||||||
|
self.assertEqual(4, selcted_rows.rows()[1])
|
||||||
|
self.assertEqual(7, selcted_rows.rows()[2])
|
||||||
|
|
||||||
|
# compare height
|
||||||
|
self.assertEqual(10, selcted_rows.height())
|
||||||
|
|
||||||
|
# compare tensor
|
||||||
|
self.assertAlmostEqual(2.0,
|
||||||
|
selcted_rows.get_tensor().get_float_element(0))
|
||||||
|
self.assertAlmostEqual(1.0,
|
||||||
|
selcted_rows.get_tensor().get_float_element(1))
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue