|
|
|
@ -948,10 +948,12 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) {
|
|
|
|
|
return target_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
|
|
|
|
|
StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, StrategyPtr right_op_stra, const CostPtr& right_op_cost,
|
|
|
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
|
|
|
|
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
|
|
|
|
|
StrategyPtr right_op_stra, const CostPtr& right_op_cost,
|
|
|
|
|
const CostPtrList& elimi_op_clist,
|
|
|
|
|
const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
|
|
|
|
const CostPtrList& left_node_clist_origin,
|
|
|
|
|
CostPtrList* left_node_clist_new) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_op_cost);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(left_node_clist_new);
|
|
|
|
@ -985,93 +987,20 @@ void CostGraph::CreateTriangleEliminationSubCostListForIdentity(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationSubCostListForOthers(
|
|
|
|
|
StrategyPtr elimi_op_stra, StrategyPtr left_node_stra, StrategyPtr right_node_stra, const CostPtr& right_op_cost,
|
|
|
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost,
|
|
|
|
|
const CostPtrList& left_node_clist_origin, CostPtrList* left_node_clist_new) {
|
|
|
|
|
CostPtr elimi_op_determined = nullptr, left_edge_determined = nullptr, init_ele = nullptr;
|
|
|
|
|
std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cost_x);
|
|
|
|
|
if ((init == nullptr) || (cost_x->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
|
|
|
|
init = cost_x;
|
|
|
|
|
}
|
|
|
|
|
return init;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Find a feasible elimi_op_clist
|
|
|
|
|
elimi_op_determined = std::accumulate(elimi_op_clist.begin(), elimi_op_clist.end(), init_ele, LocalCompare);
|
|
|
|
|
init_ele = nullptr;
|
|
|
|
|
// Find a feasible left_edge_cost
|
|
|
|
|
left_edge_determined = std::accumulate(left_edge_clist.begin(), left_edge_clist.end(), init_ele, LocalCompare);
|
|
|
|
|
if ((elimi_op_determined == nullptr) || (left_edge_determined == nullptr)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if ((elimi_op_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY) ||
|
|
|
|
|
(left_edge_determined->memory_cost_ >= DEVICE_MEMORY_CAPACITY)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& left_node_cost : left_node_clist_origin) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(left_node_cost);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_op_cost);
|
|
|
|
|
double new_memory_cost = left_node_cost->memory_cost_ + elimi_op_determined->memory_cost_ +
|
|
|
|
|
left_edge_determined->memory_cost_ + right_edge_cost->memory_cost_ +
|
|
|
|
|
right_op_cost->memory_cost_;
|
|
|
|
|
double commu_cost = left_node_cost->communication_cost_ + elimi_op_determined->communication_cost_ +
|
|
|
|
|
left_edge_determined->communication_cost_ + right_edge_cost->communication_cost_ +
|
|
|
|
|
right_op_cost->communication_cost_;
|
|
|
|
|
double commu_without =
|
|
|
|
|
left_node_cost->communication_without_parameter_ + elimi_op_determined->communication_without_parameter_ +
|
|
|
|
|
left_edge_determined->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ +
|
|
|
|
|
right_op_cost->communication_without_parameter_;
|
|
|
|
|
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_determined,
|
|
|
|
|
left_edge_determined, right_edge_cost, left_node_stra,
|
|
|
|
|
left_node_cost, right_node_stra, right_op_cost);
|
|
|
|
|
|
|
|
|
|
auto new_cost = std::make_shared<Cost>(new_memory_cost, commu_cost, decision);
|
|
|
|
|
new_cost->communication_without_parameter_ = commu_without;
|
|
|
|
|
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
|
|
|
|
|
left_node_clist_new->emplace_back(std::move(new_cost));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist,
|
|
|
|
|
const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra,
|
|
|
|
|
const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra,
|
|
|
|
|
const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist,
|
|
|
|
|
const CostPtrList& left_node_clist_origin,
|
|
|
|
|
CostPtrList* left_node_clist_new) {
|
|
|
|
|
// The reason for separately dealing with when the 'elimi_op' is 'TMPIDENTITY_INFO' or others is that
|
|
|
|
|
// when 'elimi_op' is TMPIDENTITY_INFO, the computation is limited, while 'elimi_op' is others, the computation
|
|
|
|
|
// may be huge
|
|
|
|
|
MS_EXCEPTION_IF_NULL(elimi_op);
|
|
|
|
|
if (elimi_op->name().find(TMPIDENTITY_INFO_NAME) != std::string::npos) {
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost);
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
|
|
|
|
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
|
|
|
|
|
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
|
|
|
|
// Exact computation for TMPIDENTITY_INFO_NAME case
|
|
|
|
|
CreateTriangleEliminationSubCostListForIdentity(elimi_op_stra, left_node_stra, right_node_stra,
|
|
|
|
|
right_node_cost, elimi_op_clist, left_edge_clist,
|
|
|
|
|
right_edge_cost, left_node_clist_origin, left_node_clist_new);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost);
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
|
|
|
|
if ((right_node_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY) &&
|
|
|
|
|
(right_edge_cost->memory_cost_ < DEVICE_MEMORY_CAPACITY)) {
|
|
|
|
|
// Approximate computation for other case
|
|
|
|
|
CreateTriangleEliminationSubCostListForOthers(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
|
|
|
|
|
elimi_op_clist, left_edge_clist, right_edge_cost,
|
|
|
|
|
left_node_clist_origin, left_node_clist_new);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto& right_node_cost : right_node_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_node_cost);
|
|
|
|
|
for (auto& right_edge_cost : right_edge_clist) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(right_edge_cost);
|
|
|
|
|
CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
|
|
|
|
|
elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
|
|
|
|
|
left_node_clist_new);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|