|
|
|
@ -689,15 +689,6 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
|
pre_block_idx, grad_to_block_id):
|
|
|
|
|
def _clone_var(block, var, persistable=True):
|
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
|
return block.create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
shape=var.shape,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
type=var.type,
|
|
|
|
|
persistable=persistable)
|
|
|
|
|
|
|
|
|
|
# STEP: create table optimize block
|
|
|
|
|
# create table param and grad var in pserver program
|
|
|
|
|
origin_param_var = self.origin_program.global_block().vars[
|
|
|
|
@ -708,11 +699,11 @@ class DistributeTranspiler:
|
|
|
|
|
dtype=origin_param_var.dtype,
|
|
|
|
|
type=core.VarDesc.VarType.SELECTED_ROWS,
|
|
|
|
|
persistable=True)
|
|
|
|
|
grad_var = _clone_var(
|
|
|
|
|
pserver_program.global_block(),
|
|
|
|
|
# parameter must be selected rows
|
|
|
|
|
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
|
|
|
|
|
grad_var = pserver_program.global_block().clone_variable(
|
|
|
|
|
self.origin_program.global_block().vars[grad_var_name(
|
|
|
|
|
self.table_name)],
|
|
|
|
|
persistable=False)
|
|
|
|
|
self.table_name)])
|
|
|
|
|
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
|