|
|
|
@ -39,8 +39,8 @@ import six
|
|
|
|
|
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
|
|
|
|
|
from .. import core, framework
|
|
|
|
|
from ..framework import Program, default_main_program, \
|
|
|
|
|
default_startup_program, Block, \
|
|
|
|
|
Parameter, grad_var_name
|
|
|
|
|
default_startup_program, Block, \
|
|
|
|
|
Parameter, grad_var_name
|
|
|
|
|
from .details import *
|
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
@ -178,7 +178,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
pserver_program)
|
|
|
|
|
elif role == "TRAINER":
|
|
|
|
|
trainer_program = t.get_trainer_program()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# for nccl2 mode
|
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
|
config.mode = "nccl2"
|
|
|
|
@ -534,7 +534,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in six.iteritems(self.param_var_mapping):
|
|
|
|
|
#add concat ops to merge splited parameters received from parameter servers.
|
|
|
|
|
# add concat ops to merge splited parameters received from parameter servers.
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
# NOTE: if enable memory optimization, origin vars maybe removed.
|
|
|
|
@ -734,19 +734,14 @@ in a single call.")
|
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
|
|
|
|
|
optimize_blocks.append(table_opt_block)
|
|
|
|
|
prefetch_var_name_to_block_id = self._create_prefetch_block(
|
|
|
|
|
lookup_table_var_name_to_block_id = self._create_prefetch_block(
|
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
checkpoint_block_id = self._create_checkpoint_save_block(
|
|
|
|
|
pserver_program, table_opt_block.idx)
|
|
|
|
|
|
|
|
|
|
pserver_program._distributed_lookup_table = self.table_name
|
|
|
|
|
|
|
|
|
|
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
|
|
|
|
|
# not be executed, so it's safe to use optimize_block to hold the place
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
assert len(prefetch_var_name_to_block_id) > 0
|
|
|
|
|
else:
|
|
|
|
|
assert len(prefetch_var_name_to_block_id) == 0
|
|
|
|
|
prefetch_var_name_to_block_id.extend(
|
|
|
|
|
lookup_table_var_name_to_block_id)
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
"optimize_blocks": optimize_blocks,
|
|
|
|
@ -755,11 +750,14 @@ in a single call.")
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
"grad_to_block_id": grad_to_block_id,
|
|
|
|
|
}
|
|
|
|
|
if len(prefetch_var_name_to_block_id) > 0:
|
|
|
|
|
attrs['prefetch_var_name_to_block_id'] \
|
|
|
|
|
= prefetch_var_name_to_block_id
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
attrs['checkpint_block_id'] = checkpoint_block_id
|
|
|
|
|
|
|
|
|
|
if len(prefetch_var_name_to_block_id) > 0:
|
|
|
|
|
attrs[
|
|
|
|
|
'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id
|
|
|
|
|
|
|
|
|
|
# step5 append the listen_and_serv op
|
|
|
|
|
pserver_program.global_block().append_op(
|
|
|
|
|
type="listen_and_serv",
|
|
|
|
@ -1013,7 +1011,7 @@ to transpile() call.")
|
|
|
|
|
for g, p in zip(grad_blocks, param_blocks):
|
|
|
|
|
g_name, g_bid, _ = g.split(":")
|
|
|
|
|
p_name, p_bid, _ = p.split(":")
|
|
|
|
|
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
|
|
|
|
|
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
|
|
|
|
|
self.param_var_mapping[p_name][int(p_bid)]
|
|
|
|
|
|
|
|
|
|
# create mapping of endpoint -> split var to create pserver side program
|
|
|
|
@ -1320,7 +1318,7 @@ to transpile() call.")
|
|
|
|
|
if len(splited) == 1:
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.trainer_%d" % \
|
|
|
|
|
(orig_var.name, self.trainer_id)
|
|
|
|
|
(orig_var.name, self.trainer_id)
|
|
|
|
|
program.global_block()._rename_var(varname, new_var_name)
|
|
|
|
|
var_mapping[varname] = \
|
|
|
|
|
[program.global_block().var(new_var_name)]
|
|
|
|
@ -1343,10 +1341,10 @@ to transpile() call.")
|
|
|
|
|
new_var_name = ""
|
|
|
|
|
if self.sync_mode and add_trainer_suffix:
|
|
|
|
|
new_var_name = "%s.block%d.trainer_%d" % \
|
|
|
|
|
(varname, i, self.trainer_id)
|
|
|
|
|
(varname, i, self.trainer_id)
|
|
|
|
|
else:
|
|
|
|
|
new_var_name = "%s.block%d" % \
|
|
|
|
|
(varname, i)
|
|
|
|
|
(varname, i)
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
name=new_var_name,
|
|
|
|
|
persistable=False,
|
|
|
|
@ -1484,7 +1482,7 @@ to transpile() call.")
|
|
|
|
|
vars2merge = []
|
|
|
|
|
for i in range(self.trainer_num):
|
|
|
|
|
per_trainer_name = "%s.trainer_%d" % \
|
|
|
|
|
(merged_var_name, i)
|
|
|
|
|
(merged_var_name, i)
|
|
|
|
|
vars2merge.append(pserver_block.vars[per_trainer_name])
|
|
|
|
|
|
|
|
|
|
optimize_block.append_op(
|
|
|
|
@ -1645,7 +1643,7 @@ to transpile() call.")
|
|
|
|
|
# one op's output is another op's input, we say
|
|
|
|
|
# the two operator is connected.
|
|
|
|
|
if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \
|
|
|
|
|
set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()):
|
|
|
|
|
set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
@ -1662,7 +1660,7 @@ to transpile() call.")
|
|
|
|
|
|
|
|
|
|
def _is_optimizer_op(self, op):
|
|
|
|
|
if "Param" in op.input_names and \
|
|
|
|
|
"LearningRate" in op.input_names:
|
|
|
|
|
"LearningRate" in op.input_names:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
@ -1737,7 +1735,7 @@ to transpile() call.")
|
|
|
|
|
# NOTE: we need to skip all optimize ops, since it is connected
|
|
|
|
|
# with forward/backward ops and lr ops, we only need the lr ops.
|
|
|
|
|
if op1 != op2 and self._is_op_connected(op1, op2) and \
|
|
|
|
|
not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
|
|
|
|
|
not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
|
|
|
|
|
ufind.union(op1, op2)
|
|
|
|
|
# find all ops which is related with lr var
|
|
|
|
|
for op1 in block.ops:
|
|
|
|
|