|
|
|
@ -130,13 +130,22 @@ class DistributedAdam(DistributedOptimizerImplBase):
|
|
|
|
|
find multi-sparse-table
|
|
|
|
|
"""
|
|
|
|
|
table_names = set()
|
|
|
|
|
cnt = 0
|
|
|
|
|
tmp_list = []
|
|
|
|
|
ret_list = []
|
|
|
|
|
for loss in losses:
|
|
|
|
|
for op in loss.block.program.global_block().ops:
|
|
|
|
|
if op.type == "lookup_table":
|
|
|
|
|
if op.attr('is_distributed') is True:
|
|
|
|
|
table_name = op.input("W")[0]
|
|
|
|
|
table_names.add(table_name)
|
|
|
|
|
return list(table_names)
|
|
|
|
|
if table_name not in table_names:
|
|
|
|
|
table_names.add(table_name)
|
|
|
|
|
tmp_list.append([table_name, cnt])
|
|
|
|
|
cnt += 1
|
|
|
|
|
tmp_list.sort(key=lambda k: k[1])
|
|
|
|
|
for x in tmp_list:
|
|
|
|
|
ret_list.append(x[0])
|
|
|
|
|
return ret_list
|
|
|
|
|
|
|
|
|
|
def _minimize(self,
|
|
|
|
|
losses,
|
|
|
|
|