|
|
|
@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
|
|
|
|
|
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
|
|
|
|
<< ", communication_cost_: " << ret->communication_cost_
|
|
|
|
|
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
|
|
|
|
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
|
|
|
|
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
|
|
|
|
|
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
|
|
|
|
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
|
|
|
@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
|
|
|
|
|
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
|
|
|
|
|
<< ", communication_cost_: " << ret->communication_cost_
|
|
|
|
|
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
|
|
|
|
|
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
|
|
|
|
|
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
|
|
|
|
|
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
|
|
|
|
|
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
|
|
|
|
@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
|
|
|
|
|
return succ_edges;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t CostGraph::GetNumEdges() const {
|
|
|
|
|
size_t sum = 0;
|
|
|
|
|
for (const auto &kv : edges_) {
|
|
|
|
|
auto &edges = kv.second;
|
|
|
|
|
sum += edges.size();
|
|
|
|
|
}
|
|
|
|
|
return sum;
|
|
|
|
|
}
|
|
|
|
|
Status CostGraph::InitSelectedStrategy() {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op);
|
|
|
|
@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
|
|
|
|
|
std::vector<OperatorInfoPtr> *topo_order) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(current_op);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(visited);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(topo_order);
|
|
|
|
|
|
|
|
|
|
visited->at(current_op) = true;
|
|
|
|
|
for (const auto &s_edge : current_op->succ_edges()) {
|
|
|
|
|
if (!visited->at(s_edge->next_operator())) {
|
|
|
|
|
DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
topo_order->push_back(current_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compute a topological order of the costgraph
|
|
|
|
|
void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
|
|
|
|
|
std::map<OperatorInfoPtr, bool> visited;
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
visited[op] = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
if (!visited[op]) {
|
|
|
|
|
DFSForTopoOrder(op, &visited, topo_order);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
auto search = candidate_ops.find(op);
|
|
|
|
|
if (search != candidate_ops.end()) {
|
|
|
|
|
// Mark the critical operators
|
|
|
|
|
op->mark_output_critical();
|
|
|
|
|
// Mark the successive edges
|
|
|
|
|
for (auto &s_edge : op->succ_edges()) {
|
|
|
|
|
s_edge->mark_output_critical();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
op->mark_output_not_critical();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
|
|
|
|
|
if (topo_order.size() == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "0 operator in costgraph.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto &first_op = topo_order[0];
|
|
|
|
|
if (first_op->prev_edges().size() > 0) {
|
|
|
|
|
MS_LOG(ERROR) << "The first operator in the first of topological order of "
|
|
|
|
|
"costgraph should have 0 incoming edge, but has "
|
|
|
|
|
<< first_op->prev_edges() << "edges.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
// The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
|
|
|
|
|
// of the output of OperatorInfo that currently has not been used
|
|
|
|
|
std::map<OperatorInfoPtr, int> curr_memory_state;
|
|
|
|
|
(void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size())));
|
|
|
|
|
std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state;
|
|
|
|
|
// The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
|
|
|
|
|
// not been used
|
|
|
|
|
double curr_memory_size = first_op->GetOutputsTotalSize();
|
|
|
|
|
double max_memory_size = curr_memory_size;
|
|
|
|
|
|
|
|
|
|
for (size_t finished = 1; finished < topo_order.size(); ++finished) {
|
|
|
|
|
// Produce
|
|
|
|
|
(void)curr_memory_state.emplace(
|
|
|
|
|
std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size())));
|
|
|
|
|
curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
|
|
|
|
|
// Consume
|
|
|
|
|
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
|
|
|
|
|
const auto &prev_op = prev_edge->prev_operator();
|
|
|
|
|
curr_memory_state[prev_op]--;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
|
|
|
|
|
const auto &prev_op = prev_edge->prev_operator();
|
|
|
|
|
if (curr_memory_state[prev_op] < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
|
|
|
|
|
return FAILED;
|
|
|
|
|
} else if (curr_memory_state[prev_op] == 0) {
|
|
|
|
|
curr_memory_state.erase(prev_op);
|
|
|
|
|
curr_memory_size -= prev_op->GetOutputsTotalSize();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (curr_memory_size < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
|
|
|
|
|
}
|
|
|
|
|
// Modify the max
|
|
|
|
|
if (curr_memory_size > max_memory_size) {
|
|
|
|
|
max_memory_size = curr_memory_size;
|
|
|
|
|
max_memory_state = curr_memory_state;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Mark those critical operators
|
|
|
|
|
MarkCriticalOpsAndEdges(max_memory_state);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
|
|
|
|
|
// Two steps to do:
|
|
|
|
|
// 1. Compute a topological order of the costgraph
|
|
|
|
|
// 2. Determine and mark the operators (and necessary edges) that are critical
|
|
|
|
|
std::vector<OperatorInfoPtr> topo_order;
|
|
|
|
|
TopologyOrder(&topo_order);
|
|
|
|
|
std::reverse(std::begin(topo_order), std::end(topo_order));
|
|
|
|
|
|
|
|
|
|
if (DetermineCriticalOps(topo_order) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Determining critical operators failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::CalculateOpsMemoryCost() {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op);
|
|
|
|
@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::CalculateOpsMemoryCostForInference() {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op);
|
|
|
|
|
if (op->CalculateMemoryCostForInference() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::CalculateEdgesMemoryCost() {
|
|
|
|
|
for (auto &edge_pair : edges_) {
|
|
|
|
|
const auto &edges = edge_pair.second;
|
|
|
|
@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::CalculateEdgesMemoryCostForInference() {
|
|
|
|
|
for (auto &edge_pair : edges_) {
|
|
|
|
|
const auto &edges = edge_pair.second;
|
|
|
|
|
for (auto &one_edge : edges) {
|
|
|
|
|
if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
|
|
|
|
|
for (auto one_op : ops_) {
|
|
|
|
|
if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
|
|
|
|
@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() {
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status CostGraph::CalculateMemoryCost() {
|
|
|
|
|
if (RUN_PHASE == TRAINING_PHASE) {
|
|
|
|
|
// training phase
|
|
|
|
|
if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
|
|
|
|
|
// Calculate operators' memory usage
|
|
|
|
|
if (CalculateOpsMemoryCost() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
// Calculate edges' memory usage
|
|
|
|
|
if (CalculateEdgesMemoryCost() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
// Correct memory usage caused by TmpIdentity
|
|
|
|
|
if (CorrectOpsMemoryCost() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// inference phase
|
|
|
|
|
if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
|
|
|
|
|
// Calculate operators' memory usage
|
|
|
|
|
if (CalculateOpsMemoryCostForInference() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
// Calculate edges's memory usage
|
|
|
|
|
if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Computing operators' critical flag failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
} // namespace parallel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|