|
|
|
@ -16,11 +16,12 @@ LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_distributed_lookup_table(program):
|
|
|
|
|
# process lookup_table_op
|
|
|
|
|
# 1. check all lookup_table_op is distributed
|
|
|
|
|
# 2. check all lookup_table_op share the same table.
|
|
|
|
|
distributed_lookup_table_ops = []
|
|
|
|
|
# support only one distributed_lookup_table now
|
|
|
|
|
"""
|
|
|
|
|
Find distribute lookup table in program.
|
|
|
|
|
We only support one distribute table now.
|
|
|
|
|
:param program:
|
|
|
|
|
:return: table_name or None
|
|
|
|
|
"""
|
|
|
|
|
table_name = None
|
|
|
|
|
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
@ -31,7 +32,6 @@ def find_distributed_lookup_table(program):
|
|
|
|
|
if table_name != op.input("W")[0]:
|
|
|
|
|
raise RuntimeError("all distributed lookup_table_ops"
|
|
|
|
|
" should have only one table")
|
|
|
|
|
distributed_lookup_table_ops.append(op)
|
|
|
|
|
else:
|
|
|
|
|
if table_name is not None:
|
|
|
|
|
assert op.input("W")[0] != table_name
|
|
|
|
|