|
|
|
@ -33,6 +33,10 @@ class VarBlock:
|
|
|
|
|
return "%s:%d:%d" % (self.varname, self.offset, self.size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def same_or_split_var(p_name, var_name):
|
|
|
|
|
return p_name == var_name or p_name.startswith(var_name + ".block")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_dense_variable(var_list,
|
|
|
|
|
pserver_count,
|
|
|
|
|
min_block_size=1024,
|
|
|
|
@ -303,8 +307,8 @@ class DistributeTranspiler:
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
for n in param_names:
|
|
|
|
|
if n.startswith(op.inputs["Param"].name+".block") and \
|
|
|
|
|
n != op.inputs["Param"].name:
|
|
|
|
|
if same_or_split_var(n, op.inputs[
|
|
|
|
|
"Param"].name) and n != op.inputs["Param"].name:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
else:
|
|
|
|
@ -335,7 +339,7 @@ class DistributeTranspiler:
|
|
|
|
|
if key == "Grad":
|
|
|
|
|
grad_block = None
|
|
|
|
|
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
|
|
|
|
|
if g.name.startswith(var.name):
|
|
|
|
|
if same_or_split_var(g.name, var.name):
|
|
|
|
|
grad_block = g
|
|
|
|
|
break
|
|
|
|
|
if not grad_block:
|
|
|
|
@ -365,7 +369,7 @@ class DistributeTranspiler:
|
|
|
|
|
# param is already created on global program
|
|
|
|
|
param_block = None
|
|
|
|
|
for p in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
if p.name.startswith(var.name):
|
|
|
|
|
if same_or_split_var(p.name, var.name):
|
|
|
|
|
param_block = p
|
|
|
|
|
break
|
|
|
|
|
if not param_block:
|
|
|
|
@ -502,7 +506,7 @@ class DistributeTranspiler:
|
|
|
|
|
def _get_splited_name_and_shape(varname):
|
|
|
|
|
for idx, splited_param in enumerate(params):
|
|
|
|
|
pname = splited_param.name
|
|
|
|
|
if pname.startswith(varname) and varname != pname:
|
|
|
|
|
if same_or_split_var(pname, varname) and varname != pname:
|
|
|
|
|
return pname, splited_param.shape
|
|
|
|
|
return "", []
|
|
|
|
|
|
|
|
|
|