|
|
|
@ -408,11 +408,16 @@ class DistributeTranspiler:
|
|
|
|
|
pserver_vars = pserver_program.global_block().vars
|
|
|
|
|
created_var_map = dict()
|
|
|
|
|
for _, var in pserver_vars.iteritems():
|
|
|
|
|
tmpvar = s_prog.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
if var.type == core.VarDesc.VarType.STEP_SCOPES:
|
|
|
|
|
tmpvar = s_prog.global_block().create_var(
|
|
|
|
|
name=var.name, persistable=var.persistable, type=var.type)
|
|
|
|
|
else:
|
|
|
|
|
tmpvar = s_prog.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
type=var.type,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
created_var_map[var.name] = tmpvar
|
|
|
|
|
|
|
|
|
|
# 2. rename op outputs
|
|
|
|
@ -708,11 +713,18 @@ class DistributeTranspiler:
|
|
|
|
|
varlist = [varlist]
|
|
|
|
|
|
|
|
|
|
for var in varlist:
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
print("##### deal var: ", var)
|
|
|
|
|
if var.type == core.VarDesc.VarType.STEP_SCOPES:
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
type=var.type)
|
|
|
|
|
else:
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
|
|
|
|
|
optimize_block.append_op(
|
|
|
|
|
type=opt_op.type,
|
|
|
|
|