diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 292cc4f5f0..f99b271894 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -948,10 +948,12 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { return target_op; } -void CostGraph::CreateTriangleEliminationSubCostListForIdentity( - StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, StrategyPtr right_op_stra, const CostPtr& right_op_cost, - const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, - const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) { +void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, + StrategyPtr right_op_stra, const CostPtr& right_op_cost, + const CostPtrList& elimi_op_clist, + const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, + const CostPtrList& left_node_clist_origin, + CostPtrList* left_node_clist_new) { MS_EXCEPTION_IF_NULL(right_edge_cost); MS_EXCEPTION_IF_NULL(right_op_cost); MS_EXCEPTION_IF_NULL(left_node_clist_new); @@ -985,93 +987,20 @@ void CostGraph::CreateTriangleEliminationSubCostListForIdentity( } } -void CostGraph::CreateTriangleEliminationSubCostListForOthers( - StrategyPtr elimi_op_stra, StrategyPtr left_node_stra, StrategyPtr right_node_stra, const CostPtr& right_op_cost, - const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, - const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) { - CostPtr elimi_op_determined = nullptr, left_edge_determined = nullptr, init_ele = nullptr; - std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { - MS_EXCEPTION_IF_NULL(cost_x); - if ((init == nullptr) || (cost_x->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { - init = cost_x; - } - return init; - }; - - // Find a feasible elimi_op_clist - elimi_op_determined = std::accumulate(elimi_op_clist.begin(), elimi_op_clist.end(), init_ele, LocalCompare); - init_ele = nullptr; - // Find a feasible left_edge_cost - left_edge_determined = std::accumulate(left_edge_clist.begin(), left_edge_clist.end(), init_ele, LocalCompare); - if ((elimi_op_determined == nullptr) || (left_edge_determined == nullptr)) { - return; - } - if ((elimi_op_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY) || - (left_edge_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY)) { - return; - } - - for (auto& left_node_cost : left_node_clist_origin) { - MS_EXCEPTION_IF_NULL(left_node_cost); - MS_EXCEPTION_IF_NULL(right_op_cost); - double new_memory_cost = left_node_cost->memory_cost_ + elimi_op_determined->memory_cost_ + - left_edge_determined->memory_cost_ + right_edge_cost->memory_cost_ + - right_op_cost->memory_cost_; - double commu_cost = left_node_cost->communication_cost_ + elimi_op_determined->communication_cost_ + - left_edge_determined->communication_cost_ + right_edge_cost->communication_cost_ + - right_op_cost->communication_cost_; - double commu_without = - left_node_cost->communication_without_parameter_ + elimi_op_determined->communication_without_parameter_ + - left_edge_determined->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ + - right_op_cost->communication_without_parameter_; - auto decision = std::make_shared(elimi_op_stra, elimi_op_determined, - left_edge_determined, right_edge_cost, left_node_stra, - left_node_cost, right_node_stra, right_op_cost); - - auto new_cost = std::make_shared(new_memory_cost, commu_cost, decision); - new_cost->communication_without_parameter_ = commu_without; - new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); - left_node_clist_new->emplace_back(std::move(new_cost)); - } -} - void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) { - // The reason for separately dealing with when the 'elimi_op' is 'TMPIDENTITY_INFO' or others is that - // when 'elimi_op' is TMPIDENTITY_INFO, the computation is limited, while 'elimi_op' is others, the computation - // may be huge MS_EXCEPTION_IF_NULL(elimi_op); - if (elimi_op->name().find(TMPIDENTITY_INFO_NAME) != std::string::npos) { - for (auto& right_node_cost : right_node_clist) { - MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto& right_edge_cost : right_edge_clist) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) && - (right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { - // Exact computation for TMPIDENTITY_INFO_NAME case - CreateTriangleEliminationSubCostListForIdentity(elimi_op_stra, left_node_stra, right_node_stra, - right_node_cost, elimi_op_clist, left_edge_clist, - right_edge_cost, left_node_clist_origin, left_node_clist_new); - } - } - } - } else { - for (auto& right_node_cost : right_node_clist) { - MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto& right_edge_cost : right_edge_clist) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) && - (right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) { - // Approximate computation for other case - CreateTriangleEliminationSubCostListForOthers(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, - elimi_op_clist, left_edge_clist, right_edge_cost, - left_node_clist_origin, left_node_clist_new); - } - } + for (auto& right_node_cost : right_node_clist) { + MS_EXCEPTION_IF_NULL(right_node_cost); + for (auto& right_edge_cost : right_edge_clist) { + MS_EXCEPTION_IF_NULL(right_edge_cost); + CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, + elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, + left_node_clist_new); } } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index fde9514540..3b04703a47 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -163,14 +163,9 @@ class CostGraph { void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&, const CostPtrList&, const CostPtrList&, CostPtrList*); - // Given the relevant costlist, create the TriangleElimination cost for eliminating TmpIdentityInfo - void CreateTriangleEliminationSubCostListForIdentity(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, - const CostPtrList&, const CostPtrList&, const CostPtr&, - const CostPtrList&, CostPtrList*); - // Given the relevant costlist, create the TriangleElimination cost for eliminating other operators - void CreateTriangleEliminationSubCostListForOthers(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, - const CostPtrList&, const CostPtrList&, const CostPtr&, - const CostPtrList&, CostPtrList*); + // Given the relevant costlist, create the TriangleElimination cost + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&, + const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*); // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.