From 0cfe72cd223049394313477702d6cefb01feb94f Mon Sep 17 00:00:00 2001 From: yao_yf Date: Sat, 17 Oct 2020 10:08:04 +0800 Subject: [PATCH] auto parallel dynamic --- .../auto_parallel/operator_costmodel.cc | 55 +++++ .../auto_parallel/operator_costmodel.h | 26 +++ .../ccsrc/frontend/parallel/dynamic_creator.h | 1 + .../parallel/ops_info/gather_v2_p_info.cc | 51 +++-- .../parallel/ops_info/gather_v2_p_info.h | 3 + .../parallel/ops_info/ops_info_head_files.h | 1 + .../frontend/parallel/ops_info/ops_utils.h | 4 + .../frontend/parallel/ops_info/unique_info.cc | 192 ++++++++++++++++++ .../frontend/parallel/ops_info/unique_info.h | 60 ++++++ .../frontend/parallel/step_auto_parallel.cc | 2 +- .../parallel/tensor_layout/arrangement.cc | 2 +- .../redistribution_layout_transfer.cc | 14 +- .../redistribution_layout_transfer.h | 2 + .../parallel/tensor_layout/tensor_layout.cc | 4 + .../parallel/tensor_layout/tensor_layout.h | 2 + .../tensor_layout/tensor_redistribution.cc | 27 ++- mindspore/ops/_grad/grad_array_ops.py | 6 +- .../ut/python/parallel/test_dynamic_shape.py | 118 +++++++++++ 18 files changed, 541 insertions(+), 29 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h create mode 100644 tests/ut/python/parallel/test_dynamic_shape.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index 2d5460bfcc..46f3fc823a 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -808,6 +808,61 @@ double LayerNormCost::GetForwardComputationCost(const std::vector &i return result; } +double UniqueCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + return 0.0; +} +double UniqueCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input = inputs[0]; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input_shape = input.shape(); + Shape input_slice_shape = input.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_shape.size(); ++i) { + used_device_num *= input_shape[i] / input_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + return result; +} +double UniqueCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input_slice_shape = inputs[0].slice_shape(); + double result = ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); + return result; +} +double UniqueCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input = inputs[0]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input_shape = input.shape(); + Shape input_slice_shape = input.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_shape.size(); ++i) { + used_device_num *= input_shape[i] / input_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + return result; +} + double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const { double result = 0.0; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index e8f0360625..a2f5509093 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -606,6 +606,32 @@ class LayerNormCost : public OperatorCost { using DropOutCostPtr = std::shared_ptr; +class UniqueCost : public OperatorCost { + public: + explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + UniqueCost() : OperatorCost(true) {} + ~UniqueCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; +}; + +using UniqueCostPtr = std::shared_ptr; + class GatherV2Cost : public OperatorCost { public: explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 78dfbc2f06..6bd6667ce9 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -182,6 +182,7 @@ REGISTER(DropoutInfo); REGISTER(PackInfo); REGISTER(ConcatInfo); REGISTER(SplitInfo); +REGISTER(UniqueInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 9ad640189b..eaf9e99228 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -151,6 +151,10 @@ Status GatherV2PInfo::GetAttrs() { MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_; return FAILED; } + + if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { + dynamic_shape_indices_ = true; + } return SUCCESS; } @@ -240,7 +244,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // axis=0, index_shape(0)%param_strategy(0) must be 0 Shape index_shape = inputs_shape_.at(1); - if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { + if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) { MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; return FAILED; } @@ -357,13 +361,7 @@ Status GatherV2PInfo::InferDevMatrixShape() { return SUCCESS; } -Status GatherV2PInfo::InferTensorMap() { - if (manual_split_) { - inputs_tensor_map_.push_back({1, 0}); - inputs_tensor_map_.push_back({-1, 1}); - outputs_tensor_map_.push_back({-1, 1, 0}); - return SUCCESS; - } +void GatherV2PInfo::InferInputsTensorMap() { // infer input tensor map // param_strategy(axis) != 1 size_t param_size = inputs_shape_.at(0).size(); @@ -373,7 +371,7 @@ Status GatherV2PInfo::InferTensorMap() { Shape tensor_map_params; auto param_strategy = strategy_->GetInputDim().at(0); if (param_strategy.at(IntToSize(axis_)) != 1) { - tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); + tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE); for (size_t i = 0; i < param_size; ++i) { tensor_map_params.push_back(SizeToInt(i)); } @@ -386,9 +384,17 @@ Status GatherV2PInfo::InferTensorMap() { tensor_map_index.push_back(SizeToInt(index_size - i - 1)); } } + inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); +} +void GatherV2PInfo::InferOutputsTensorMap() { // infer output tensor map + size_t param_size = inputs_shape_.at(0).size(); + size_t index_size = inputs_shape_.at(1).size(); + size_t total_size = param_size + index_size; Shape tensor_map_out; + auto param_strategy = strategy_->GetInputDim().at(0); if (param_strategy.at(IntToSize(axis_)) == 1) { // param_strategy(axis) == 1 for (size_t i = 0; i < param_size; ++i) { @@ -403,25 +409,40 @@ Status GatherV2PInfo::InferTensorMap() { } else { // param_strategy(axis) != 1 if (axis_ == 0) { - tensor_map_out.insert(tensor_map_out.end(), 0); - tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); + if (dynamic_shape_indices_) { + tensor_map_out.insert(tensor_map_out.end(), MAP_NONE); + } else { + tensor_map_out.insert(tensor_map_out.end(), 0); + } + tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE); for (size_t i = 1; i < param_size; ++i) { tensor_map_out.push_back(i); } } else { for (size_t i = 0; i < param_size; ++i) { if (i == IntToSize(axis_)) { - tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE); } else { + if (i == 0 && dynamic_shape_indices_) { + tensor_map_out.push_back(MAP_NONE); + } tensor_map_out.push_back(SizeToInt(param_size - i - 1)); } } } } - - inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); - inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); +} + +Status GatherV2PInfo::InferTensorMap() { + if (manual_split_) { + inputs_tensor_map_.push_back({1, 0}); + inputs_tensor_map_.push_back({-1, 1}); + outputs_tensor_map_.push_back({-1, 1, 0}); + return SUCCESS; + } + InferInputsTensorMap(); + InferOutputsTensorMap(); return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index 899ba73db7..ba516dca06 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -57,6 +57,8 @@ class GatherV2PInfo : public OperatorInfo { Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; + void InferInputsTensorMap(); + void InferOutputsTensorMap(); Status GetAttrs() override; Status ComputeReplaceGraph(const CNodePtr &cnode); @@ -77,6 +79,7 @@ class GatherV2PInfo : public OperatorInfo { Shape out_dev_matrix_shape_; Group group_; bool manual_split_ = false; + bool dynamic_shape_indices_ = false; std::vector param_split_shapes_; std::vector index_offsets_; }; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 7f0b83bc96..edc96610ac 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -43,5 +43,6 @@ #include "frontend/parallel/ops_info/split_info.h" #include "frontend/parallel/ops_info/pack_info.h" #include "frontend/parallel/ops_info/broadcast_to_info.h" +#include "frontend/parallel/ops_info/unique_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 5b867c4403..0eb4758648 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -48,6 +48,9 @@ constexpr size_t DROPOUT_DO_MASK_KEEP_PROB_INDEX = 3; constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1; constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2; constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2; +constexpr size_t UNIQUE_INPUTS_SIZE = 1; +constexpr size_t UNIQUE_INPUT_SIZE = 1; +constexpr size_t UNIQUE_OUTPUTS_SIZE = 2; constexpr double EPS = 1e-6; constexpr double INF = 1e20; @@ -285,6 +288,7 @@ constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char ADD[] = "Add"; constexpr char DROPOUT[] = "Dropout"; constexpr char KStridedSlice[] = "StridedSlice"; +constexpr char UNIQUE[] = "Unique"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc new file mode 100644 index 0000000000..46b1c3aefe --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc @@ -0,0 +1,192 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/unique_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +/* + * unique has one input, two outputs. Currently, unique cannot be split. + */ +Status UniqueInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + for (auto shp : inputs_shape_) { + TensorMap out_tensor_map; + TensorMap in_tensor_map; + for (size_t i = 0; i < shp.size(); ++i) { + in_tensor_map.push_back(MAP_NONE); + out_tensor_map.push_back(MAP_NONE); + } + inputs_tensor_map_.push_back(in_tensor_map); + outputs_tensor_map_.push_back(out_tensor_map); + outputs_tensor_map_.push_back(out_tensor_map); + } + return SUCCESS; +} + +Status UniqueInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if (inputs_layout == nullptr || outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + TensorLayout input_layout; + TensorLayout output_layout; + TensorLayout index_layout; + if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || + (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) || + (index_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS)) { + return FAILED; + } + inputs_layout->push_back(input_layout); + outputs_layout->push_back(output_layout); + outputs_layout->push_back(index_layout); + return SUCCESS; +} + +Status UniqueInfo::InferTensorInfo() { + TensorLayouts inputs_layout; + TensorLayouts outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + for (size_t i = 0; i < inputs_layout.size(); ++i) { + TensorInfo input_tensor_info(inputs_layout[i]); + inputs_tensor_info_.push_back(input_tensor_info); + } + for (size_t i = 0; i < outputs_layout.size(); ++i) { + TensorInfo output_tensor_info(outputs_layout[i]); + outputs_tensor_info_.push_back(output_tensor_info); + } + return SUCCESS; +} + +Status UniqueInfo::InferDevMatrixShape() { + dev_matrix_shape_.push_back(dev_num_); + return SUCCESS; +} + +Status UniqueInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success"; + return SUCCESS; +} + +Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { + Strategys stras = strategy->GetInputDim(); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + return FAILED; + } + for (Dimensions stra : stras) { + if (stra.size() != UNIQUE_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + return FAILED; + } + } + int32_t stage = strategy->GetInputStage(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + dev_num_ = dev_num; + if (stras[0][0] != 1) { + MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; + return FAILED; + } + return SUCCESS; +} + +Status UniqueInfo::GetAttrs() { + if ((inputs_shape_.size() != UNIQUE_INPUTS_SIZE) || (outputs_shape_.size() != UNIQUE_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size " + << outputs_shape_.size() << " is wrong."; + return FAILED; + } + return SUCCESS; +} + +Status UniqueInfo::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + OperatorVector mirror_op; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status UniqueInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status UniqueInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +Status UniqueInfo::GenerateStrategies(int32_t stage_id) { + if (inputs_shape_.size() != UNIQUE_INPUTS_SIZE) { + return FAILED; + } + if (inputs_shape_[0].size() != UNIQUE_INPUT_SIZE) { + return FAILED; + } + Shape input0_split; + input0_split.emplace_back(0); + Shapes splittable_inputs = {input0_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h new file mode 100644 index 0000000000..1a7d2851d6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIQUE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIQUE_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class UniqueInfo : public OperatorInfo { + public: + UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~UniqueInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferAsLossDivisor() override { return SUCCESS; } + + private: + int32_t dev_num_ = 1; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIQUE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 326be244a4..a72c8abb8a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -285,7 +285,7 @@ bool IsSplittableOperator(const std::string &op_name) { EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH, EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, - SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD}; + SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc index 40a793b0e3..eb22c631ee 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc @@ -39,7 +39,7 @@ Status Arrangement::Init(const Shape &array) { } bool Arrangement::IsValidArrangement() { - return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0; }); + return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0 && value != -1; }); } void Arrangement::ComputeSize() { diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc index 8ef5afaadf..c6f1c741d9 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc @@ -21,7 +21,19 @@ namespace mindspore { namespace parallel { -Status RedistributionLayoutTransfer::CheckValidTransfer() { return Status::SUCCESS; } +Status RedistributionLayoutTransfer::CheckValidTransfer() { + Shape from_shape = from_in_.tensor_shape().array(); + if (std::find(from_shape.begin(), from_shape.end(), -1) != from_shape.end()) { + is_dynamic_shape_ = true; + if (from_in_ != to_in_) { + MS_LOG(ERROR) << "In dynamic shape scene, the from_tensor_shape should be equal to to_tensor_shape"; + MS_LOG(ERROR) << "from_in layout" << from_in_.ToString(); + MS_LOG(ERROR) << "to_in layout" << to_in_.ToString(); + return Status::FAILED; + } + } + return Status::SUCCESS; +} /* * unify device arrangement between in_layout and out_layout diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h index 589eea2383..800f997645 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h @@ -29,10 +29,12 @@ class RedistributionLayoutTransfer : public LayoutTransfer { RedistributionLayoutTransfer() = default; ~RedistributionLayoutTransfer() override = default; std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; + bool IsDynamicShape() const { return is_dynamic_shape_; } private: Status CheckValidTransfer() override; std::shared_ptr UnifyDeviceArrangement() const; + bool is_dynamic_shape_ = false; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc index 16fafbd113..2549fa5339 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -357,6 +357,10 @@ bool TensorLayout::operator==(const TensorLayout &t1) const { return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); } +bool TensorLayout::operator!=(const TensorLayout &t1) const { + return !(IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); +} + /* * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ] * example 1: diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index fefeae3e06..22a00ef4ae 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -82,6 +82,8 @@ class TensorLayout { bool operator==(const TensorLayout &t1) const; + bool operator!=(const TensorLayout &t1) const; + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc index f7ef356303..d9b546bb4a 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -82,17 +82,24 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL if (status != Status::SUCCESS) { return nullptr; } - std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); - if (ptr == nullptr) { - MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; - return nullptr; - } - if (!ptr->ExpandAble()) { - expand_able_ = false; - return InferTensorRedistributionOperatorListUnExpand(is_cost_model); + TensorLayout from_layout; + TensorLayout to_layout; + if (layout_transfer.IsDynamicShape()) { + from_layout = layout_transfer.from_in(); + to_layout = layout_transfer.to_in(); + } else { + std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); + if (ptr == nullptr) { + MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; + return nullptr; + } + if (!ptr->ExpandAble()) { + expand_able_ = false; + return InferTensorRedistributionOperatorListUnExpand(is_cost_model); + } + from_layout = ptr->from_in(); + to_layout = ptr->to_in(); } - TensorLayout from_layout = ptr->from_in(); - TensorLayout to_layout = ptr->to_in(); MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString(); MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index c704969eea..8255c7088c 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -33,6 +33,7 @@ reduce_sum = P.ReduceSum() unsorted_segment_sum = P.UnsortedSegmentSum() transpose = P.Transpose() shape_op = P.Shape() +dyn_shape_op = P.DynamicShape() reshape = P.Reshape() size_op = P.Size() invert_permutation = P.InvertPermutation() @@ -365,7 +366,10 @@ def get_bprop_gather_v2(self): # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) perm_1 = _generate_shape_index(out_shp, ind_shp, axis) values_transpose = transpose(dout, perm_1) - params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + if -1 in shape_op(x): + params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) + else: + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) perm_2 = _generate_inverse_index(x_shp, axis) params_grad = transpose(params_grad, perm_2) diff --git a/tests/ut/python/parallel/test_dynamic_shape.py b/tests/ut/python/parallel/test_dynamic_shape.py new file mode 100644 index 0000000000..86a52873b7 --- /dev/null +++ b/tests/ut/python/parallel/test_dynamic_shape.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.nn import TrainOneStepCell, Momentum +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + +def test_unique_column_split(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.unique = P.Unique().shard(((1,),)) + self.relu = P.ReLU() + self.mul = P.Mul() + self.embedding_lookp = P.GatherV2().shard(((1, 8), (1,))) + self.embedding_table = Parameter(initializer('normal', [2000, 128]), + name='embedding_table') + self.gatherv2 = P.GatherV2().shard(((1, 8), (1,))) + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") + + def construct(self, indices): + indices_flatten = self.reshape(indices, (-1,)) + unique_id, unique_idx = self.unique(indices_flatten) + unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0) + weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0) + weight = self.reshape(weight_flatten, (32, 64, 128)) + vx = self.mul(weight, self.mul_weight) + return vx + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="auto_parallel") + x = Tensor(np.ones([32, 64]), dtype=ms.int32) + net = Net() + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, x) + +def test_unique_row_split(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.unique = P.Unique().shard(((1,),)) + self.relu = P.ReLU() + self.mul = P.Mul() + self.embedding_lookp = P.GatherV2().shard(((8, 1), (1,))) + self.embedding_table = Parameter(initializer('normal', [2000, 128]), + name='embedding_table') + self.gatherv2 = P.GatherV2().shard(((1, 1), (8,))) + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") + + def construct(self, indices): + indices_flatten = self.reshape(indices, (-1,)) + unique_id, unique_idx = self.unique(indices_flatten) + unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0) + weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0) + weight = self.reshape(weight_flatten, (32, 64, 128)) + vx = self.mul(weight, self.mul_weight) + return vx + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="stand_alone") + x = Tensor(np.ones([32, 64]), dtype=ms.int32) + net = Net() + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, x)