!13596 modify comm fusion and grad accu api

From: @gong_zi_yan
Reviewed-by: @guoqi1024,@stsuteng
Signed-off-by: @stsuteng
pull/13596/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7df661fbaf

@ -288,20 +288,20 @@ class Parameter(Tensor_):
@property
def comm_fusion(self):
"""Get the fusion type for communication operators corresponding to this parameter."""
return self.param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
"""
Get and Set the fusion type (int) for communication operators corresponding to this parameter.
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated
for this parameter. Only `Ascend` and `Graph` mode is supported.
gradients aggregation are inserted automatically. Set the fusion type for communication operators generated
for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0,
operators will not be fused together.
Args:
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
When the value of fusion is 0, operators will not be fused together.
Only `Ascend` and `Graph` mode is supported.
"""
return self.param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
Validator.check_non_negative_int(comm_fusion_)

@ -344,7 +344,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int)
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
def set_auto_parallel_context(**kwargs):
r"""
Set auto parallel context, which is valid only for Ascend and GPU target.
@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs):
all_reduce_fusion_config strategy_ckpt_save_file
enable_parallel_optimizer full_batch
\ pipeline_stages
\ grad_accumulation_step
=========================== ===========================
Args:
@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs):
the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when
parallel mode semi_auto_parallel is enabled. Default: 1.
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
This should be a positive int. Default: 1.
Raises:
ValueError: If input key is not attribute in auto parallel context.

@ -18,7 +18,7 @@ import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check
from mindspore._checkparam import args_type_check, Validator
_MAX_GROUP_NAME_LEN = 127
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
@ -257,6 +257,7 @@ class _AutoParallelContext:
grad_accumulation_step (int): The grad accumulation step.
"""
self.check_context_handle()
Validator.check_positive_int(grad_accumulation_step)
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
def get_grad_accumulation_step(self):

Loading…
Cancel
Save