|
|
|
@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest):
|
|
|
|
|
assert (result_array2[3] == w_array[6]).all()
|
|
|
|
|
assert (result_array2[4] == w_array[7]).all()
|
|
|
|
|
|
|
|
|
|
# create and run lookup_table operator
|
|
|
|
|
test_lookup_table = Operator(
|
|
|
|
|
"lookup_sparse_table",
|
|
|
|
|
W='W',
|
|
|
|
|
Ids='Ids',
|
|
|
|
|
Out='Out',
|
|
|
|
|
min=-5.0,
|
|
|
|
|
max=10.0,
|
|
|
|
|
seed=10,
|
|
|
|
|
is_test=True)
|
|
|
|
|
|
|
|
|
|
ids = scope.var("Ids").get_tensor()
|
|
|
|
|
unknown_id = [44, 22, 33]
|
|
|
|
|
ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64")
|
|
|
|
|
ids.set(ids_array2, place)
|
|
|
|
|
test_lookup_table.run(scope, place)
|
|
|
|
|
|
|
|
|
|
result_array2 = np.array(out_tensor)
|
|
|
|
|
assert (result_array2[0] == w_array[5]).all()
|
|
|
|
|
assert (result_array2[1] == w_array[1]).all()
|
|
|
|
|
assert (result_array2[2] == w_array[2]).all()
|
|
|
|
|
assert (result_array2[3] == w_array[6]).all()
|
|
|
|
|
assert (result_array2[4] == w_array[7]).all()
|
|
|
|
|
|
|
|
|
|
for i in [5, 6, 7]:
|
|
|
|
|
assert np.all(result_array2[i] == 0)
|
|
|
|
|
|
|
|
|
|
def test_w_is_selected_rows(self):
|
|
|
|
|
places = [core.CPUPlace()]
|
|
|
|
|
# currently only support CPU
|
|
|
|
|