|
|
|
@ -19,6 +19,8 @@ import numbers
|
|
|
|
|
import numpy as np
|
|
|
|
|
from .._c_expression import ParamInfo
|
|
|
|
|
from . import dtype as mstype
|
|
|
|
|
from .. import context
|
|
|
|
|
from ..parallel._utils import _get_parallel_mode
|
|
|
|
|
from .initializer import initializer
|
|
|
|
|
from .tensor import Tensor
|
|
|
|
|
from .._checkparam import Validator
|
|
|
|
@ -292,7 +294,18 @@ class Parameter(Tensor_):
|
|
|
|
|
|
|
|
|
|
@comm_fusion.setter
|
|
|
|
|
def comm_fusion(self, comm_fusion_):
|
|
|
|
|
"""Set the fusion type for communication operators corresponding to this parameter."""
|
|
|
|
|
"""
|
|
|
|
|
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
|
|
|
|
|
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated
|
|
|
|
|
for this parameter. Only `Ascend` and `Graph` mode is supported.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
|
|
|
|
|
When the value of fusion is 0, operators will not be fused together.
|
|
|
|
|
"""
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
|
|
|
|
|
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
|
|
|
|
|
Validator.check_non_negative_int(comm_fusion_)
|
|
|
|
|
self.param_info.comm_fusion = comm_fusion_
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|