|
|
|
@ -719,6 +719,28 @@ class DistributeTranspiler(object):
|
|
|
|
}) for ep in self.pserver_endpoints
|
|
|
|
}) for ep in self.pserver_endpoints
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_slice_vars_and_atts(self, endpoint):
|
|
|
|
|
|
|
|
slice_vars_and_atts = []
|
|
|
|
|
|
|
|
block_suffix = ".block"
|
|
|
|
|
|
|
|
for param in self.param_grad_ep_mapping[endpoint]["params"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suff_idx = param.name.find(block_suffix)
|
|
|
|
|
|
|
|
if suff_idx <= 0:
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orig_var_name = param.name[:suff_idx]
|
|
|
|
|
|
|
|
block_idx = int(param.name[suff_idx + len(block_suffix):])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orig_var = self.origin_program.global_block().vars[orig_var_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
skip_numel = 0
|
|
|
|
|
|
|
|
slice_vars = self.param_var_mapping[orig_var_name]
|
|
|
|
|
|
|
|
for slice_var in slice_vars[:block_idx]:
|
|
|
|
|
|
|
|
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
|
|
|
|
|
|
|
|
slice_vars_and_atts.append([orig_var, skip_numel, param])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return slice_vars_and_atts
|
|
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
pserver_endpoints):
|
|
|
|
pserver_endpoints):
|
|
|
|
|