|
|
|
@ -121,7 +121,6 @@ def split_dense_variable(var_list,
|
|
|
|
|
block_size += dim1 - remains
|
|
|
|
|
# update split_count after aligning
|
|
|
|
|
split_count = int(math.ceil(var_numel / float(block_size)))
|
|
|
|
|
print("###split var ", var.name, var.shape, block_size, split_count)
|
|
|
|
|
for block_id in xrange(split_count):
|
|
|
|
|
curr_block_size = min(block_size, var_numel - (
|
|
|
|
|
(block_id) * block_size))
|
|
|
|
@ -207,7 +206,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
rpc_client_var = program.global_block().create_var(
|
|
|
|
|
name="RPC_CLIENT_VAR",
|
|
|
|
|
psersistable=True,
|
|
|
|
|
persistable=True,
|
|
|
|
|
dtype='float32', # dtype and shape is not used in fact
|
|
|
|
|
shape=[0])
|
|
|
|
|
|
|
|
|
@ -256,15 +255,13 @@ class DistributeTranspiler:
|
|
|
|
|
splited_shape = [rows]
|
|
|
|
|
if len(orig_shape) >= 2:
|
|
|
|
|
splited_shape.extend(orig_shape[1:])
|
|
|
|
|
print("###splited: ", size, rows, splited_shape)
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
name="%s.block%d" % (varname, i),
|
|
|
|
|
psersistable=False,
|
|
|
|
|
persistable=False,
|
|
|
|
|
dtype=orig_var.dtype,
|
|
|
|
|
type=orig_var.type,
|
|
|
|
|
shape=splited_shape) # flattend splited var
|
|
|
|
|
var_mapping[varname].append(var)
|
|
|
|
|
print("###created split var ", var)
|
|
|
|
|
return var_mapping
|
|
|
|
|
|
|
|
|
|
def _clone_var(self, block, var):
|
|
|
|
@ -322,7 +319,7 @@ class DistributeTranspiler:
|
|
|
|
|
for i in xrange(trainers):
|
|
|
|
|
var_each = block.create_var(
|
|
|
|
|
name="%s.trainer_%d" % (var.name, i),
|
|
|
|
|
psersistable=var.persistable,
|
|
|
|
|
persistable=var.persistable,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
type=var.type,
|
|
|
|
|
shape=var.shape)
|
|
|
|
@ -531,8 +528,6 @@ class DistributeTranspiler:
|
|
|
|
|
"""
|
|
|
|
|
# step5
|
|
|
|
|
pserver_program = Program()
|
|
|
|
|
print("param mapping on pserver: #### ",
|
|
|
|
|
self.param_grad_ep_mapping[endpoint]["params"])
|
|
|
|
|
for v in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
self._clone_var(pserver_program.global_block(), v)
|
|
|
|
|
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
|
|
|
|
|