|
|
|
@ -453,8 +453,7 @@ class DistributeTranspiler:
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
|
pserver_index, pserver_program, opt_state_block or
|
|
|
|
|
pserver_program.global_block())
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx)
|
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
|
|
|
|
@ -665,7 +664,7 @@ class DistributeTranspiler:
|
|
|
|
|
return prefetch_block
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
|
append_block):
|
|
|
|
|
pre_block_idx):
|
|
|
|
|
def _clone_var(block, var, persistable=True):
|
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
|
return block.create_var(
|
|
|
|
@ -702,7 +701,7 @@ class DistributeTranspiler:
|
|
|
|
|
op for op in self.optimize_ops
|
|
|
|
|
if op.input("Param")[0] == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
table_opt_block = pserver_program.create_block(append_block.idx)
|
|
|
|
|
table_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
# only support sgd now
|
|
|
|
|
assert table_opt_op.type == "sgd"
|
|
|
|
|
|
|
|
|
|