PyNative only support hccl_world_group

pull/10288/head
caifubi 4 years ago
parent a61aff2fa4
commit ac061052a4

@ -21,7 +21,7 @@ from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
from ...common.api import context
class ReduceOp:
"""
@ -45,6 +45,12 @@ class ReduceOp:
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
def check_hcom_group_valid(group):
if context.get_context("mode") == context.PYNATIVE_MODE and \
context.get_context("device_target") == "Ascend" and \
group != GlobalComm.WORLD_COMM_GROUP:
raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group))
class AllReduce(PrimitiveWithInfer):
"""
Reduces the tensor data across all devices in such a way that all devices will get the same final result.
@ -103,6 +109,7 @@ class AllReduce(PrimitiveWithInfer):
raise TypeError("The operation of AllReduce should be str.")
if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.")
check_hcom_group_valid(group)
self.op = op
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
@ -407,6 +414,7 @@ class Broadcast(PrimitiveWithInfer):
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
check_hcom_group_valid(group)
self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape):

Loading…
Cancel
Save