|
|
|
@ -459,9 +459,10 @@ class DistributeTranspiler:
|
|
|
|
|
return pname, splited_param.shape
|
|
|
|
|
return "", []
|
|
|
|
|
|
|
|
|
|
# 1. create vars
|
|
|
|
|
# 1. create vars in pserver program to startup program
|
|
|
|
|
pserver_vars = pserver_program.global_block().vars
|
|
|
|
|
created_var_map = dict()
|
|
|
|
|
for _, var in pserver_program.global_block().vars.iteritems():
|
|
|
|
|
for _, var in pserver_vars.iteritems():
|
|
|
|
|
print("create var for startup", var.name, var.shape)
|
|
|
|
|
tmpvar = s_prog.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
@ -469,30 +470,36 @@ class DistributeTranspiler:
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
created_var_map[var.name] = tmpvar
|
|
|
|
|
optimize_op_input_var_names = [
|
|
|
|
|
v.name for v in pserver_program.global_block().vars.values()
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 2. rename op outputs
|
|
|
|
|
for op in orig_s_prog.global_block().ops:
|
|
|
|
|
new_outputs = dict()
|
|
|
|
|
# do not append startup op if var is not on this pserver
|
|
|
|
|
op_on_pserver = False
|
|
|
|
|
for key, var in op.outputs.iteritems():
|
|
|
|
|
newname, _ = _get_splited_name_and_shape(var.name)
|
|
|
|
|
if newname:
|
|
|
|
|
op_on_pserver = True
|
|
|
|
|
new_outputs[key] = created_var_map[newname]
|
|
|
|
|
else:
|
|
|
|
|
new_outputs[key] = var
|
|
|
|
|
# do not append startup op if var is not on this pserver
|
|
|
|
|
op_on_pserver = False
|
|
|
|
|
for _, var in op.outputs.iteritems():
|
|
|
|
|
if var.name in optimize_op_input_var_names:
|
|
|
|
|
elif var.name in pserver_vars:
|
|
|
|
|
op_on_pserver = True
|
|
|
|
|
break
|
|
|
|
|
new_outputs[key] = pserver_vars[var.name]
|
|
|
|
|
|
|
|
|
|
# newname, _ = _get_splited_name_and_shape(var.name)
|
|
|
|
|
# if newname:
|
|
|
|
|
# print("updating output", newname, created_var_map[newname])
|
|
|
|
|
# new_outputs[key] = created_var_map[newname]
|
|
|
|
|
# else:
|
|
|
|
|
# print("no update output", key, var)
|
|
|
|
|
# new_outputs[key] = var
|
|
|
|
|
# if var.name in created_var_map or \
|
|
|
|
|
# newname:
|
|
|
|
|
# op_on_pserver = True
|
|
|
|
|
|
|
|
|
|
if op_on_pserver:
|
|
|
|
|
# gaussian_random use attr to determine tensor shape
|
|
|
|
|
if op.type in ["gaussian_random", "fill_constant"]:
|
|
|
|
|
op.attrs["shape"] = new_outputs["Out"].shape
|
|
|
|
|
print("updated shape", op.attrs["shape"])
|
|
|
|
|
s_prog.global_block().append_op(
|
|
|
|
|
type=op.type,
|
|
|
|
|
inputs=op.inputs,
|
|
|
|
|