From 6066b1683874eec1d0e795eabb186a078c4c5c29 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Thu, 17 Sep 2020 15:37:27 +0800 Subject: [PATCH] implement parallel Pack --- .../auto_parallel/operator_costmodel.h | 2 + .../ccsrc/frontend/parallel/dynamic_creator.h | 1 + .../ccsrc/frontend/parallel/node_check.cc | 1 - .../parallel/ops_info/ops_info_head_files.h | 1 + .../frontend/parallel/ops_info/pack_info.cc | 253 ++++++++++++++++++ .../frontend/parallel/ops_info/pack_info.h | 62 +++++ .../frontend/parallel/step_auto_parallel.cc | 10 +- .../ccsrc/frontend/parallel/step_parallel.cc | 44 +-- mindspore/ccsrc/utils/convert_utils_py.cc | 10 + mindspore/ops/_grad/grad_comm_ops.py | 14 +- mindspore/ops/functional.py | 1 + tests/ut/python/parallel/test_pack.py | 188 +++++++++++++ 12 files changed, 559 insertions(+), 28 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h create mode 100644 tests/ut/python/parallel/test_pack.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index fbda9c1546..b9caa7be1c 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost { using SoftmaxCostPtr = std::shared_ptr; using TileCost = SoftmaxCost; using TileCostPtr = std::shared_ptr; +using PackCost = TileCost; +using PackCostPtr = std::shared_ptr; using ConcatCost = TileCost; using ConcatCostPtr = std::shared_ptr; using SplitCost = TileCost; diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 9d8d114bb8..ca33986334 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -178,6 +178,7 @@ REGISTER(EmbeddingLookupInfo); REGISTER(TileInfo); REGISTER(StridedSliceInfo); REGISTER(DropoutInfo); +REGISTER(PackInfo); REGISTER(ConcatInfo); REGISTER(SplitInfo); } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/node_check.cc b/mindspore/ccsrc/frontend/parallel/node_check.cc index 8470a5de90..2bb6f99496 100644 --- a/mindspore/ccsrc/frontend/parallel/node_check.cc +++ b/mindspore/ccsrc/frontend/parallel/node_check.cc @@ -39,7 +39,6 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, TILE_SHAPE, TUPLE_DIV, TUPLE_TO_ARRAY, - MAKE_LIST, MAKE_DICT, MAKE_SLICE, MAKE_RECORD, diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 18e3dec6ed..53e87e478a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -41,5 +41,6 @@ #include "frontend/parallel/ops_info/strided_slice_info.h" #include "frontend/parallel/ops_info/concat_info.h" #include "frontend/parallel/ops_info/split_info.h" +#include "frontend/parallel/ops_info/pack_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc new file mode 100644 index 0000000000..19db3b7bf0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/pack_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +Status PackInfo::GetAttrs() { + int axis = 0; + auto axis_iter = attrs_.find(AXIS); + if (axis_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(axis_iter->second); + if (axis_iter->second->isa()) { + axis = axis_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis is not int"; + return FAILED; + } + } else { + MS_LOG(ERROR) << name_ << ": Can not find the axis attr"; + return FAILED; + } + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + int dim = SizeToInt(inputs_shape_[0].size()); + + if (axis < 0) { + axis = axis + dim; + } + axis_ = SizeToInt(axis); + return SUCCESS; +} + +Status PackInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + for (size_t i = 0; i < stra.size(); ++i) { + auto strategy_ele = stra[i]; + if (axis_ > strategy_ele.size()) { + MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_; + return FAILED; + } + + for (size_t j = 0; j < strategy_ele.size(); ++j) { + if (strategy_ele[j] != stra[0][j]) { + MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal"; + return FAILED; + } + } + } + + return SUCCESS; +} + +Status PackInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status PackInfo::InferTensorMap() { + TensorMap in_tensor_map; + TensorMap out_tensor_map; + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << "The inputs shape is empty"; + return FAILED; + } + + int32_t size = SizeToInt(inputs_shape_[0].size()); + for (int i = 0; i < size; ++i) { + in_tensor_map.push_back(size - i - 1); + out_tensor_map.push_back(size - i - 1); + } + + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + inputs_tensor_map_.push_back(in_tensor_map); + } + + out_tensor_map.insert(out_tensor_map.begin() + axis_, MAP_NONE); + outputs_tensor_map_.push_back(out_tensor_map); + return SUCCESS; +} + +Status PackInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; + return FAILED; + } + + Shape input_tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group for input failed."; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty."; + return SUCCESS; + } + + OperatorVector input_op; + input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + mirror_ops_.push_back(input_op); + } + + return SUCCESS; +} + +Status PackInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + } + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +void PackInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = true; + } +} + +Status PackInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +Status PackInfo::GenerateStrategies(int32_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_split; + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + input_split.push_back(1); + } + + // to generate the first input's strategy + Shapes splittable_input = {input_split}; + Shapes tmp_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + + // the others strategies are equal to the first input's strategy + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; + return FAILED; + } + Strategys tmp_strategy; + Dimensions first_input_strategy = sp->GetInputDim()[0]; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + tmp_strategy.push_back(first_input_strategy); + } + sp->ResetInputs(tmp_strategy); + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status PackInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status PackInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h new file mode 100644 index 0000000000..a180776430 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class PackInfo : public OperatorInfo { + public: + PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~PackInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + + private: + size_t axis_ = 0; +}; + +using PackInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index d6f48071d8..9442cc63be 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -116,7 +116,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { std::vector ExtractInputParameterByNode(const CNodePtr &node) { std::vector is_parameter; std::vector node_inputs{node->inputs()}; - if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) { + if ((node_inputs.size() == 2) && + (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { node_inputs = node_inputs[1]->cast()->inputs(); } for (size_t i = 1; i < node_inputs.size(); ++i) { @@ -193,7 +194,8 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { std::vector inputs_type_len; std::vector node_inputs{node->inputs()}; - if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) { + if ((node_inputs.size() == 2) && + (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { node_inputs = node_inputs[1]->cast()->inputs(); } @@ -259,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) { {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, - MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, + MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, @@ -281,7 +283,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { return false; } bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); - if (bool_result && (prim->name() != MAKE_TUPLE)) { + if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) { MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); } else if (prim->name() == CAST) { if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a3e80d31e8..e504059bf1 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -450,7 +450,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ AnfNodeIndexSet node_set = manager->node_users()[node]; CNodePtr insert_node_new; - if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) { + if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) { MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node"; return; } @@ -851,7 +851,8 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); - if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) { + if ((node->inputs().size() == 2) && + (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; return; } @@ -1055,7 +1056,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " << node->fullname_with_scope(); } - auto tuple_shape_ptr = dyn_cast(base_shape_ptr); + auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); for (auto &shape : tuple_shape) { @@ -1436,7 +1437,7 @@ void ExtractInformation(const std::vector &all_nodes) { SetVirtualDatasetStrategy(cnode); ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == MAKE_TUPLE) { + if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) { continue; } auto attrs = prim->attrs(); @@ -2459,9 +2460,9 @@ Status ParallelInit() { return SUCCESS; } -void HandleForwardMakeTuple(const std::vector &all_nodes) { +void HandleForwardMakeTupleAndMakeList(const std::vector &all_nodes) { for (auto &node : all_nodes) { - if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) { + if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) { continue; } @@ -2473,25 +2474,28 @@ void HandleForwardMakeTuple(const std::vector &all_nodes) { FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); - auto make_tuple_user = manager->node_users()[cnode]; - if (make_tuple_user.size() != 1) { - MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size(); + std::string op_type = AnfNodeIsPrimitive(node, MAKE_TUPLE) ? MAKE_TUPLE : MAKE_LIST; + + auto make_tuple_list_user = manager->node_users()[cnode]; + if (make_tuple_list_user.size() != 1) { + MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user must be 1, but got " << make_tuple_list_user.size(); } - CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple_next_cnode); + CNodePtr make_tuple_list_next_cnode = make_tuple_list_user.pop().first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode); - std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode); - if (!IsParallelCareNode(make_tuple_next_cnode)) { - MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info"; + std::string make_tuple__list_user_prim_name = GetPrimName(make_tuple_list_next_cnode); + if (!IsParallelCareNode(make_tuple_list_next_cnode)) { + MS_LOG(INFO) << "The " << op_type << "'s user is " << make_tuple__list_user_prim_name + << ", no need to set operator info"; continue; } - if (make_tuple_next_cnode->inputs().size() != 2) { - MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got " - << make_tuple_next_cnode->inputs().size() - 1; + if (make_tuple_list_next_cnode->inputs().size() != 2) { + MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user only support 1 input, but got " + << make_tuple_list_next_cnode->inputs().size() - 1; } - MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name; - OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode); + MS_LOG(INFO) << "Set the " << op_type << "'s operator info, and the op name is " << make_tuple__list_user_prim_name; + OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_list_next_cnode); MS_EXCEPTION_IF_NULL(op_info); cnode->set_user_data(op_info); } @@ -2695,7 +2699,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) ReshapeInit(all_nodes); } - HandleForwardMakeTuple(all_nodes); + HandleForwardMakeTupleAndMakeList(all_nodes); // if the input or parameter has multiple users, check whether its split strategies are consistent. CheckParameterSplit(all_nodes); diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 2fda223195..74e70d3dbd 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -348,6 +348,16 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py } auto tuple = std::make_shared(ptr_list); return tuple; + } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { + py::list shape_list = shape_obj.cast(); + py::list typeid_list = type_obj.cast(); + AbstractBasePtrList ptr_list; + for (size_t it = 0; it < shape_list.size(); ++it) { + auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]); + ptr_list.push_back(tensor_it); + } + auto list = std::make_shared(ptr_list); + return list; } else if (shape_obj.is_none() && type_obj.is_none()) { // AbstractNone indicates there is no output for this CNode node. auto abstract_none = std::make_shared(); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index d1282f8ecb..40ca4d77b5 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -228,11 +228,19 @@ def get_bprop_virtual_div_operator(self): dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) return (dx,) - dx = () - input_nums = F.tuple_len(dout) + if F.issubclass_(F.typeof(dout), mstype.tuple_): + dx = () + input_nums = F.tuple_len(dout) + for i in range(input_nums): + ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) + dx = dx + (ele_grad,) + return (dx,) + + dx = [] + input_nums = F.list_len(dout) for i in range(input_nums): ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) - dx = dx + (ele_grad,) + dx.append(ele_grad) return (dx,) return bprop diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 9de52b091e..6173f3a8d7 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -92,6 +92,7 @@ dict_getitem = Primitive('dict_getitem') dict_setitem = Primitive('dict_setitem') tuple_div = Primitive("tuple_div") tuple_len = Primitive("tuple_len") +list_len = Primitive("list_len") tuple_reversed = Primitive("tuple_reversed") make_range = Primitive("make_range") make_tuple = Primitive('make_tuple') diff --git a/tests/ut/python/parallel/test_pack.py b/tests/ut/python/parallel/test_pack.py new file mode 100644 index 0000000000..ccc8356703 --- /dev/null +++ b/tests/ut/python/parallel/test_pack.py @@ -0,0 +1,188 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True): + super(Net, self).__init__() + self.pack = P.Pack(axis=axis).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + if is_parameter: + self.weight1 = Parameter(weight1, "w1") + else: + self.weight1 = weight1 + self.weight2 = Parameter(weight2, "w2") + + def construct(self, x): + out = self.pack([self.weight1, self.weight2]) + out = self.mul(x, out) + return out + + +class Net1(nn.Cell): + def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None): + super(Net1, self).__init__() + self.pack = P.Pack(axis=axis).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + self.weight1 = Parameter(weight1, "w1") + self.weight2 = Parameter(weight2, "w2") + + def construct(self, x): + out = self.mul(x, self.weight1) + out = self.pack([out, self.weight2]) + return out + + +class Net2(nn.Cell): + def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True): + super(Net2, self).__init__() + self.pack = P.Pack(axis=axis).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + if is_parameter: + self.weight1 = Parameter(weight1, "w1") + else: + self.weight1 = weight1 + self.weight2 = Parameter(weight2, "w2") + self.weight3 = Parameter(weight2, "w3") + + def construct(self, x): + out = self.pack([self.weight1, self.weight2, self.weight3]) + out = self.mul(x, out) + return out + + +_w1 = Tensor(np.ones([48, 64]), dtype=ms.float32) +_w2 = Tensor(np.ones([48, 64]), dtype=ms.float32) +_w3 = Tensor(np.ones([48, 64]), dtype=ms.float32) +_x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32) +_x1 = Tensor(np.ones([48, 64]), dtype=ms.float32) +_x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x) + context.reset_auto_parallel_context() + + +def compile_net1(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x1) + context.reset_auto_parallel_context() + + +def compile_net2(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x2) + context.reset_auto_parallel_context() + + +def test_pack_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, 0, strategy1, strategy2) + compile_net(net) + + +def test_pack_parameter_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, 0, strategy1, strategy2) + compile_net(net) + + +def test_pack_tensor_and_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, _w2, 0, strategy1, strategy2, False) + compile_net(net) + + +def test_pack_output(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((4, 2), (4, 2)) + net = Net1(_w1, _w2, 0, strategy1, strategy2) + compile_net1(net) + + +def test_pack_output_axis1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((4, 2), (4, 2)) + net = Net1(_w1, _w2, 1, strategy1, strategy2) + compile_net1(net) + + +def test_pack_output_no_full_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((4, 2), (4, 2)) + net = Net1(_w1, _w2, 0, strategy1, strategy2) + compile_net1(net) + + +def test_pack_no_strategy(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = None + strategy2 = ((4, 2), (4, 2)) + net = Net1(_w1, _w2, 0, strategy1, strategy2) + compile_net1(net) + + +def test_pack_no_strategy_axis1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = None + strategy2 = ((4, 2), (4, 2)) + net = Net1(_w1, _w2, 1, strategy1, strategy2) + compile_net1(net) + + +def test_pack_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net1(_w1, _w2, 0) + compile_net1(net) + + +def test_pack_auto_parallel_axis1(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net1(_w1, _w2, 1) + compile_net1(net) + + +def test_pack_auto_parallel_3_tensor(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w1, _w2, _w3) + compile_net2(net)