|
|
|
|
@ -164,8 +164,10 @@ class Optimizer(Cell):
|
|
|
|
|
self.param_length = len(self.parameters)
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
|
|
|
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL:
|
|
|
|
|
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
|
|
|
|
|
self.use_parallel = True
|
|
|
|
|
elif context.get_context("device_target") != "Ascend":
|
|
|
|
|
raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
|
|
|
|
|
elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
|
|
|
|
|
raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
|
|
|
|
|
else:
|
|
|
|
|
@ -174,10 +176,10 @@ class Optimizer(Cell):
|
|
|
|
|
self.use_parallel = False
|
|
|
|
|
if self.use_parallel:
|
|
|
|
|
if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
|
|
|
|
|
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
|
|
|
|
raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name))
|
|
|
|
|
self.dev_num = _get_device_num()
|
|
|
|
|
if self.dev_num > self.param_length:
|
|
|
|
|
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
|
|
|
|
|
raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
|
|
|
|
|
" less than the number of devices {}".format(self.param_length, self.dev_num))
|
|
|
|
|
self.param_rank = self._get_parameter_group_id()
|
|
|
|
|
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
|
|
|
|
|
|