|
|
|
@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest):
|
|
|
|
|
w_array = np.ones((len(rows), row_numel)).astype("float32")
|
|
|
|
|
for i in range(len(rows)):
|
|
|
|
|
w_array[i] *= i
|
|
|
|
|
ids_tensor = w_selected_rows.get_tensor()
|
|
|
|
|
ids_tensor.set(w_array, place)
|
|
|
|
|
w_tensor = w_selected_rows.get_tensor()
|
|
|
|
|
w_tensor.set(w_array, place)
|
|
|
|
|
|
|
|
|
|
# create Out Variable
|
|
|
|
|
Out_tensor = scope.var('Out').get_tensor()
|
|
|
|
|
out_tensor = scope.var('Out').get_tensor()
|
|
|
|
|
|
|
|
|
|
# create and run lookup_table operator
|
|
|
|
|
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
|
|
|
|
|
lookup_table.run(scope, place)
|
|
|
|
|
|
|
|
|
|
# get result from Out
|
|
|
|
|
result_array = np.array(Out_tensor)
|
|
|
|
|
result_array = np.array(out_tensor)
|
|
|
|
|
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
|
|
|
|
|
for idx, row in enumerate(ids_array):
|
|
|
|
|
assert (row[0] == result_array[idx]).all()
|
|
|
|
|