diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index a1764d98e9..7d1f9b5b4c 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -135,6 +135,7 @@ REGISTER(EluInfo); REGISTER(ReLUInfo); REGISTER(RepeatElementsInfo); REGISTER(TensorDotInfo); +REGISTER(RangeInfo); REGISTER(ReLU6Info); REGISTER(ReLUV2Info); REGISTER(SoftplusInfo); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 16b92a63ce..767b5784d6 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -43,6 +43,7 @@ #include "frontend/parallel/ops_info/concat_info.h" #include "frontend/parallel/ops_info/split_info.h" #include "frontend/parallel/ops_info/tensordot_info.h" +#include "frontend/parallel/ops_info/range_info.h" #include "frontend/parallel/ops_info/pack_info.h" #include "frontend/parallel/ops_info/broadcast_to_info.h" #include "frontend/parallel/ops_info/unique_info.h" diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 6338b08384..cb76c85822 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -104,6 +104,9 @@ constexpr char GROUP[] = "group"; constexpr char FUSION[] = "fusion"; constexpr char AXIS[] = "axis"; constexpr char AXES[] = "axes"; +constexpr char START[] = "start"; +constexpr char LIMIT[] = "limit"; +constexpr char DELTA[] = "delta"; constexpr char OUTPUT_NUM[] = "output_num"; constexpr char SPLIT_COUNT[] = "split_count"; constexpr char SPLIT_DIM[] = "split_dim"; @@ -193,6 +196,7 @@ constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossE constexpr char RELU[] = "ReLU"; constexpr char REPEAT_ELEMENTS[] = "RepeatElements"; constexpr char TENSOR_DOT[] = "TensorDot"; +constexpr char RANGE[] = "Range"; constexpr char ONEHOT[] = "OneHot"; constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask"; constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc new file mode 100644 index 0000000000..73bd3eeb50 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/range_info.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +float RangeInfo::GetRangeAttr(const std::string &arg) { + auto iter = attrs_.find(arg); + if (iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attr for " << arg; + } + + MS_EXCEPTION_IF_NULL(iter->second); + if (!iter->second->isa()) { + MS_LOG(EXCEPTION) << name_ << ": The type of attr is not float, the attr is " << arg; + } + + return iter->second->cast()->value(); +} + +Status RangeInfo::GetAttrs() { + start_ = GetRangeAttr(START); + limit_ = GetRangeAttr(LIMIT); + delta_ = GetRangeAttr(DELTA); + MS_LOG(INFO) << name_ << ": The start is " << start_ << ", the limit is " << limit_ << ", the delta is " << delta_; + return SUCCESS; +} + +Status RangeInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + return SUCCESS; +} + +Status RangeInfo::InferDevMatrixShape() { + Strategys stra = strategy_->GetInputDim(); + dev_matrix_shape_ = stra[0]; + split_num_ = stra[0][0]; + return SUCCESS; +} + +Status RangeInfo::InferMirrorOps() { return SUCCESS; } + +Status RangeInfo::InferForwardCommunication() { return SUCCESS; } + +Status RangeInfo::InferTensorMap() { + TensorMap input_tensor_map = {0}, output_tensor_map = {0}; + + inputs_tensor_map_.push_back(input_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + return SUCCESS; +} + +Status RangeInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + } + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + outputs_tensor_info_.push_back(output_tensor_info); + + for (auto &tensor_info : inputs_tensor_info_) { + MS_LOG(INFO) << name_ << ": The input layout: " << tensor_info.tensor_layout().ToString(); + } + MS_LOG(INFO) << name_ << ": The output layout: " << outputs_tensor_info_[0].tensor_layout().ToString(); + return SUCCESS; +} + +Status RangeInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init success"; + return SUCCESS; +} + +Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} + +Status RangeInfo::InferNewAttr() { + CheckGlobalDeviceManager(); + int64_t rank = g_device_manager->global_rank(); + + // If repeated calculation and repeated num as the last dimension of dev-matrix, + // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ + // are repeated calculation, and these rank have the same 'new_start_'. + // If repeated calculation and repeated num as the first dimension of dev-matrix, + // the dev-matrix is [repeated_calc_num_, split_num_], so rank 0 and rank split_num_ and so on + // are repeated calculation, and these rank have the same 'new_start_'. + float start_bias = inputs_shape_[0][0] / split_num_ * delta_; + if (repeated_num_in_dev_matrix_right_) { + new_start_ = start_ + start_bias * (rank / repeated_calc_num_); + } else { + new_start_ = start_ + start_bias * (rank % split_num_); + } + + new_limit_ = new_start_ + start_bias; + MS_LOG(INFO) << name_ << ": The new start is " << new_start_ << ", the new limit is " << new_limit_; + return SUCCESS; +} + +Status RangeInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed"; + return FAILED; + } + + if (InferNewAttr() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer new attr failed"; + return FAILED; + } + + Attr attr_start = std::make_pair(START, MakeValue(new_start_)); + Attr attr_limit = std::make_pair(LIMIT, MakeValue(new_limit_)); + Attr attr_delta = std::make_pair(DELTA, MakeValue(delta_)); + OperatorAttrs attrs = {attr_start, attr_limit, attr_delta}; + auto new_range_op = gen_g.PushBack({gen_g.NewOpInst(RANGE, attrs), gen_g.virtual_input_node()}); + std::vector> input_nodes = {std::make_pair(new_range_op, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, new_range_op)); + + return SUCCESS; +} + +ReplaceGraphPtr RangeInfo::replace_graph(const CNodePtr &cnode) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + return replace_graph_; +} + +Status RangeInfo::GenerateStrategies(int64_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status RangeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + return SetCostUnderStrategyBase(strategy); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h new file mode 100644 index 0000000000..38b2bad89a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/range_info.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RANGE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RANGE_INFO_H_ + +#include +#include +#include +#include + +#include "utils/ms_utils.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +// Range op: +// (start=8.0, limit=16.0, delta=1.0) -> [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0] +// (start=8.0, limit=None, delta=1.0) -> [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] +// when entering the step_parallel, the limit=None has been processed +// the parallel op need to modify the 'start' and 'limit' +class RangeInfo : public OperatorInfo { + public: + RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~RangeInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int64_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + Status InferNewAttr(); + float GetRangeAttr(const std::string &arg); + Status ComputeReplaceGraph(const CNodePtr &cnode); + + float start_ = 0.0; + float limit_ = 0.0; + float delta_ = 0.0; + float new_start_ = 0.0; + float new_limit_ = 0.0; + int64_t split_num_ = 1; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RANGE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 26ce63c4fc..719394e6b5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -316,7 +316,7 @@ bool IsSplittableOperator(const std::string &op_name) { EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, - UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT}; + UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_range.py b/tests/ut/python/parallel/test_range.py new file mode 100644 index 0000000000..2e8780d1a4 --- /dev/null +++ b/tests/ut/python/parallel/test_range.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.range = nn.Range(start, limit, delta) + self.range.range_x.shard(strategy2) + self.mul2 = P.Mul().shard(strategy3) + self.weight = Parameter(weight, "w") + + + def construct(self, x, b): + r_out = self.range() + out = self.mul(x, self.weight) + out = self.mul2(out, r_out) + return out + +dev_num = 4 +_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) +_b = Tensor(np.ones([8]), dtype=ms.float32) +_w1 = Tensor(np.ones([64, 8]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, optimizer=opt) + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_range(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=dev_num, global_rank=2) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((2,),) + strategy3 = ((2, 2), (2,)) + net = Net(_w1, 0, 8, 1, strategy1, strategy2, strategy3) + compile_net(net) + + +def test_range2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=dev_num, global_rank=0) + strategy1 = ((4, 1), (4, 1)) + strategy2 = ((1,),) + strategy3 = ((4, 1), (1,)) + net = Net(_w1, 0.0, 4.0, 0.5, strategy1, strategy2, strategy3) + compile_net(net) + + +def test_range3(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) + net = Net(_w1, 4.0, None, 0.5) + compile_net(net)