From ac061052a4ccd8404278aacb4d2f77974cd45490 Mon Sep 17 00:00:00 2001 From: caifubi Date: Mon, 21 Dec 2020 16:48:49 +0800 Subject: [PATCH] PyNative only support hccl_world_group --- mindspore/ops/operations/comm_ops.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index c7e46fdf53..b37575a238 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -21,7 +21,7 @@ from ..._checkparam import Rel from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register - +from ...common.api import context class ReduceOp: """ @@ -45,6 +45,12 @@ class ReduceOp: target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) +def check_hcom_group_valid(group): + if context.get_context("mode") == context.PYNATIVE_MODE and \ + context.get_context("device_target") == "Ascend" and \ + group != GlobalComm.WORLD_COMM_GROUP: + raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group)) + class AllReduce(PrimitiveWithInfer): """ Reduces the tensor data across all devices in such a way that all devices will get the same final result. @@ -103,6 +109,7 @@ class AllReduce(PrimitiveWithInfer): raise TypeError("The operation of AllReduce should be str.") if not isinstance(_get_group(group), str): raise TypeError("The group of AllReduce should be str.") + check_hcom_group_valid(group) self.op = op self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) @@ -407,6 +414,7 @@ class Broadcast(PrimitiveWithInfer): def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): validator.check_value_type('root_rank', root_rank, (int,), self.name) validator.check_value_type('group', _get_group(group), (str,), self.name) + check_hcom_group_valid(group) self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape):