diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 1255d79bdc..d376e32213 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include "parallel/auto_parallel/graph_costmodel.h" - #include #include #include @@ -24,6 +21,10 @@ #include #include +#include "parallel/auto_parallel/graph_costmodel.h" +#include "parallel/ops_info/reshape_info.h" +#include "parallel/step_auto_parallel.h" + namespace mindspore { namespace parallel { CostGraphPtr entire_costgraph = nullptr; @@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; +constexpr char RESHAPEINFO[] = "ReshapeInfo"; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); @@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); } +void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { + std::vector curr_edges(edges_[{u_node, v_node}]); + curr_edges.push_back(edge); + edges_[{u_node, v_node}] = curr_edges; + + std::vector curr_out_edges(out_edges_[u_node]); + curr_out_edges.push_back(edge); + out_edges_[u_node] = curr_out_edges; + + std::vector curr_in_edges(in_edges_[v_node]); + curr_in_edges.push_back(edge); + in_edges_[v_node] = curr_in_edges; +} + bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { for (auto &edge_pair : edges_) { auto edges = edge_pair.second; @@ -1338,11 +1354,51 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo Status CostGraph::InitSelectedStrategy() { for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); + if (op->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } auto result = op->InitSelectedStrategy(op->selected_strategy()); if (result != SUCCESS) { return result; } } + // reshape init should be apply after the init of it's previous node and next node. + for (size_t i = 0; i < ops_.size(); ++i) { + if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { + auto reshape_info = std::dynamic_pointer_cast(ops_[i]); + auto in_edges = GetOriginalPrevEdges(ops_[i]); + auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr edge) { + return edge->prev_operator()->name() == reshape_info->pre_operator_name(); + }); + auto out_edges = GetOriginalNextEdges(ops_[i]); + auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr edge) { + return edge->next_operator()->name() == reshape_info->next_operator_name(); + }); + if (pre_iter != in_edges.end()) { + MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); + int32_t pre_index = reshape_info->pre_operator_index(); + Dimensions stra; + TensorInfo pre_info; + if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { + pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; + } else { + pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; + } + reshape_info->SetInputLayout(pre_info.tensor_layout()); + InferStraByTensorInfo(pre_info, &stra); + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = + std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); + reshape_info->set_strategy(reshape_stra); + } + if (next_iter != out_edges.end()) { + MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); + int32_t next_index = reshape_info->next_operator_index(); + reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); + } + return reshape_info->Init(nullptr); + } + } return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index 5077459695..6c58ac7957 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -87,11 +87,9 @@ class CostGraph { void RemoveOperator(const OperatorInfoPtr &op); bool IsOperatorInCostGraph(const OperatorInfoPtr &op); // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { - std::vector curr_edges(edges_[{u_node, v_node}]); - curr_edges.push_back(edge); - edges_[{u_node, v_node}] = curr_edges; - } + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); + std::vector> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } + std::vector> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } // An edge is uniquely identified by its name, and its output index and input index. bool IsEdgeInCostGraph(const std::string &, size_t, size_t); @@ -219,6 +217,8 @@ class CostGraph { std::vector ops_; std::map, std::vector> edges_; std::vector> connected_compoents_; + std::map> out_edges_; + std::map> in_edges_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index de95bd84ad..833c4a2c84 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -111,6 +111,7 @@ class OperatorInfo { Shape dev_matrix_shape() const { return dev_matrix_shape_; } std::vector inputs_tensor_info() const { return inputs_tensor_info_; } std::vector outputs_tensor_info() const { return outputs_tensor_info_; } + std::vector> strategy_cost() const { return strategy_cost_; } const std::string &name() const { return name_; } void set_name(const std::string &name) { name_ = name; } RankList global_device_list() const { return global_device_list_; } diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index d6e1c277ef..b191e22198 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -22,6 +22,7 @@ #include "parallel/device_manager.h" #include "parallel/device_matrix.h" #include "parallel/step_parallel.h" +#include "parallel/auto_parallel/graph_costmodel.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" @@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { } return FAILED; } - std::vector stra = strategy->GetInputDim(); - for (size_t i = 0; i < strategy_size; ++i) { - Shape sub_strategy = stra.at(i); - size_t strategy_len = sub_strategy.size(); - bool flag = false; - for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); - if (strategy_value > 1) { - if (flag) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy."; - } - return FAILED; - } - flag = true; - } - } - } return SUCCESS; } @@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr return SUCCESS; } +void ReshapeInfo::SetCostForReshapeWithParameter() { + size_t success = 0; + for (auto &sp : sp_vector_) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } +} + +void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + int32_t stage_id = strategy->GetInputStage(); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); +} + Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { if (GetAttrs() != SUCCESS) { MS_LOG(ERROR) << name_ << ": GetAttrs failed."; @@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { } is_auto_parallel_ = true; Shape input0_split; - input0_split.emplace_back(1); - (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0); + (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); Shapes splittable_inputs = {input0_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + // strategy used only in the input node is parameter, + // in other case, use the input node's output_layout as input_layout. + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; return FAILED; } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } return SUCCESS; } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 99ee014175..a711e9cb88 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo { output_layout_ = output_layout; output_layout_set_flag_ = true; } + void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); + void SetCostForReshapeWithParameter(); + void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } + void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } + void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } + void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::string pre_operator_name() const { return pre_operator_name_; } + std::string next_operator_name() const { return next_operator_name_; } + int32_t pre_operator_index() const { return pre_operator_index_; } + int32_t next_operator_index() const { return next_operator_index_; } protected: Status CheckStrategy(const StrategyPtr &strategy) override; @@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo { Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); int32_t dev_num_; + int32_t pre_operator_index_; + int32_t next_operator_index_; std::vector parameter_input_v_; + std::vector sp_vector_; Dimensions input_strategy_; TensorLayout input_layout_; TensorLayout output_layout_; bool input_layout_set_flag_; bool output_layout_set_flag_; + std::string pre_operator_name_; + std::string next_operator_name_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 3c538002e6..0447285e5e 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -39,6 +39,7 @@ #include "parallel/auto_parallel/rec_core/rec_partition.h" #include "parallel/context.h" #include "parallel/ops_info/tmp_identity_info.h" +#include "parallel/ops_info/reshape_info.h" #include "parallel/step_parallel.h" #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "pipeline/parse/python_adapter.h" @@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { EdgePtr edge_ptr; MS_LOG(INFO) << "Creating edge: " << edge_name; - bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()); + bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || + (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); if (follow_strategy) { // Redistribution in not allowed on the edge. // Elementwise operators have the same strategy as their previous operators. @@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector &all_nodes) { } } +bool FindReshape(const CNodePtr &cnode) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + return false; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; + } + if (prim->name() != RESHAPE) { + return false; + } + return true; +} + +// find previous node, then obtain its strategy_cost_ vector to get its layout vector. +bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { + // if previous node is a parameter, handle it in the outsize. + if (node->isa()) { + return false; + } + if (!node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return false; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + *pre_operator_info = cnode->operator_info(); + *out_index = 0; + return true; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + *out_index = GetTupleGetItemIndex(cnode); + // find tuple_get_item's previous node + auto pre_node = cnode->input(1); + if (!pre_node->isa()) { + MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; + } + CNodePtr pre_cnode = pre_node->cast(); + if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { + *pre_operator_info = pre_cnode->operator_info(); + return true; + } + return false; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { + continue; + } + return true; + } + MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; + return false; +} + +// find next node, then obtain its strategy_cost_ vector to get its layout vector. +// if reshape's output connect to several primitive, return the first layout found +bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); + *next_operator_info = use_apply->operator_info(); + *in_index = node_pair.second - 1; + return true; + } + MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + << " " << (use_apply->operator_info() != nullptr); + + if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { + return true; + } + } + return false; +} + +void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) { + Shape shape = pre_out_tensor_info.shape(); + Shape slice_shape = pre_out_tensor_info.slice_shape(); + for (size_t i = 0; i < shape.size(); ++i) { + if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) { + MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator"; + } + int32_t dim = (int32_t)(shape[i] / slice_shape[i]); + (*stra).push_back(dim); + } +} + +void ReshapeCostCompute(const std::vector &all_nodes) { + for (auto node : all_nodes) { + auto cnode = node->cast(); + if (!FindReshape(cnode)) { + continue; + } + MS_ASSERT(cnode->inputs().size() == 3); + // get previous node's strategy_cost_ + auto pre_node = cnode->input(1); + int32_t out_index = 0; + OperatorInfoPtr pre_operator_info; + std::vector> pre_stra_costs; + if (pre_node->isa()) { + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->SetCostForReshapeWithParameter(); + pre_operator_info = reshape_info; + pre_stra_costs = reshape_info->strategy_cost(); + } else { + if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { + MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; + } + pre_stra_costs = pre_operator_info->strategy_cost(); + } + // get next node's strategy_cost_ + int32_t in_index = 0; + OperatorInfoPtr next_operator_info; + std::vector> next_stra_costs; + bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); + if (!find_next_node) { + MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; + } + // set input_layout and output_layout for reshape. + // init reshape and set cost for each input_layout and output_layout. + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->set_pre_operator_name(pre_operator_info->name()); + reshape_info->set_pre_operator_index(out_index); + if (find_next_node) { + next_stra_costs = next_operator_info->strategy_cost(); + reshape_info->set_next_operator_name(next_operator_info->name()); + reshape_info->set_next_operator_index(in_index); + } + for (auto pre_stra_cost : pre_stra_costs) { + std::vector pre_out_tensor_infos; + if (pre_node->isa()) { + pre_out_tensor_infos = pre_stra_cost->inputs_ptr; + } else { + pre_out_tensor_infos = pre_stra_cost->outputs_ptr; + } + if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { + MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; + } + TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; + TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); + reshape_info->SetInputLayout(pre_out_tensor_layout); + // infer pre_node output strategy from output_layout. + Dimensions stra; + InferStraByTensorInfo(pre_out_tensor_info, &stra); + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); + if (next_stra_costs.empty()) { + if (reshape_info->Init(nullptr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; + } + // set cost for each input_layout and output_layout pairs. + reshape_info->SetCostForReshape(reshape_stra); + continue; + } + for (auto next_stra_cost : next_stra_costs) { + std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; + if (next_in_tensor_infos.size() <= IntToSize(in_index)) { + MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; + } + TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; + TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); + reshape_info->SetOutputLayout(next_in_tensor_layout); + if (reshape_info->Init(nullptr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; + } + // set cost for each input_layout and output_layout pairs. + reshape_info->SetCostForReshape(reshape_stra); + } + } + } +} + Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root) { // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. // Step 1: Traverse the ANF graph, and create NODEs for costgraph: @@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; } } - + // reshape operator needs the next node's input_layout as its output_layout. + // and needs the previous node's output_layout as its input_layout. + ReshapeCostCompute(all_nodes); // Step 2 ConstructCostGraphEdges(all_nodes); MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h index fff9dfa4c3..cf05a36fe1 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/parallel/step_auto_parallel.h @@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes); void AugmentCostGraph(const std::vector &all_nodes); +void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra); + Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index 7ff94e9af5..fb60c6d250 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) { Status ret = reshape->Init(strategy); ASSERT_EQ(ret, SUCCESS); } - -TEST_F(TestReshapeInfo, AutoStrategy1) { - ASSERT_EQ(reshape->GenerateStrategies(0), Status::SUCCESS); - std::vector> sc = reshape->GetStrategyCost(); - - Shapes splittable_inputs = {{1, 0, 0, 0}}; - std::vector sp_vector; - Shapes inputs_shape = {{32, 512, 7, 7}}; - GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector); - ASSERT_EQ(sc.size(), sp_vector.size()); - for (auto stra : sp_vector) { - auto stra0 = stra->GetInputDim()[0]; - ASSERT_EQ(stra0[1], 1); - ASSERT_EQ(stra0[2], 1); - ASSERT_EQ(stra0[3], 1); - } -} } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index 09769776a9..1bce733610 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -65,6 +65,193 @@ def test_reshape_matmul(): net.set_auto_parallel() _executor.compile(net, x) +def test_reshape_auto_1(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") + + def construct(self, x): + out = self.relu(x) + out = self.reshape(out, (64, 28)) + out = self.matmul(out, self.matmul_weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x) + +def test_reshape_auto_2(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1") + self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") + + def construct(self, x): + out = self.relu(x) + out = self.reshape(out, (64, 28)) + out = self.matmul(out, self.matmul_weight) + out = self.reshape(out, (128, 32)) + out = out + self.add_weight + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x) + +def test_reshape_auto_3(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") + + def construct(self, x): + out = self.relu(x) + out = self.matmul(out, self.matmul_weight) + out = self.reshape(out, (8, 8, 8, 8)) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([8*size, 28]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x) + +def test_reshape_auto_4(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.matmul_weight = Parameter(Tensor(np.ones([28*64]), dtype=ms.float32), name="weight") + + def construct(self, x): + out = self.relu(x) + out = self.reshape(out, (64, 28)) + w = self.reshape(self.matmul_weight, (28, 64)) + out = self.matmul(out, w) + return out -if __name__ == '__main__': - test_reshape_matmul() \ No newline at end of file + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x) + + +def test_reshape_auto_5(): + class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.mul = P.Mul() + self.reshape = P.Reshape() + self.reduce_sum = P.ReduceSum() + self.wide_w = Parameter(Tensor(np.ones([4, 1024*8, 64]), dtype=ms.float32), name="weight") + + def construct(self, x, y): + mask = self.reshape(y, (4, 1024*8, 1)) + w_id = self.relu(x) + wx = self.mul(w_id, mask) + wide_out = self.reshape(self.reduce_sum(wx, 1), (-1,1)) + deep_id = x + self.wide_w + vx = self.mul(deep_id, mask) + deep_in = self.reshape(vx, (-1, 1024*8*64)) + out = wide_out + deep_in + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([4, 1024*size, 1]), dtype=ms.float32) + y = Tensor(np.ones([4, 1024*size,]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x, y) + +def test_reshape_auto_6(): + class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.mul = P.Mul() + self.reshape = P.Reshape() + self.reduce_mean = P.ReduceMean() + self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") + + def construct(self, x, y): + out1 = x + self.wide_w + w = self.reshape(self.wide_w, (4,1024)) + out1 = self.reduce_mean(out1, 1) + out1 = out1 - w + out2 = self.mul(y, w) + out = out1 + out2 + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) + y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + _executor.compile(net, x, y)