|
|
|
@ -144,7 +144,7 @@ class Parameter(MetaTensor_):
|
|
|
|
|
if isinstance(data, bool):
|
|
|
|
|
raise ValueError('Parameter data can not be `bool`')
|
|
|
|
|
if isinstance(data, MetaTensor):
|
|
|
|
|
if _is_in_parallel_mode():
|
|
|
|
|
if _is_in_parallel_mode() or _is_role_worker():
|
|
|
|
|
# do not init data while in auto parallel.
|
|
|
|
|
return (MetaTensor_, data.dtype, data.shape)
|
|
|
|
|
data = data.to_tensor()
|
|
|
|
@ -174,8 +174,12 @@ class Parameter(MetaTensor_):
|
|
|
|
|
|
|
|
|
|
def set_param_ps(self, init_in_server=False):
|
|
|
|
|
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
|
|
|
|
|
if init_in_server and (not self.name.endswith("embedding_table")):
|
|
|
|
|
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of \
|
|
|
|
|
sparse operator support initialization in server.".format(self.name))
|
|
|
|
|
self.is_param_ps = True
|
|
|
|
|
self.init_in_server = init_in_server
|
|
|
|
|
self._param_info.init_in_server = init_in_server
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
|
|
|
|
|
1. set_ps_context(enable_ps=True) \
|
|
|
|
@ -270,6 +274,8 @@ class Parameter(MetaTensor_):
|
|
|
|
|
x._param_info = self._param_info.clone()
|
|
|
|
|
x._param_info.name = prefix + '.' + self._param_info.name
|
|
|
|
|
x.is_init = False
|
|
|
|
|
x.is_param_ps = self.is_param_ps
|
|
|
|
|
x.init_in_server = self.init_in_server
|
|
|
|
|
if init != 'same':
|
|
|
|
|
shape = self.shape
|
|
|
|
|
dtype = self.dtype
|
|
|
|
@ -403,12 +409,18 @@ class Parameter(MetaTensor_):
|
|
|
|
|
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
|
|
|
|
|
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
|
|
|
|
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
|
|
|
|
|
data = self.init_mode.to_tensor(0, [1])
|
|
|
|
|
if _is_role_worker():
|
|
|
|
|
data = self.init_mode.to_tensor(0, [1])
|
|
|
|
|
else:
|
|
|
|
|
data = self.init_mode.to_tensor(slice_index, layout[2])
|
|
|
|
|
else:
|
|
|
|
|
data = self.init_mode.to_tensor(slice_index, layout[2])
|
|
|
|
|
else:
|
|
|
|
|
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
|
|
|
|
|
data = self.init_mode.to_tensor(0, [1])
|
|
|
|
|
if _is_role_worker():
|
|
|
|
|
data = self.init_mode.to_tensor(0, [1])
|
|
|
|
|
else:
|
|
|
|
|
data = self.init_mode.to_tensor()
|
|
|
|
|
else:
|
|
|
|
|
data = self.init_mode.to_tensor()
|
|
|
|
|
|
|
|
|
|