|
|
|
@ -21,7 +21,7 @@ from ..._checkparam import Rel
|
|
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
|
|
|
|
|
|
|
|
|
from ...common.api import context
|
|
|
|
|
|
|
|
|
|
class ReduceOp:
|
|
|
|
|
"""
|
|
|
|
@ -45,6 +45,12 @@ class ReduceOp:
|
|
|
|
|
|
|
|
|
|
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
|
|
|
|
|
|
|
|
|
|
def check_hcom_group_valid(group):
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and \
|
|
|
|
|
context.get_context("device_target") == "Ascend" and \
|
|
|
|
|
group != GlobalComm.WORLD_COMM_GROUP:
|
|
|
|
|
raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group))
|
|
|
|
|
|
|
|
|
|
class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Reduces the tensor data across all devices in such a way that all devices will get the same final result.
|
|
|
|
@ -103,6 +109,7 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
raise TypeError("The operation of AllReduce should be str.")
|
|
|
|
|
if not isinstance(_get_group(group), str):
|
|
|
|
|
raise TypeError("The group of AllReduce should be str.")
|
|
|
|
|
check_hcom_group_valid(group)
|
|
|
|
|
self.op = op
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
self.add_prim_attr('fusion', 0)
|
|
|
|
@ -407,6 +414,7 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
check_hcom_group_valid(group)
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|