diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index dbb4625873..9497cb574b 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -167,6 +167,17 @@ class Parameter(MetaTensor_): """For parse check.""" def set_param_ps(self, init_in_server=False): + """ + Set whether the trainable parameter is updated by parameter server and whether the + trainable parameter is initialized on server. + + Note: + It only works when a running task is in the parameter server mode. + + Args: + init_in_server (bool): Whether trainable parameter updated by parameter server is + initialized on server. Default: False. + """ if _is_role_worker() or _is_role_pserver() or _is_role_sched(): if init_in_server and (not self.name.endswith("embedding_table")): raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 662ee24464..d3dedb1bf3 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1018,13 +1018,16 @@ class Cell(Cell_): def set_param_ps(self, recurse=True, init_in_server=False): """ - Set whether the trainable parameter is updated by parameter server. + Set whether the trainable parameters are updated by parameter server and whether the + trainable parameters are initialized on server. Note: It only works when a running task is in the parameter server mode. Args: recurse (bool): Whether sets the trainable parameters of subcells. Default: True. + init_in_server (bool): Whether trainable parameters updated by parameter server are + initialized on server. Default: False. """ params = self.trainable_params(recurse) for param in params: