|
|
|
@ -300,6 +300,9 @@ class DistributeTranspiler:
|
|
|
|
|
pass
|
|
|
|
|
return orig_shape
|
|
|
|
|
|
|
|
|
|
def _op_input_var(self, op, varname):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _is_op_on_pserver(self, endpoint, all_ops, idx):
|
|
|
|
|
"""
|
|
|
|
|
Recursively check if the op need to run on current server.
|
|
|
|
@ -309,29 +312,35 @@ class DistributeTranspiler:
|
|
|
|
|
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
|
|
|
|
|
]
|
|
|
|
|
op = all_ops[idx]
|
|
|
|
|
if op.inputs.has_key("Param"):
|
|
|
|
|
if op.inputs["Param"].name in param_names:
|
|
|
|
|
input_names = set(op.input_names)
|
|
|
|
|
# TODO(typhoonzero): using Param and Grad input name to identify
|
|
|
|
|
# that the operator is an optimization operator, need a better way.
|
|
|
|
|
if "Param" in input_names:
|
|
|
|
|
if op.input("Param")[0] in param_names:
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
for n in param_names:
|
|
|
|
|
if same_or_split_var(n, op.inputs[
|
|
|
|
|
"Param"].name) and n != op.inputs["Param"].name:
|
|
|
|
|
if same_or_split_var(n, op.input("Param")[0]) \
|
|
|
|
|
and n != op.input("Param")[0]:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
else:
|
|
|
|
|
j = idx - 1
|
|
|
|
|
while j >= 0:
|
|
|
|
|
prev_op = all_ops[j]
|
|
|
|
|
prev_output_names = [o.name for o in prev_op.outputs.values()]
|
|
|
|
|
prev_input_names = [o.name for o in prev_op.inputs.values()]
|
|
|
|
|
# prev_output_names = [o.name for o in prev_op.outputs.values()]
|
|
|
|
|
# prev_input_names = [o.name for o in prev_op.inputs.values()]
|
|
|
|
|
# NOTE(typhoonzero): consider list input/output
|
|
|
|
|
prev_output_names = prev_op.desc.output_arg_names()
|
|
|
|
|
prev_input_names = prev_op.desc.input_arg_names()
|
|
|
|
|
found1 = False
|
|
|
|
|
found2 = False
|
|
|
|
|
for _, v in op.inputs.iteritems():
|
|
|
|
|
if v.name in prev_output_names:
|
|
|
|
|
for varname in op.desc.input_arg_names():
|
|
|
|
|
if varname in prev_output_names:
|
|
|
|
|
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
|
|
|
|
|
# later ops may produce output for prev op's next batch use.
|
|
|
|
|
for _, v in op.outputs.iteritems():
|
|
|
|
|
if v.name in prev_input_names:
|
|
|
|
|
for varname in op.desc.output_arg_names():
|
|
|
|
|
if varname in prev_input_names:
|
|
|
|
|
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
|
|
|
|
|
if found1 or found2:
|
|
|
|
|
return True
|
|
|
|
@ -342,11 +351,11 @@ class DistributeTranspiler:
|
|
|
|
|
new_inputs = dict()
|
|
|
|
|
# update param/grad shape first, then other inputs like
|
|
|
|
|
# moment can use the updated shape
|
|
|
|
|
for key, var in opt_op.inputs.iteritems():
|
|
|
|
|
for key in opt_op.input_names:
|
|
|
|
|
if key == "Grad":
|
|
|
|
|
grad_block = None
|
|
|
|
|
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
|
|
|
|
|
if same_or_split_var(g.name, var.name):
|
|
|
|
|
if same_or_split_var(g.name, opt_op.input(key)[0]):
|
|
|
|
|
grad_block = g
|
|
|
|
|
break
|
|
|
|
|
if not grad_block:
|
|
|
|
@ -376,7 +385,7 @@ class DistributeTranspiler:
|
|
|
|
|
# param is already created on global program
|
|
|
|
|
param_block = None
|
|
|
|
|
for p in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
if same_or_split_var(p.name, var.name):
|
|
|
|
|
if same_or_split_var(p.name, opt_op.input(key)[0]):
|
|
|
|
|
param_block = p
|
|
|
|
|
break
|
|
|
|
|
if not param_block:
|
|
|
|
@ -389,11 +398,12 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
new_inputs[key] = tmpvar
|
|
|
|
|
|
|
|
|
|
for key, var in opt_op.inputs.iteritems():
|
|
|
|
|
for key in opt_op.input_names:
|
|
|
|
|
if key in ["Param", "Grad"]:
|
|
|
|
|
continue
|
|
|
|
|
# update accumulator variable shape
|
|
|
|
|
param_shape = new_inputs["Param"].shape
|
|
|
|
|
var = program.global_block().vars[opt_op.input(key)[0]]
|
|
|
|
|
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
|
|
|
|
|
var.shape, param_shape)
|
|
|
|
|
tmpvar = program.global_block().create_var(
|
|
|
|
@ -412,30 +422,44 @@ class DistributeTranspiler:
|
|
|
|
|
shape=new_shape)
|
|
|
|
|
|
|
|
|
|
# change output's ParamOut variable
|
|
|
|
|
opt_op.outputs["ParamOut"] = new_inputs["Param"]
|
|
|
|
|
outputs = self._get_output_map_from_op(program.global_block(), opt_op)
|
|
|
|
|
outputs["ParamOut"] = new_inputs["Param"]
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type=opt_op.type,
|
|
|
|
|
inputs=new_inputs,
|
|
|
|
|
outputs=opt_op.outputs,
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
attrs=opt_op.attrs)
|
|
|
|
|
|
|
|
|
|
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
|
|
|
|
|
# Append the ops for parameters that do not need to be optimized/updated
|
|
|
|
|
for _, var in opt_op.inputs.iteritems():
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
inputs = self._get_input_map_from_op(self.program.global_block().vars,
|
|
|
|
|
opt_op)
|
|
|
|
|
for var in inputs.itervalues():
|
|
|
|
|
if type(var) == list:
|
|
|
|
|
varlist = var
|
|
|
|
|
else:
|
|
|
|
|
varlist = [var]
|
|
|
|
|
for var in varlist:
|
|
|
|
|
# TODO(typhoonzero): will remove below line later.
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
if not pserver_program.global_block().vars.has_key(var.name):
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
shape=var.shape)
|
|
|
|
|
|
|
|
|
|
outputs = self._get_output_map_from_op(self.program.global_block().vars,
|
|
|
|
|
opt_op)
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type=opt_op.type,
|
|
|
|
|
inputs=opt_op.inputs,
|
|
|
|
|
outputs=opt_op.outputs,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
attrs=opt_op.attrs)
|
|
|
|
|
|
|
|
|
|
def get_pserver_program(self, endpoint):
|
|
|
|
@ -472,7 +496,7 @@ class DistributeTranspiler:
|
|
|
|
|
self.optimize_ops, idx)
|
|
|
|
|
if not is_op_on_pserver:
|
|
|
|
|
continue
|
|
|
|
|
if opt_op.inputs.has_key("Grad"):
|
|
|
|
|
if "Grad" in opt_op.desc.input_arg_names():
|
|
|
|
|
self._append_pserver_ops(optimize_sub_program, pserver_program,
|
|
|
|
|
opt_op, endpoint)
|
|
|
|
|
else:
|
|
|
|
@ -499,6 +523,30 @@ class DistributeTranspiler:
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
|
return pserver_program
|
|
|
|
|
|
|
|
|
|
def _get_input_map_from_op(self, varmap, op):
|
|
|
|
|
iomap = dict()
|
|
|
|
|
for key in op.input_names:
|
|
|
|
|
vars = []
|
|
|
|
|
for varname in op.input(key):
|
|
|
|
|
vars.append(varmap[varname])
|
|
|
|
|
if len(vars) == 1:
|
|
|
|
|
iomap[key] = vars[0]
|
|
|
|
|
else:
|
|
|
|
|
iomap[key] = vars
|
|
|
|
|
return iomap
|
|
|
|
|
|
|
|
|
|
def _get_output_map_from_op(self, varmap, op):
|
|
|
|
|
iomap = dict()
|
|
|
|
|
for key in op.output_names:
|
|
|
|
|
vars = []
|
|
|
|
|
for varname in op.output(key):
|
|
|
|
|
vars.append(varmap[varname])
|
|
|
|
|
if len(vars) == 1:
|
|
|
|
|
iomap[key] = vars[0]
|
|
|
|
|
else:
|
|
|
|
|
iomap[key] = vars
|
|
|
|
|
return iomap
|
|
|
|
|
|
|
|
|
|
def get_startup_program(self, endpoint, pserver_program):
|
|
|
|
|
"""
|
|
|
|
|
Get startup program for current parameter server.
|
|
|
|
@ -529,17 +577,21 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# 2. rename op outputs
|
|
|
|
|
for op in orig_s_prog.global_block().ops:
|
|
|
|
|
new_inputs = dict()
|
|
|
|
|
new_outputs = dict()
|
|
|
|
|
# do not append startup op if var is not on this pserver
|
|
|
|
|
op_on_pserver = False
|
|
|
|
|
for key, var in op.outputs.iteritems():
|
|
|
|
|
newname, _ = _get_splited_name_and_shape(var.name)
|
|
|
|
|
for key in op.output_names:
|
|
|
|
|
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
|
|
|
|
|
if newname:
|
|
|
|
|
op_on_pserver = True
|
|
|
|
|
new_outputs[key] = created_var_map[newname]
|
|
|
|
|
elif var.name in pserver_vars:
|
|
|
|
|
elif op.output(key)[0] in pserver_vars:
|
|
|
|
|
op_on_pserver = True
|
|
|
|
|
new_outputs[key] = pserver_vars[var.name]
|
|
|
|
|
new_outputs[key] = pserver_vars[op.output(key)[0]]
|
|
|
|
|
|
|
|
|
|
# most startup program ops have no inputs
|
|
|
|
|
new_inputs = self._get_input_map_from_op(pserver_vars, op)
|
|
|
|
|
|
|
|
|
|
if op_on_pserver:
|
|
|
|
|
if op.type in [
|
|
|
|
@ -548,7 +600,7 @@ class DistributeTranspiler:
|
|
|
|
|
op.attrs["shape"] = new_outputs["Out"].shape
|
|
|
|
|
s_prog.global_block().append_op(
|
|
|
|
|
type=op.type,
|
|
|
|
|
inputs=op.inputs,
|
|
|
|
|
inputs=new_inputs,
|
|
|
|
|
outputs=new_outputs,
|
|
|
|
|
attrs=op.attrs)
|
|
|
|
|
return s_prog
|
|
|
|
|