|
|
|
@ -410,7 +410,8 @@ class DistributeTranspiler:
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program, pserver_endpoints)
|
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program,
|
|
|
|
|
pserver_endpoints)
|
|
|
|
|
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
def get_trainer_program(self):
|
|
|
|
@ -631,7 +632,8 @@ class DistributeTranspiler:
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints):
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
|