|
|
|
@ -53,18 +53,11 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
|
|
|
|
|
def check_with_place(self, place):
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
|
|
|
|
|
# create and initialize Grad Variable
|
|
|
|
|
# create and initialize Variable
|
|
|
|
|
height = 10
|
|
|
|
|
rows = [0, 4, 4, 7]
|
|
|
|
|
row_numel = 12
|
|
|
|
|
|
|
|
|
|
ids_selected_rows = scope.var('Ids').get_selected_rows()
|
|
|
|
|
ids_selected_rows.set_height(height)
|
|
|
|
|
ids_selected_rows.set_rows(rows)
|
|
|
|
|
np_array = np.ones((len(rows), row_numel)).astype("float32")
|
|
|
|
|
ids_tensor = ids_selected_rows.get_tensor()
|
|
|
|
|
ids_tensor.set(np_array, place)
|
|
|
|
|
|
|
|
|
|
# create and initialize W Variable
|
|
|
|
|
W = scope.var('W').get_tensor()
|
|
|
|
|
W_array = np.full((height, row_numel), 1.0).astype("float32")
|
|
|
|
@ -72,20 +65,26 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
|
|
|
|
|
W_array[i] *= i
|
|
|
|
|
W.set(W_array, place)
|
|
|
|
|
|
|
|
|
|
# create and initialize Ids Variable
|
|
|
|
|
ids_selected_rows = scope.var('Ids').get_selected_rows()
|
|
|
|
|
ids_selected_rows.set_height(len(rows))
|
|
|
|
|
ids_selected_rows.set_rows(rows)
|
|
|
|
|
np_array = np.ones((len(rows), row_numel)).astype("float32")
|
|
|
|
|
ids_tensor = ids_selected_rows.get_tensor()
|
|
|
|
|
ids_tensor.set(np_array, place)
|
|
|
|
|
|
|
|
|
|
# create Out Variable
|
|
|
|
|
Out = scope.var('Out').get_selected_rows()
|
|
|
|
|
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
|
|
|
|
|
Out.set_height(height)
|
|
|
|
|
Out.set_rows(rows)
|
|
|
|
|
Out_tensor = Out.get_tensor()
|
|
|
|
|
Out_tensor.set(Out_array, place)
|
|
|
|
|
|
|
|
|
|
# create and run concat_rows_op operator
|
|
|
|
|
# create and run lookup_table operator
|
|
|
|
|
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
|
|
|
|
|
concat_rows_op.run(scope, place)
|
|
|
|
|
|
|
|
|
|
# get and compare result
|
|
|
|
|
# get result from Out
|
|
|
|
|
Out_tensor = Out.get_tensor()
|
|
|
|
|
result_array = np.array(Out_tensor)
|
|
|
|
|
|
|
|
|
|
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
|
|
|
|
|
for idx, row in enumerate(rows):
|
|
|
|
|
assert (row == result_array[idx]).all()
|
|
|
|
|
|
|
|
|
|