|
|
@ -626,7 +626,7 @@ class DistributeTranspiler(object):
|
|
|
|
attrs=attrs)
|
|
|
|
attrs=attrs)
|
|
|
|
|
|
|
|
|
|
|
|
# add distributed attrs
|
|
|
|
# add distributed attrs
|
|
|
|
pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts(
|
|
|
|
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
|
|
|
|
endpoint)
|
|
|
|
endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
pserver_program._sync_with_cpp()
|
|
|
|
pserver_program._sync_with_cpp()
|
|
|
@ -704,31 +704,28 @@ class DistributeTranspiler(object):
|
|
|
|
attrs=op.all_attrs())
|
|
|
|
attrs=op.all_attrs())
|
|
|
|
|
|
|
|
|
|
|
|
# add slice vars
|
|
|
|
# add slice vars
|
|
|
|
s_prog._slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
|
|
|
|
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
return s_prog
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
|
|
def _get_slice_vars_and_atts(self, endpoint):
|
|
|
|
def _get_slice_vars_and_attrs(self, endpoint):
|
|
|
|
slice_vars_and_atts = []
|
|
|
|
slice_vars_and_attrs = []
|
|
|
|
block_suffix = ".block"
|
|
|
|
block_suffix = "block"
|
|
|
|
for param in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
for param in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
|
|
|
orig_var_name, block_name, _ = self._get_varname_parts(param)
|
|
|
|
suff_idx = param.name.find(block_suffix)
|
|
|
|
if not block_name:
|
|
|
|
if suff_idx <= 0:
|
|
|
|
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
orig_var_name = param.name[:suff_idx]
|
|
|
|
block_idx = int(block_name.split(block_suffix)[1])
|
|
|
|
block_idx = int(param.name[suff_idx + len(block_suffix):])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orig_var = self.origin_program.global_block().vars[orig_var_name]
|
|
|
|
orig_var = self.origin_program.global_block().vars[orig_var_name]
|
|
|
|
|
|
|
|
|
|
|
|
skip_numel = 0
|
|
|
|
skip_numel = 0
|
|
|
|
slice_vars = self.param_var_mapping[orig_var_name]
|
|
|
|
slice_vars = self.param_var_mapping[orig_var_name]
|
|
|
|
for slice_var in slice_vars[:block_idx]:
|
|
|
|
for slice_var in slice_vars[:block_idx]:
|
|
|
|
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
|
|
|
|
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
|
|
|
|
slice_vars_and_atts.append([orig_var, skip_numel, param])
|
|
|
|
slice_vars_and_attrs.append([orig_var, skip_numel, param])
|
|
|
|
|
|
|
|
|
|
|
|
return slice_vars_and_atts
|
|
|
|
return slice_vars_and_attrs
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
|
|
|
|
|
|
|
|