|
|
|
@ -177,6 +177,7 @@ class DistributeTranspiler:
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
return param_list, grad_list
|
|
|
|
|
|
|
|
|
|
def _init_splited_vars(self, slice_var_up):
|
|
|
|
|
# update these mappings for further transpile:
|
|
|
|
@ -199,8 +200,8 @@ class DistributeTranspiler:
|
|
|
|
|
grad_list.append(g)
|
|
|
|
|
param_grad_set.add(g.name)
|
|
|
|
|
|
|
|
|
|
self._update_dist_lookup_table_vars(param_list, grad_list,
|
|
|
|
|
self.params_grads)
|
|
|
|
|
param_list, grad_list = self._update_dist_lookup_table_vars(
|
|
|
|
|
param_list, grad_list, self.params_grads)
|
|
|
|
|
|
|
|
|
|
if slice_var_up:
|
|
|
|
|
# when we slice var up into blocks, we will slice the var according to
|
|
|
|
|