From cae254f4df69b219a568d0a8d8307f9f4b30e94a Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Sat, 20 Jun 2020 09:48:12 +0800 Subject: [PATCH] asymmetric row split support for GatherV2 --- .../parallel/ops_info/gather_v2_p_info.cc | 118 ++++++++++++++++++ .../parallel/ops_info/gather_v2_p_info.h | 7 ++ .../python/parallel/test_manual_gatherv2.py | 61 +++++++++ 3 files changed, 186 insertions(+) create mode 100644 tests/ut/python/parallel/test_manual_gatherv2.py diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 9fb8df0883..dfecb29e88 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "parallel/device_matrix.h" #include "parallel/graph_util/generate_graph.h" @@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() { return FAILED; } + auto manual_split_iter = attrs_.find("manual_split"); + if (manual_split_iter != attrs_.end()) { + param_split_shapes_.clear(); + manual_split_ = true; + auto var = manual_split_iter->second->cast(); + MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); + + if (var->size() > 0) { + std::vector elements = var->value(); + for (auto &ele : elements) { + if (ele->isa()) { + auto value_tuple = ele->cast(); + std::vector value_vector = value_tuple->value(); + if (value_vector.size() != 2) { + MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; + return FAILED; + } + param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); + index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); + } else { + MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + } + + if (param_split_shapes_.empty()) { + MS_LOG(ERROR) << "Failed to extract param split strategy."; + return FAILED; + } + } + } + + return SUCCESS; +} + +Status GatherV2PInfo::CheckManualSplit() { + auto param_shape = inputs_shape_.at(0); + int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, + [](int32_t s, int32_t shape) { return s + shape; }); + if (split_shape_sum < param_shape.at(0)) { + MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; + return FAILED; + } + + if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { + MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; + return FAILED; + } + return SUCCESS; } @@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } + if (manual_split_) { + if (CheckManualSplit() != SUCCESS) { + return FAILED; + } + // when using manual_split, no need to check belowings. + return SUCCESS; + } + // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; @@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { } Status GatherV2PInfo::InferMirrorOps() { + // There is no mirror operators for manual split + if (manual_split_) { + return SUCCESS; + } + mirror_ops_.clear(); Shape input_a_tensor_map = inputs_tensor_map_.at(0); std::vector input_a_group; @@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() { // infer input dev_matrix_shape auto param_strategy = strategy_->GetInputDim().at(0); auto index_strategy = strategy_->GetInputDim().at(1); + + if (manual_split_) { + dev_matrix_shape_ = param_strategy; + out_dev_matrix_shape_ = dev_matrix_shape_; + return SUCCESS; + } + dev_matrix_shape_ = param_strategy; // param_strategy(axis)!=1, @@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() { } Status GatherV2PInfo::InferTensorMap() { + if (manual_split_) { + inputs_tensor_map_.push_back({1, 0}); + inputs_tensor_map_.push_back({-1, 1}); + outputs_tensor_map_.push_back({-1, 1, 0}); + return SUCCESS; + } // infer input tensor map // param_strategy(axis) != 1 size_t param_size = inputs_shape_.at(0).size(); @@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() { Shape input_shape = inputs_shape_.at(0); Shape input_index_shape = inputs_shape_.at(1); Shape output_shape = outputs_shape_.at(0); + int32_t rank = g_device_manager->global_rank(); // infer tensor layout TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if (manual_split_) { + input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]]; + input_shape[0] = input_shape[0] * dev_matrix_shape_[0]; + } if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != @@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() { TensorInfo input_index_info(input_index_layout); TensorInfo output_tensor_info(output_tensor_layout); + Shape slice_shape = input_tensor_info.slice_shape(); + MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); + inputs_tensor_info_.push_back(input_tensor_info); inputs_tensor_info_.push_back(input_index_info); outputs_tensor_info_.push_back(output_tensor_info); @@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() { return FAILED; } +Status GatherV2PInfo::InferOffset() { + CheckGlobalDeviceManager(); + size_t rank = g_device_manager->global_rank(); + if (rank < index_offsets_.size()) { + index_offset_ = index_offsets_.at(rank); + MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; + return SUCCESS; + } + + MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size(); + return FAILED; +} + Status GatherV2PInfo::InferGroup() { auto param_strategy = strategy_->GetInputDim().at(0); size_t dim = IntToSize(axis_); @@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { MS_LOG(ERROR) << "GenerateGraph Init failed"; return FAILED; } + if (manual_split_) { + if (InferOffset() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, gather_v2)); + return SUCCESS; + } if (InferBias() != SUCCESS) { MS_LOG(ERROR) << name_ << ": Infer Bias failed."; return FAILED; @@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { } ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { + if (manual_split_) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; + } + auto param_strategy = strategy_->GetInputDim().at(0); // target_ == CPU, no need to raplace graph if (target_ == CPU) { diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index 83868606d1..acdecb49a3 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -36,6 +36,7 @@ class GatherV2PInfo : public OperatorInfo { : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(0), bias_(0), + index_offset_(0), slice_size_(0) {} ~GatherV2PInfo() override = default; Status Init(const StrategyPtr &strategy) override; @@ -57,20 +58,26 @@ class GatherV2PInfo : public OperatorInfo { private: Status ComputeReplaceGraph(const CNodePtr &cnode); + Status CheckManualSplit(); Status ComputeReplaceOp(); Status InferBias(); + Status InferOffset(); Status InferGroup(); int32_t axis_; std::string target_; std::string replace_op_name_ = GATHERV2; int32_t bias_; + int32_t index_offset_; int32_t slice_size_; Shape out_dev_matrix_shape_; Group group_; bool reduce_scatter_flag_ = false; int32_t split_num_ = 1; bool host_reduce_scatter_ = false; + bool manual_split_ = false; + std::vector param_split_shapes_; + std::vector index_offsets_; }; class SparseGatherV2Info : public GatherV2PInfo { diff --git a/tests/ut/python/parallel/test_manual_gatherv2.py b/tests/ut/python/parallel/test_manual_gatherv2.py new file mode 100644 index 0000000000..21d25ae720 --- /dev/null +++ b/tests/ut/python/parallel/test_manual_gatherv2.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer + +class Net(Cell): + def __init__(self, strategy1=None, strategy2=None, strategy3=None): + super().__init__() + self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) + self.mul = P.Mul().set_strategy(strategy2) + self.reshape = P.Reshape() + self.matmul = P.MatMul().set_strategy(strategy3) + self.matmul.add_prim_attr("forward_reduce_scatter", True) + self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") + self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") + self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") + + def construct(self, x, b): + out = self.gatherv2(self.param, x, 0) + out = self.mul(out, self.mul_weight) + out = self.reshape(out, (2, 256)) + out = self.matmul(out, self.matmul_weight) + return out + +_x = Tensor(np.ones([2, 4]), dtype=ms.int32) +_b = Tensor(np.ones([64, 8]), dtype=ms.float32) + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + +def test_neg_data_parallel(): + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + compile_net(net)