|
|
|
@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name):
|
|
|
|
|
return p_name == var_name or p_name.startswith(var_name + ".block")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_variable(var_list, service_count, min_block_size=8192):
|
|
|
|
|
def slice_variable(var_list, slice_count, min_block_size=8192):
|
|
|
|
|
"""
|
|
|
|
|
We may need to split dense tensor to one or more blocks and put
|
|
|
|
|
them equally onto parameter server. One block is a sub-tensor
|
|
|
|
@ -83,8 +83,8 @@ def split_variable(var_list, service_count, min_block_size=8192):
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
var_list (list): List of variables.
|
|
|
|
|
service_count (int): Numel of pserver services. A pserver may have two
|
|
|
|
|
or more listening ports.
|
|
|
|
|
slice_count (int): Numel of count that variables will be sliced, which
|
|
|
|
|
could be the pserver services' count.
|
|
|
|
|
min_block_size (int): Minimum splitted block size.
|
|
|
|
|
Returns:
|
|
|
|
|
blocks (list[(varname, block_id, current_block_size)]): A list
|
|
|
|
@ -92,12 +92,12 @@ def split_variable(var_list, service_count, min_block_size=8192):
|
|
|
|
|
"""
|
|
|
|
|
blocks = []
|
|
|
|
|
for var in var_list:
|
|
|
|
|
split_count = service_count
|
|
|
|
|
split_count = slice_count
|
|
|
|
|
var_numel = reduce(lambda x, y: x * y, var.shape)
|
|
|
|
|
max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
|
|
|
|
|
if max_pserver_count == 0:
|
|
|
|
|
max_pserver_count = 1
|
|
|
|
|
if max_pserver_count < service_count:
|
|
|
|
|
if max_pserver_count < slice_count:
|
|
|
|
|
split_count = max_pserver_count
|
|
|
|
|
block_size = int(math.ceil(var_numel / float(split_count)))
|
|
|
|
|
|
|
|
|
@ -178,7 +178,7 @@ class DistributeTranspiler:
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def _init_splited_vars(self, split_method, align_var_to_block=True):
|
|
|
|
|
def _init_splited_vars(self, slice_var_up):
|
|
|
|
|
# update these mappings for further transpile:
|
|
|
|
|
# 1. param_var_mapping: param var name -> [splited params vars]
|
|
|
|
|
# 2. grad_var_mapping: grad var name -> [splited grads vars]
|
|
|
|
@ -197,16 +197,19 @@ class DistributeTranspiler:
|
|
|
|
|
self._update_dist_lookup_table_vars(param_list, grad_list,
|
|
|
|
|
self.params_grads)
|
|
|
|
|
|
|
|
|
|
if align_var_to_block:
|
|
|
|
|
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
|
|
|
|
|
param_blocks = split_variable(param_list,
|
|
|
|
|
if slice_var_up:
|
|
|
|
|
# when we slice var up into blocks, we will slice the var according to
|
|
|
|
|
# pserver services' count. A pserver may have two or more listening ports.
|
|
|
|
|
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
|
|
|
|
|
param_blocks = slice_variable(param_list,
|
|
|
|
|
len(self.pserver_endpoints))
|
|
|
|
|
else:
|
|
|
|
|
# when we do NOT align var to block, we will always split params
|
|
|
|
|
# when we do NOT slice var up into blocks, we will always slice params
|
|
|
|
|
# grads into one block.
|
|
|
|
|
grad_blocks = split_variable(grad_list, 1)
|
|
|
|
|
param_blocks = split_variable(param_list, 1)
|
|
|
|
|
grad_blocks = slice_variable(grad_list, 1)
|
|
|
|
|
param_blocks = slice_variable(param_list, 1)
|
|
|
|
|
assert (len(grad_blocks) == len(param_blocks))
|
|
|
|
|
|
|
|
|
|
# origin_varname -> [splited_var]
|
|
|
|
|
self.param_var_mapping = self._create_vars_from_blocklist(
|
|
|
|
|
self.origin_program, param_blocks)
|
|
|
|
@ -237,7 +240,7 @@ class DistributeTranspiler:
|
|
|
|
|
program=None,
|
|
|
|
|
pservers="127.0.0.1:6174",
|
|
|
|
|
trainers=1,
|
|
|
|
|
align_var_to_block=True,
|
|
|
|
|
slice_var_up=True,
|
|
|
|
|
split_method=RoundRobin,
|
|
|
|
|
sync_mode=True):
|
|
|
|
|
"""
|
|
|
|
@ -271,7 +274,7 @@ class DistributeTranspiler:
|
|
|
|
|
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
|
|
|
|
|
|
|
|
|
|
# split and create vars, then put splited vars in dicts for later use.
|
|
|
|
|
self._init_splited_vars(split_method, align_var_to_block)
|
|
|
|
|
self._init_splited_vars(slice_var_up)
|
|
|
|
|
|
|
|
|
|
# step 3.1: insert send op to send gradient vars to parameter servers
|
|
|
|
|
ps_dispatcher.reset()
|
|
|
|
@ -283,13 +286,13 @@ class DistributeTranspiler:
|
|
|
|
|
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
|
|
|
|
|
# shuffle the map will avoid the uneven distribution above
|
|
|
|
|
grad_var_mapping_items = self.grad_var_mapping.items()
|
|
|
|
|
if not align_var_to_block:
|
|
|
|
|
if not slice_var_up:
|
|
|
|
|
np.random.shuffle(grad_var_mapping_items)
|
|
|
|
|
|
|
|
|
|
for orig_varname, splited_vars in grad_var_mapping_items:
|
|
|
|
|
eplist = ps_dispatcher.dispatch(splited_vars)
|
|
|
|
|
|
|
|
|
|
if not align_var_to_block:
|
|
|
|
|
if not slice_var_up:
|
|
|
|
|
assert (len(splited_vars) == 1)
|
|
|
|
|
|
|
|
|
|
if len(splited_vars) == 1:
|
|
|
|
|