add test_lookup_sparse_table_op

panyx0718-patch-1
Qiao Longfei 6 years ago
parent 8d205c853c
commit bad0c27e6e

@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest):
assert (result_array2[3] == w_array[6]).all() assert (result_array2[3] == w_array[6]).all()
assert (result_array2[4] == w_array[7]).all() assert (result_array2[4] == w_array[7]).all()
# create and run lookup_table operator
test_lookup_table = Operator(
"lookup_sparse_table",
W='W',
Ids='Ids',
Out='Out',
min=-5.0,
max=10.0,
seed=10,
is_test=True)
ids = scope.var("Ids").get_tensor()
unknown_id = [44, 22, 33]
ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64")
ids.set(ids_array2, place)
test_lookup_table.run(scope, place)
result_array2 = np.array(out_tensor)
assert (result_array2[0] == w_array[5]).all()
assert (result_array2[1] == w_array[1]).all()
assert (result_array2[2] == w_array[2]).all()
assert (result_array2[3] == w_array[6]).all()
assert (result_array2[4] == w_array[7]).all()
for i in [5, 6, 7]:
assert np.all(result_array2[i] == 0)
def test_w_is_selected_rows(self): def test_w_is_selected_rows(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
# currently only support CPU # currently only support CPU

Loading…
Cancel
Save