modify grad accu and comm fusion api

pull/13596/head
Ziyan 4 years ago
parent 1d505ebad3
commit d19d42ee44

@ -288,20 +288,20 @@ class Parameter(Tensor_):
@property @property
def comm_fusion(self): def comm_fusion(self):
"""Get the fusion type for communication operators corresponding to this parameter."""
return self.param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
""" """
Get and Set the fusion type (int) for communication operators corresponding to this parameter.
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated gradients aggregation are inserted automatically. Set the fusion type for communication operators generated
for this parameter. Only `Ascend` and `Graph` mode is supported. for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0,
operators will not be fused together.
Args: Only `Ascend` and `Graph` mode is supported.
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
When the value of fusion is 0, operators will not be fused together.
""" """
return self.param_info.comm_fusion
@comm_fusion.setter
def comm_fusion(self, comm_fusion_):
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
Validator.check_non_negative_int(comm_fusion_) Validator.check_non_negative_int(comm_fusion_)

@ -344,7 +344,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int) all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
r""" r"""
Set auto parallel context, which is valid only for Ascend and GPU target. Set auto parallel context, which is valid only for Ascend and GPU target.
@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs):
all_reduce_fusion_config strategy_ckpt_save_file all_reduce_fusion_config strategy_ckpt_save_file
enable_parallel_optimizer full_batch enable_parallel_optimizer full_batch
\ pipeline_stages \ pipeline_stages
\ grad_accumulation_step
=========================== =========================== =========================== ===========================
Args: Args:
@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs):
the devices are distributed alone the pipeline. The total devices will be divided into the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when 'pipeline_stags' stages. This currently could only be used when
parallel mode semi_auto_parallel is enabled. Default: 1. parallel mode semi_auto_parallel is enabled. Default: 1.
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
This should be a positive int. Default: 1.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.

@ -18,7 +18,7 @@ import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check, Validator
_MAX_GROUP_NAME_LEN = 127 _MAX_GROUP_NAME_LEN = 127
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
@ -257,6 +257,7 @@ class _AutoParallelContext:
grad_accumulation_step (int): The grad accumulation step. grad_accumulation_step (int): The grad accumulation step.
""" """
self.check_context_handle() self.check_context_handle()
Validator.check_positive_int(grad_accumulation_step)
self._context_handle.set_grad_accumulation_step(grad_accumulation_step) self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
def get_grad_accumulation_step(self): def get_grad_accumulation_step(self):

Loading…
Cancel
Save