diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index 46f3fc823a..74b3c5aeb6 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -947,5 +947,122 @@ double GatherV2PCost::GetBackwardComputationCost(const std::vector & return result; } + +// The forward communication is determined by whether the slice is column split or row split +// The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce +double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + TensorInfo input0 = inputs[0]; + TensorInfo input1 = inputs[1]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = inputs[0].slice_shape(); + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost"; + } + // If the shape b is not the same as the shape a, we regard it as column slice + for (size_t i = 0; i < input1.shape().size(); ++i) { + if (input0_shape[i] != input0_slice_shape[i]) { + result = ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + return result; + } + } + return result; +} + +double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + TensorInfo input0 = inputs[0]; + TensorInfo input1 = inputs[1]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = inputs[0].slice_shape(); + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost"; + } + if (is_parameter_[0]) { + // If the forward process has a AllReduce, then the backward also needs one. + for (size_t i = 0; i < input1.shape().size(); ++i) { + if (input0_shape[i] != input0_slice_shape[i]) { + result = ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + return result; + } + } + } + return result; +} +double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape output_slice_shape = outputs[0].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) + + ListProduct(output_slice_shape) * static_cast(outputs_type_lengths_[0]); + return result; +} + +double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + TensorInfo input0 = inputs[0]; + TensorInfo input1 = inputs[1]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = inputs[0].slice_shape(); + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() + << " for UnsortedSegmentMinCost cost"; + } + // If the shape b is not the same as the shape a, we regard it as column slice + // The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin. + for (size_t i = 0; i < input1.shape().size(); ++i) { + if (input0_shape[i] != input0_slice_shape[i]) { + result = ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + return result; + } + } + return result; +} + +double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + TensorInfo input0 = inputs[0]; + TensorInfo input1 = inputs[1]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = inputs[0].slice_shape(); + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() + << " for UnsortedSegmentMinCost cost"; + } + if (is_parameter_[0]) { + // If the forward process has a AllGather, then the backward also needs one ReduceScatter. + for (size_t i = 0; i < input1.shape().size(); ++i) { + if (input0_shape[i] != input0_slice_shape[i]) { + result = ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + return result; + } + } + } + return result; +} +double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape output_slice_shape = outputs[0].slice_shape(); + // The forward operation is UnsortedSegmentMin + ReudceMin + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) + + ListProduct(output_slice_shape) * static_cast(outputs_type_lengths_[0]) + + ListProduct(output_slice_shape) * static_cast(outputs_type_lengths_[0]); // ReduceMin + return result; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index a2f5509093..0b73b077c6 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -578,6 +578,58 @@ class DropOutCost : public OperatorCost { using DropOutCostPtr = std::shared_ptr; +class UnsortedSegmentSumCost : public OperatorCost { + public: + explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + UnsortedSegmentSumCost() : OperatorCost(true) {} + ~UnsortedSegmentSumCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using UnsortedSegmentSumCostPtr = std::shared_ptr; + +class UnsortedSegmentMinCost : public OperatorCost { + public: + explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + UnsortedSegmentMinCost() : OperatorCost(true) {} + ~UnsortedSegmentMinCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using UnsortedSegmentMinCostPtr = std::shared_ptr; + class LayerNormCost : public OperatorCost { public: explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 6bd6667ce9..dd9ae230d8 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -173,6 +173,8 @@ REGISTER(ExpandDimsInfo); REGISTER(SqueezeInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SquareInfo); +REGISTER(UnsortedSegmentSumInfo); +REGISTER(UnsortedSegmentMinInfo); REGISTER(GatherV2PInfo); REGISTER(EmbeddingLookupInfo); REGISTER(TileInfo); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index edc96610ac..a5c9b2e0b4 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -35,6 +35,7 @@ #include "frontend/parallel/ops_info/reduce_method_info.h" #include "frontend/parallel/ops_info/reshape_info.h" #include "frontend/parallel/ops_info/transpose_info.h" +#include "frontend/parallel/ops_info/unsorted_segment_op_info.h" #include "frontend/parallel/ops_info/virtual_dataset_info.h" #include "frontend/parallel/ops_info/gather_v2_p_info.h" #include "frontend/parallel/ops_info/tile_info.h" diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 0eb4758648..d57f2bb7e2 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -283,6 +283,8 @@ constexpr char IN_TOPK[] = "InTopK"; constexpr char GATHER_ND[] = "GatherNd"; constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; +constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; +constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char ADD[] = "Add"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc new file mode 100644 index 0000000000..91a1456e36 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc @@ -0,0 +1,313 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/unsorted_segment_op_info.h" + +#include +#include +#include + +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/strategy.h" +#include "ir/tensor.h" +#include "ir/value.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +// The operator UnsortedSegment accepts three inputs: +// input0 : vector, the shape is x1,x2,x3,...,xr +// input1 : segment id, the shape is x1,x2,..,xn +// input2 : value, the number of the segments +// For Sum: r >= n +// For Min: r >=n, n=1 +Status UnsortedSegmentOpInfo::GetAttrs() { + if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != UNSORTEDSEGMENTOP_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + + if (inputs_shape_.at(0).empty()) { + MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; + return FAILED; + } + int num_segments = GetValue(input_value_.at(2)); + if (num_segments < 0) { + MS_LOG(ERROR) << name_ << ": the number of segments should be non negative value."; + return FAILED; + } + + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) { + // Check size + if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != UNSORTEDSEGMENTOP_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << UNSORTEDSEGMENTOP_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + // The strategy of the first and the second input should be set. + if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + return FAILED; + } + Strategys stra = strategy->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + // The size of the input b must be equal or smaller than input a + for (size_t i = 0; i < input_b_shape.size(); ++i) { + if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) { + MS_LOG(ERROR) << name_ + << " : Invalid strategy. The shape and the strategy of the input0 and input1 " + "should be same before the front size of the input[1]"; + return FAILED; + } + } + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::InferDevMatrixShape() { + Strategys stra = strategy_->GetInputDim(); + dev_matrix_shape_ = stra.at(0); + return SUCCESS; +} + +// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr +// the dimension x1,x2,x3,..,xn is eliminated +// suppose the strategy of the inputs is (a,b,c,d), (a,b) +// the tensor map of the input vector is (3,2,1,0), id:(1, 0) +// the output vector is (-1, 1, 0) +Status UnsortedSegmentOpInfo::InferTensorMap() { + Shape tensor_map_in; + Shape tensor_map_in_index; + Shape tensor_map_out; + size_t input0_size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < input0_size; ++i) { + tensor_map_in.push_back(SizeToInt(input0_size - i - 1)); + tensor_map_in_index.push_back(SizeToInt(input0_size - i - 1)); + tensor_map_out.push_back(SizeToInt(input0_size - i - 1)); + } + + (void)tensor_map_out.erase(tensor_map_out.begin(), tensor_map_out.begin() + inputs_shape_.at(1).size() - 1); + // A special case: the input vector (a,) id (a,) or input vector (a,b,c), id(a,b,c) + // The output vector will be a 1-dim vector, + // These two kinds of situations as row slice. + tensor_map_out[0] = -1; + (void)tensor_map_in_index.erase(tensor_map_in_index.begin() + inputs_shape_.at(1).size(), tensor_map_in_index.end()); + if (tensor_map_out.size() != outputs_shape_.at(0).size()) { + MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() + << " output size is " << outputs_shape_.at(0).size(); + return FAILED; + } + + inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +// Set the default strategy +Status UnsortedSegmentOpInfo::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + for (auto &sp : sp_vector) { + Strategys tmp_strategy; + Dimensions first_input_strategy = sp->GetInputDim()[0]; + Dimensions second_input_strategy; + for (size_t i = 0; i < inputs_shape_[1].size(); ++i) { + second_input_strategy.push_back(first_input_strategy[i]); + } + tmp_strategy.push_back(first_input_strategy); + tmp_strategy.push_back(second_input_strategy); + sp->ResetInputs(tmp_strategy); + } + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + + return SUCCESS; +} + +// if the dimension of the input b is split, we regarded it as the row slice, thus requires a AllReduce +// otherwise it is column slice, +Status UnsortedSegmentOpInfo::InferForwardCommunication() { + forward_op_.clear(); + std::vector group_list; + Shape tmp_group_tensor_map = outputs_tensor_map_.at(0); + if (repeated_calc_num_ > 1) { + for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) { + tmp_group_tensor_map[i] += 1; + } + tmp_group_tensor_map.push_back(0); + } + if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed."; + return FAILED; + } else if (group_list.empty()) { + MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; + return SUCCESS; + } + + Operator op; + op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); + + forward_op_.push_back(op); + MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); + return SUCCESS; +} + +Status UnsortedSegmentOpInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + return SetCostUnderStrategyBase(strategy); +} + +std::shared_ptr UnsortedSegmentOpInfo::GenerateBatchStrategies() { + if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) { + MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (GetAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << "GetAttrs failed!"; + } + + Dimensions strategy_a; + Dimensions strategy_b; + strategy_a.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[0].size(); i++) { + strategy_a.push_back(1); + } + + strategy_b.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[1].size(); i++) { + strategy_b.push_back(1); + } + Strategys strategy_v = {strategy_a, strategy_b}; + return std::make_shared(strategy_v); +} + +// When the index is splited, the graph should be replaced +// a special case is when the shape input equals the shape of ids, we regard it as column slice, +// thus there is no need for repalce graphs +ReplaceGraphPtr UnsortedSegmentMinInfo::replace_graph(const CNodePtr &cnode) { + auto input_id_strategy = strategy_->GetInputDim().at(1); + // 1. the two input shapes are same, and the strategy is not all ones + if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int32_t &shard) { return shard > 1; })) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + } + return replace_graph_; +} + +Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + // Get the attributes of the UnsortedSegmentMin + auto num_segments = GetValue(input_value_.at(2)); + // Step1: Output branch + auto segment_min = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MIN), gen_g.virtual_input_node(), + gen_g.virtual_input_node(), CreatInt32Imm(num_segments)}); + auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_min, CreatInt32Imm(0)}); + auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output}); + auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MIN), all_gather_output, CreatInt32Imm(0)}); + + std::vector> input_nodes = {std::make_pair(segment_min, 1), + std::make_pair(segment_min, 2)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, final_output)); + + return SUCCESS; +} + +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h new file mode 100644 index 0000000000..8e684b294b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" +#include "ir/value.h" + +namespace mindspore { +namespace parallel { +constexpr size_t UNSORTEDSEGMENTOP_INPUTS_SIZE = 2; +constexpr size_t UNSORTEDSEGMENTOP_OUTPUTS_SIZE = 1; +class UnsortedSegmentOpInfo : public OperatorInfo { + public: + UnsortedSegmentOpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} + ~UnsortedSegmentOpInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::shared_ptr GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferForwardCommunication() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); +}; + +class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo { + public: + UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~UnsortedSegmentSumInfo() override = default; +}; + +class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { + public: + UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~UnsortedSegmentMinInfo() override = default; + + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + Status InferForwardCommunication() override { return SUCCESS; } + + protected: + Status ComputeReplaceGraph(const CNodePtr &cnode); +}; + +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index d1bca2125e..12b4fa1a7c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -312,7 +312,8 @@ bool IsSplittableOperator(const std::string &op_name) { EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH, EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, - SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE}; + SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, + UNSORTED_SEGMENT_MIN}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 5d68aaaec6..5d1432a305 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -740,9 +740,25 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node if (manager == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; } + // Sovle the input order + // For example input_node:{segment_sum:1, segment_sum:2, gahter:2} + // The Original code here will bind the all operations to the first inputs of theses operatos + // However, the segment_sum operation needs two inputs, To sovle this + // We maintain a dict to count the times of the same operations, + // and bind the inputs according to the times of the op appears. + static std::unordered_map input_map = {}; + static int appear_count = 0; for (auto &replace_input : replace_graph->first) { auto pre_node = node->input(IntToSize(replace_input.second)); - manager->SetEdge(replace_input.first, 1, pre_node); + + auto it = input_map.find(replace_input.first); + if (it != input_map.end()) { + appear_count = 1 + it->second; + } else { + appear_count = 1; + } + input_map[replace_input.first] = appear_count; + manager->SetEdge(replace_input.first, appear_count, pre_node); } // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called auto replace_output = replace_graph->second; diff --git a/tests/ut/python/parallel/test_auto_parallel_segment_min.py b/tests/ut/python/parallel/test_auto_parallel_segment_min.py new file mode 100644 index 0000000000..36dde644f4 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_segment_min.py @@ -0,0 +1,71 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +import mindspore.ops as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, vectors, index): + predict = self.network(vectors, index) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, vectors, index): + return grad_all(self.network)(vectors, index) + + +def test_auto_parallel_unsortedsegmentmin(): + class Net(nn.Cell): + def __init__(self, num_segments): + super().__init__() + self.merge_op = P.UnsortedSegmentMin() + self.num_segments = num_segments + + def construct(self, vectors, index): + out = self.merge_op(vectors, index, self.num_segments) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + + x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32) + indices = Tensor(np.random.randint(16, size=(16,)), ms.int32) + + net = GradWrap(NetWithLoss(Net(16))) + net.set_auto_parallel() + net.set_train() + _executor.compile(net, x, indices) diff --git a/tests/ut/python/parallel/test_auto_parallel_segment_sum.py b/tests/ut/python/parallel/test_auto_parallel_segment_sum.py new file mode 100644 index 0000000000..5b5d2291dc --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_segment_sum.py @@ -0,0 +1,71 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +import mindspore.ops as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, vectors, index): + predict = self.network(vectors, index) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, vectors, index): + return grad_all(self.network)(vectors, index) + + +def test_auto_parallel_unsortedsegmentsum(): + class Net(nn.Cell): + def __init__(self, num_segments): + super().__init__() + self.merge_op = P.UnsortedSegmentSum() + self.num_segments = num_segments + + def construct(self, vectors, index): + out = self.merge_op(vectors, index, self.num_segments) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + + x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32) + indices = Tensor(np.random.randint(16, size=(16, 16))) + + net = GradWrap(NetWithLoss(Net(16))) + net.set_auto_parallel() + net.set_train() + _executor.compile(net, x, indices) diff --git a/tests/ut/python/parallel/test_unsortedsegmentmin.py b/tests/ut/python/parallel/test_unsortedsegmentmin.py new file mode 100644 index 0000000000..2b55dff5da --- /dev/null +++ b/tests/ut/python/parallel/test_unsortedsegmentmin.py @@ -0,0 +1,161 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops.operations.comm_ops import _VirtualDataset +from tests.ut.python.ops.test_math_ops import VirtualLoss + +context.set_context(mode=context.GRAPH_MODE) + + +grad_all = C.GradOperation(get_all=True) + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2, num_segments): + super(Net, self).__init__() + self.virtual_dataset = _VirtualDataset() + self.merge_op = P.UnsortedSegmentMin().shard((strategy1, strategy2)) + self.num_segments = num_segments + + def construct(self, vectors, segment_ids): + predict = self.merge_op(vectors, segment_ids, self.num_segments) + return predict + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +def compile_graph(x, y, segments, strategy1, strategy2, auto=False): + net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) + net.set_auto_parallel() + net.set_train() + if auto: + context.set_auto_parallel_context(parallel_mode="auto_parallel") + else: + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + _executor.compile(net, x, y) + + +def test_unsortedsegmentmin_model_parallel_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (8,) + strategy2 = (8,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_no_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (1,) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_index_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.arange(4), ms.int32) + num_segments = 4 + strategy1 = (4, 1) + strategy2 = (4,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 4) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 2, 2) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_index_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (2, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentmin_model_parallel_index_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + +def test_unsortedsegmentmin_model_parallel_float16(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float16) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + +def test_unsortedsegmentmin_model_parallel_int32(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.int32) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) diff --git a/tests/ut/python/parallel/test_unsortedsegmentsum.py b/tests/ut/python/parallel/test_unsortedsegmentsum.py new file mode 100644 index 0000000000..1b0d3b9682 --- /dev/null +++ b/tests/ut/python/parallel/test_unsortedsegmentsum.py @@ -0,0 +1,153 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops.operations.comm_ops import _VirtualDataset +from tests.ut.python.ops.test_math_ops import VirtualLoss + +context.set_context(mode=context.GRAPH_MODE) + + +grad_all = C.GradOperation(get_all=True) + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2, num_segments): + super(Net, self).__init__() + self.virtual_dataset = _VirtualDataset() + self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2)) + self.num_segments = num_segments + + def construct(self, vectors, segment_ids): + predict = self.merge_op(vectors, segment_ids, self.num_segments) + return predict + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +def compile_graph(x, y, segments, strategy1, strategy2, auto=False): + net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) + net.set_auto_parallel() + net.set_train() + if auto: + context.set_auto_parallel_context(parallel_mode="auto_parallel") + else: + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + _executor.compile(net, x, y) + + +def test_unsortedsegmentsum_model_parallel_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (8,) + strategy2 = (8,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_no_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (1,) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_index_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.arange(4), ms.int32) + num_segments = 4 + strategy1 = (4, 1) + strategy2 = (4,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_index_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4, 4)), ms.int32) + num_segments = 16 + strategy1 = (2, 2, 1) + strategy2 = (2, 2) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 4) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 2, 2) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_index_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (2, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4, 4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2, 1) + compile_graph(x, y, num_segments, strategy1, strategy2)