|
|
|
@ -56,6 +56,8 @@ def split_dense_variable(var_list,
|
|
|
|
|
(block_id) * block_size))
|
|
|
|
|
block = VarBlock(var.name, block_id, curr_block_size)
|
|
|
|
|
blocks.append(str(block))
|
|
|
|
|
print("$$ splited var: ", var.name, var.shape, split_count, len(blocks),
|
|
|
|
|
block_size)
|
|
|
|
|
return blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -132,10 +134,12 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# step4
|
|
|
|
|
for varname, splited_var in param_var_mapping.iteritems():
|
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_param = program.global_block().vars[varname]
|
|
|
|
|
concat = program.global_block().append_op(
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={"X": send_outputs},
|
|
|
|
|
inputs={"X": splited_var},
|
|
|
|
|
outputs={"Out": orig_param},
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
|
|
|
|
@ -147,28 +151,29 @@ class DistributeTranspiler:
|
|
|
|
|
if not block_map.has_key(varname):
|
|
|
|
|
block_map[varname] = []
|
|
|
|
|
block_map[varname].append((long(offset), long(size)))
|
|
|
|
|
|
|
|
|
|
for varname, splited in block_map.iteritems():
|
|
|
|
|
orig_var = program.global_block().vars[varname]
|
|
|
|
|
var_mapping[varname] = []
|
|
|
|
|
if len(splited) == 1:
|
|
|
|
|
var_mapping[varname] = [orig_var]
|
|
|
|
|
continue
|
|
|
|
|
orig_shape = orig_var.shape
|
|
|
|
|
orig_dim1_flatten = 1
|
|
|
|
|
if len(orig_shape) >= 2:
|
|
|
|
|
orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:])
|
|
|
|
|
var_list = []
|
|
|
|
|
|
|
|
|
|
for i, block in enumerate(splited):
|
|
|
|
|
size = block[1]
|
|
|
|
|
rows = size / orig_dim1_flatten
|
|
|
|
|
splited_shape = [rows]
|
|
|
|
|
if len(orig_shape) >= 2:
|
|
|
|
|
splited_shape.extend(orig_shape[1:])
|
|
|
|
|
print("block, splited shape:", block, splited_shape)
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
name="%s.block%d" % (varname, i),
|
|
|
|
|
psersistable=False,
|
|
|
|
|
dtype=orig_var.dtype,
|
|
|
|
|
shape=splited_shape) # flattend splited var
|
|
|
|
|
var_list.append(var)
|
|
|
|
|
var_mapping[varname] = var_list
|
|
|
|
|
var_mapping[varname].append(var)
|
|
|
|
|
return var_mapping
|
|
|
|
|
|
|
|
|
|
def _clone_param(self, block, v):
|
|
|
|
@ -199,7 +204,8 @@ class DistributeTranspiler:
|
|
|
|
|
def _append_split_op(self, program, gradblocks):
|
|
|
|
|
var_mapping = self._create_vars_from_blocklist(program, gradblocks)
|
|
|
|
|
for varname, splited_vars in var_mapping.iteritems():
|
|
|
|
|
if len(splited_vars) == 1:
|
|
|
|
|
# variable that don't need to split have empty splited_vars
|
|
|
|
|
if len(splited_vars) <= 1:
|
|
|
|
|
continue
|
|
|
|
|
orig_var = program.global_block().vars[varname]
|
|
|
|
|
sections = []
|
|
|
|
|