From d19d42ee443925b24b47d08dba909fbae4054a16 Mon Sep 17 00:00:00 2001 From: Ziyan Date: Fri, 19 Mar 2021 15:11:23 +0800 Subject: [PATCH] modify grad accu and comm fusion api --- mindspore/common/parameter.py | 20 ++++++++++---------- mindspore/context.py | 5 ++++- mindspore/parallel/_auto_parallel_context.py | 3 ++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index a9614855f3..3583251ee3 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -288,20 +288,20 @@ class Parameter(Tensor_): @property def comm_fusion(self): - """Get the fusion type for communication operators corresponding to this parameter.""" - return self.param_info.comm_fusion - - @comm_fusion.setter - def comm_fusion(self, comm_fusion_): """ + Get and Set the fusion type (int) for communication operators corresponding to this parameter. + In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or - gradients aggregation are inserted automatically.Set the fusion type for communication operators generated - for this parameter. Only `Ascend` and `Graph` mode is supported. + gradients aggregation are inserted automatically. Set the fusion type for communication operators generated + for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0, + operators will not be fused together. - Args: - comm_fusion_ (int): The value of fusion must be greater than or equal to 0. - When the value of fusion is 0, operators will not be fused together. + Only `Ascend` and `Graph` mode is supported. """ + return self.param_info.comm_fusion + + @comm_fusion.setter + def comm_fusion(self, comm_fusion_): if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") Validator.check_non_negative_int(comm_fusion_) diff --git a/mindspore/context.py b/mindspore/context.py index 2c83ea54ff..3c41179561 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -344,7 +344,7 @@ def _context(): @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, - all_reduce_fusion_config=list, pipeline_stages=int) + all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int) def set_auto_parallel_context(**kwargs): r""" Set auto parallel context, which is valid only for Ascend and GPU target. @@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs): all_reduce_fusion_config strategy_ckpt_save_file enable_parallel_optimizer full_batch \ pipeline_stages + \ grad_accumulation_step =========================== =========================== Args: @@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs): the devices are distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages. This currently could only be used when parallel mode semi_auto_parallel is enabled. Default: 1. + grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode. + This should be a positive int. Default: 1. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 83b0a1731a..06fdf0bb94 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -18,7 +18,7 @@ import mindspore.context as context from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore.parallel._ps_context import _is_role_pserver from mindspore._c_expression import AutoParallelContext -from mindspore._checkparam import args_type_check +from mindspore._checkparam import args_type_check, Validator _MAX_GROUP_NAME_LEN = 127 _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" @@ -257,6 +257,7 @@ class _AutoParallelContext: grad_accumulation_step (int): The grad accumulation step. """ self.check_context_handle() + Validator.check_positive_int(grad_accumulation_step) self._context_handle.set_grad_accumulation_step(grad_accumulation_step) def get_grad_accumulation_step(self):