diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc index f5f0fe85cb..7bd2fa808d 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc @@ -28,9 +28,14 @@ namespace parallel { std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; + const std::string inner_ops_module = INNER_OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); + py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); if (!py::hasattr(mod, common::SafeCStr(op_name))) { - MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name; + if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { + MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; + } + return inner_ops_module; } return ops_module; } diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 1c40350e6a..7a16aeafcb 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() { } } + // target=CPU, axis must be 0 + if (target_ == "CPU" && axis_ != 0) { + MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; + return FAILED; + } + return SUCCESS; } @@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() { int32_t rank = g_device_manager->global_rank(); auto input_shape = inputs_shape_.at(0); auto params_strategy = strategy_->GetInputDim().at(0); + // axis don't split + if (params_strategy.at(axis_) == 1) { + bias_ = 0; + return SUCCESS; + } // params_size=1, axis=0 if ((input_shape.size() == 1) && (axis_ == 0)) { slice_size_ = input_shape.at(0) / params_strategy.at(0); @@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() { } auto group_size = group_.GetDevNum(); Attr attr_group; - // group size <= 8 - std::vector rank_list; - if (group_size <= 8) { - reduce_scatter_flag_ = false; - operator_name = HOST_REDUCE_SCATTER; - rank_list = GetRankFromGroup(group_); - attr_group = std::make_pair(GROUP, MakeValue(rank_list)); + if (host_reduce_scatter_) { + // group size <= 8 + std::vector rank_list; + if (group_size <= 8) { + reduce_scatter_flag_ = false; + operator_name = HOST_REDUCE_SCATTER; + rank_list = GetRankFromGroup(group_); + attr_group = std::make_pair(GROUP, MakeValue(rank_list)); + } else { + // group size > 8, don't support host reduce_scatter + reduce_scatter_flag_ = true; + split_num_ = SizeToInt(group_size / 8); + CheckGlobalDeviceManager(); + operator_name = REDUCE_SCATTER; + int32_t rank = g_device_manager->global_rank(); + size_t repeat = group_size / 8; + for (size_t i = 0; i < repeat; ++i) { + rank_list.push_back(rank + SizeToInt(i * 8)); + } + Group g = g_device_manager->CreateGroup(rank_list); + attr_group = std::make_pair(GROUP, MakeValue(g.name())); + } } else { - // group size > 8 - reduce_scatter_flag_ = true; - split_num_ = SizeToInt(group_size / 8); - CheckGlobalDeviceManager(); operator_name = REDUCE_SCATTER; - int32_t rank = g_device_manager->global_rank(); - size_t repeat = group_size / 8; - for (size_t i = 0; i < repeat; ++i) { - rank_list.push_back(rank + SizeToInt(i * 8)); + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; } - Group g = g_device_manager->CreateGroup(rank_list); - attr_group = std::make_pair(GROUP, MakeValue(g.name())); + attr_group = std::make_pair(GROUP, MakeValue(group_.name())); } Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); OperatorAttrs attrs = {attr_op, attr_group}; @@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() { Attr param_offset = std::make_pair("offset", MakeValue(bias_)); Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); - OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5), - std::make_pair(param_split_num, 6)}; + OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), + std::make_pair(param_split_num, 5)}; OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(op_name, args); replace_op_.push_back(op); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index b139ee215c..83868606d1 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo { Group group_; bool reduce_scatter_flag_ = false; int32_t split_num_ = 1; + bool host_reduce_scatter_ = false; }; class SparseGatherV2Info : public GatherV2PInfo { diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 952b87a14d..e717525237 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum"; constexpr char REDUCE_OP_MAX[] = "max"; constexpr char REDUCE_OP_MIN[] = "min"; constexpr char OP_PATH[] = "mindspore.ops.operations"; +constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; constexpr char GET_OP_FUNCTION[] = "_get_python_op"; constexpr char KEEP_DIMS[] = "keep_dims"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 39dd2c96e0..a5e5dee990 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -536,7 +536,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; auto prim = GetValueNode(node->input(0)); if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { - replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; + replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { Param param_first = *(params.begin()); diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 0b4804ffbe..5d52089cbe 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu0(): +def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu1(): +def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu2(): +def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py index f12148e34f..dd0517a08e 100644 --- a/tests/ut/python/parallel/test_sparse_gather_v2.py +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu0(): +def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu1(): +def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): _executor.compile(net, x, y) -def need_fix_test_gatherv2_cpu2(): +def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))