|
|
|
@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
|
|
|
|
|
)
|
|
|
|
|
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
|
|
|
|
|
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
|
|
|
|
|
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
|
|
|
|
|
|
|
|
|
|
PRINT_LOG = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log(*args):
|
|
|
|
|
if PRINT_LOG:
|
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VarBlock:
|
|
|
|
@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object):
|
|
|
|
|
slice_var_up = True
|
|
|
|
|
split_method = None
|
|
|
|
|
min_block_size = 8192
|
|
|
|
|
print_log = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributeTranspiler(object):
|
|
|
|
@ -174,6 +184,9 @@ class DistributeTranspiler(object):
|
|
|
|
|
if self.config.split_method is None:
|
|
|
|
|
self.config.split_method = RoundRobin
|
|
|
|
|
|
|
|
|
|
global PRINT_LOG
|
|
|
|
|
if self.config.print_log:
|
|
|
|
|
PRINT_LOG = True
|
|
|
|
|
assert (self.config.min_block_size >= 8192)
|
|
|
|
|
assert (self.config.split_method.__bases__[0] == PSDispatcher)
|
|
|
|
|
|
|
|
|
@ -257,12 +270,12 @@ class DistributeTranspiler(object):
|
|
|
|
|
splited_grad_varname = grad_varname
|
|
|
|
|
if len(splited_vars) == 1:
|
|
|
|
|
splited_grad_varname = splited_vars[0].name
|
|
|
|
|
index = find_op_by_output_arg(program.global_block(),
|
|
|
|
|
splited_grad_varname)
|
|
|
|
|
index = find_op_by_output_arg(
|
|
|
|
|
program.global_block(), splited_grad_varname, reverse=True)
|
|
|
|
|
elif len(splited_vars) > 1:
|
|
|
|
|
orig_var = program.global_block().vars[splited_grad_varname]
|
|
|
|
|
index = find_op_by_output_arg(program.global_block(),
|
|
|
|
|
splited_grad_varname)
|
|
|
|
|
index = find_op_by_output_arg(
|
|
|
|
|
program.global_block(), splited_grad_varname, reverse=True)
|
|
|
|
|
self._insert_split_op(program, orig_var, index, splited_vars)
|
|
|
|
|
index += 1
|
|
|
|
|
else:
|
|
|
|
@ -301,7 +314,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
self.grad_name_to_send_dummy_out[
|
|
|
|
|
self.table_name] = program.global_block().create_var(
|
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
|
input_deps = self.grad_name_to_send_dummy_out.values()
|
|
|
|
|
input_deps = list(self.grad_name_to_send_dummy_out.values())
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="send_barrier",
|
|
|
|
@ -377,7 +390,10 @@ class DistributeTranspiler(object):
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={"X": splited_var},
|
|
|
|
|
outputs={"Out": [orig_param]},
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
attrs={
|
|
|
|
|
"axis": 0,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
|
|
|
|
|
|
|
|
|
@ -496,9 +512,9 @@ class DistributeTranspiler(object):
|
|
|
|
|
# NOTE: assume blocks of the same variable is not distributed
|
|
|
|
|
# on the same pserver, only change param/grad varnames for
|
|
|
|
|
# trainers to fetch.
|
|
|
|
|
sys.stderr.write("get_pserver_program() is deprecated, call\
|
|
|
|
|
get_pserver_programs() to get pserver main and startup\
|
|
|
|
|
in a single call.")
|
|
|
|
|
sys.stderr.write("get_pserver_program() is deprecated, call \
|
|
|
|
|
get_pserver_programs() to get pserver main and startup \
|
|
|
|
|
in a single call.")
|
|
|
|
|
# step1
|
|
|
|
|
pserver_program = Program()
|
|
|
|
|
pserver_program.random_seed = self.origin_program.random_seed
|
|
|
|
@ -615,22 +631,31 @@ class DistributeTranspiler(object):
|
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
per_opt_block = pserver_program._create_block(pre_block_idx)
|
|
|
|
|
optimize_blocks.append(per_opt_block)
|
|
|
|
|
optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
|
|
|
|
|
# append grad merging ops before clip and weight decay
|
|
|
|
|
# cases may like:
|
|
|
|
|
# L2Decay op -> clip op -> optimize
|
|
|
|
|
# e.g. merge grad -> L2Decay op -> clip op -> optimize
|
|
|
|
|
merged_var = None
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# find the origin @GRAD var before clipping
|
|
|
|
|
grad_varname_for_block = __op_have_grad_input__(op)
|
|
|
|
|
if ufind.is_connected(op, opt_op) and grad_varname_for_block:
|
|
|
|
|
# find the origin grad var before clipping/L2Decay,
|
|
|
|
|
# merged_var should be the input var name of L2Decaybuil
|
|
|
|
|
grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
|
|
|
|
|
if op.attr(OP_ROLE_VAR_ATTR_NAME)[
|
|
|
|
|
0] == optimize_target_param_name:
|
|
|
|
|
merged_var = self._append_pserver_grad_merge_ops(
|
|
|
|
|
per_opt_block, grad_varname_for_block, endpoint,
|
|
|
|
|
grad_to_block_id, self.origin_program)
|
|
|
|
|
break # append optimize op once then append other ops.
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
|
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
|
|
|
|
|
merged_var, lr_ops)
|
|
|
|
|
if merged_var:
|
|
|
|
|
break # append optimize op once then append other ops.
|
|
|
|
|
if merged_var:
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \
|
|
|
|
|
op not in global_ops:
|
|
|
|
|
log("append opt op: ", op.type, op.input_arg_names,
|
|
|
|
|
merged_var)
|
|
|
|
|
__append_optimize_op__(op, per_opt_block,
|
|
|
|
|
grad_to_block_id, merged_var,
|
|
|
|
|
lr_ops)
|
|
|
|
|
|
|
|
|
|
# dedup grad to ids list
|
|
|
|
|
grad_to_block_id = list(set(grad_to_block_id))
|
|
|
|
@ -726,17 +751,17 @@ class DistributeTranspiler(object):
|
|
|
|
|
Returns:
|
|
|
|
|
Program: parameter server side startup program.
|
|
|
|
|
"""
|
|
|
|
|
sys.stderr.write("get_startup_program() is deprecated, call\
|
|
|
|
|
get_pserver_programs() to get pserver main and startup\
|
|
|
|
|
in a single call.")
|
|
|
|
|
sys.stderr.write("get_startup_program() is deprecated, call \
|
|
|
|
|
get_pserver_programs() to get pserver main and startup \
|
|
|
|
|
in a single call.")
|
|
|
|
|
if pserver_program != None:
|
|
|
|
|
sys.stderr.write("passing pserver_program to get_startup_program()\
|
|
|
|
|
is deprecated, you can use new API get_pserver_programs() to\
|
|
|
|
|
get both pserver main program and startup program.")
|
|
|
|
|
sys.stderr.write("passing pserver_program to get_startup_program() \
|
|
|
|
|
is deprecated, you can use new API get_pserver_programs() to \
|
|
|
|
|
get both pserver main program and startup program.")
|
|
|
|
|
if startup_program != None:
|
|
|
|
|
sys.stderr.write("passing startup_program to get_startup_program()\
|
|
|
|
|
is deprecated, use fluid.program_guard() or pass this argument\
|
|
|
|
|
to transpile() call.")
|
|
|
|
|
sys.stderr.write("passing startup_program to get_startup_program() \
|
|
|
|
|
is deprecated, use fluid.program_guard() or pass this argument \
|
|
|
|
|
to transpile() call.")
|
|
|
|
|
|
|
|
|
|
s_prog = Program()
|
|
|
|
|
orig_s_prog = self.startup_program
|
|
|
|
@ -1302,7 +1327,10 @@ class DistributeTranspiler(object):
|
|
|
|
|
type="split_selected_rows",
|
|
|
|
|
inputs={"X": orig_var},
|
|
|
|
|
outputs={"Out": splited_vars},
|
|
|
|
|
attrs={"height_sections": height_sections})
|
|
|
|
|
attrs={
|
|
|
|
|
"height_sections": height_sections,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
|
sections = []
|
|
|
|
|
for v in splited_vars:
|
|
|
|
@ -1312,8 +1340,10 @@ class DistributeTranspiler(object):
|
|
|
|
|
type="split_byref",
|
|
|
|
|
inputs={"X": orig_var},
|
|
|
|
|
outputs={"Out": splited_vars},
|
|
|
|
|
attrs={"sections": sections} # assume split evenly
|
|
|
|
|
)
|
|
|
|
|
attrs={
|
|
|
|
|
"sections": sections,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
else:
|
|
|
|
|
AssertionError("Variable type should be in set "
|
|
|
|
|
"[LOD_TENSOR, SELECTED_ROWS]")
|
|
|
|
@ -1381,15 +1411,15 @@ class DistributeTranspiler(object):
|
|
|
|
|
if not grad_block:
|
|
|
|
|
# do not append this op if current endpoint
|
|
|
|
|
# is not dealing with this grad block
|
|
|
|
|
return
|
|
|
|
|
return None
|
|
|
|
|
orig_varname, block_name, trainer_name = self._get_varname_parts(
|
|
|
|
|
grad_block.name)
|
|
|
|
|
if block_name:
|
|
|
|
|
merged_var_name = '.'.join([orig_varname, block_name])
|
|
|
|
|
else:
|
|
|
|
|
merged_var_name = orig_varname
|
|
|
|
|
merged_var = \
|
|
|
|
|
pserver_block.vars[merged_var_name]
|
|
|
|
|
|
|
|
|
|
merged_var = pserver_block.vars[merged_var_name]
|
|
|
|
|
grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
|
|
|
|
|
if self.sync_mode and self.trainer_num > 1:
|
|
|
|
|
vars2merge = []
|
|
|
|
@ -1473,7 +1503,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
outputs = self._get_output_map_from_op(
|
|
|
|
|
self.origin_program.global_block().vars, opt_op)
|
|
|
|
|
outputs["ParamOut"] = new_inputs["Param"]
|
|
|
|
|
|
|
|
|
|
optimize_block.append_op(
|
|
|
|
|
type=opt_op.type,
|
|
|
|
|
inputs=new_inputs,
|
|
|
|
@ -1618,6 +1647,16 @@ class DistributeTranspiler(object):
|
|
|
|
|
return iomap
|
|
|
|
|
|
|
|
|
|
def _get_lr_ops(self):
|
|
|
|
|
lr_ops = []
|
|
|
|
|
block = self.origin_program.global_block()
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int(
|
|
|
|
|
LR_SCHED_OP_ROLE_ATTR_VALUE):
|
|
|
|
|
lr_ops.append(op)
|
|
|
|
|
log("append lr op: ", op.type)
|
|
|
|
|
return lr_ops
|
|
|
|
|
|
|
|
|
|
def _get_lr_ops_deprecated(self):
|
|
|
|
|
lr_ops = []
|
|
|
|
|
# find learning rate variables by optimize op
|
|
|
|
|
lr_vars = set()
|
|
|
|
@ -1670,20 +1709,21 @@ class DistributeTranspiler(object):
|
|
|
|
|
block = self.origin_program.global_block()
|
|
|
|
|
opt_ops = []
|
|
|
|
|
params_grads = []
|
|
|
|
|
# tmp set to dedup
|
|
|
|
|
optimize_params = set()
|
|
|
|
|
origin_var_dict = self.origin_program.global_block().vars
|
|
|
|
|
for op in block.ops:
|
|
|
|
|
if self._is_opt_role_op(op):
|
|
|
|
|
opt_ops.append(op)
|
|
|
|
|
# HACK(wuyi): if we find grad vars from input of optimize
|
|
|
|
|
# ops, we may get the output of clip op. Use syntax "@GRAD"
|
|
|
|
|
# and op_role_var to get the pair.
|
|
|
|
|
for input_name in op.input_arg_names:
|
|
|
|
|
if input_name.find("@GRAD") != -1 and \
|
|
|
|
|
op.attr(RPC_OP_ROLE_ATTR_NAME):
|
|
|
|
|
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
|
|
|
|
|
if op.attr(OP_ROLE_VAR_ATTR_NAME):
|
|
|
|
|
param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
|
|
|
|
|
grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
|
|
|
|
|
if not param_name in optimize_params:
|
|
|
|
|
optimize_params.add(param_name)
|
|
|
|
|
log("adding param_grad pair: ", param_name, grad_name)
|
|
|
|
|
params_grads.append([
|
|
|
|
|
origin_var_dict[param_name],
|
|
|
|
|
origin_var_dict[input_name]
|
|
|
|
|
origin_var_dict[grad_name]
|
|
|
|
|
])
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|