@ -712,7 +712,7 @@ class DistributeTranspiler(object):
slice_vars_and_attrs = []
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param)
orig_var_name, block_name, _ = self._get_varname_parts(param.name)
if not block_name:
continue