!156 [Auto parallel] Separate memory_cost and computation_cost in cost model

Merge pull request !156 from Xiaoda/implementing-memory-calculation-in-auto-parallel
pull/156/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 22516d3e08

@ -23,8 +23,8 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void Simplify(CostPtrList* clist_ptrs) { void Simplify(CostPtrList* clist_ptrs) {
// Sort the cost_list with the memory_cost increasing, and communication_cost decreasing order. This method // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
// excludes the cost with greater memory_cost and greater communication_cost. // excludes the cost with greater computation_cost_ and greater communication_cost.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if (!COST_MODEL_SIMPLIFY_CALCULATION) { if (!COST_MODEL_SIMPLIFY_CALCULATION) {
return; return;
@ -33,7 +33,7 @@ void Simplify(CostPtrList* clist_ptrs) {
std::vector<size_t> id(clist_ptrs->size()); std::vector<size_t> id(clist_ptrs->size());
std::iota(id.begin(), id.end(), size_t(0)); std::iota(id.begin(), id.end(), size_t(0));
std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
}); });
CostPtrList ret; CostPtrList ret;
for (size_t i = 0; i < clist_ptrs->size(); ++i) { for (size_t i = 0; i < clist_ptrs->size(); ++i) {
@ -45,8 +45,8 @@ void Simplify(CostPtrList* clist_ptrs) {
} }
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) {
// Sort the cost_list with the memory_cost increasing, and communication_with_partial_para_cost decreasing order. // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// This method excludes the cost with greater memory_cost and greater communication_without_para_cost. // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if (!COST_MODEL_SIMPLIFY_CALCULATION) { if (!COST_MODEL_SIMPLIFY_CALCULATION) {
return; return;
} }
@ -54,7 +54,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) {
std::vector<size_t> id(clist_ptrs->size()); std::vector<size_t> id(clist_ptrs->size());
std::iota(id.begin(), id.end(), size_t(0)); std::iota(id.begin(), id.end(), size_t(0));
std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
}); });
CostPtrList ret; CostPtrList ret;
for (size_t i = 0; i < clist_ptrs->size(); ++i) { for (size_t i = 0; i < clist_ptrs->size(); ++i) {

@ -44,14 +44,18 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut
struct Cost { struct Cost {
Cost(); Cost();
Cost(double memory, double commuication, const std::shared_ptr<Decision>& decision_ = nullptr) Cost(double computation, double commuication, const std::shared_ptr<Decision>& decision_ = nullptr)
: memory_cost_(memory), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) {
memory_with_reuse_ = 0.0;
communication_without_parameter_ = 0.0; communication_without_parameter_ = 0.0;
communication_with_partial_para_ = 0.0; communication_with_partial_para_ = 0.0;
communication_redis_forward_ = 0.0; communication_redis_forward_ = 0.0;
communication_redis_backward_ = 0.0; communication_redis_backward_ = 0.0;
} }
double memory_cost_; // 'memory_with_reuse_' calculates the peak memory usage in a training phase
double memory_with_reuse_;
// 'computation_cost_' models the training time of an iteration in a training phase
double computation_cost_;
// 'communication_cost_' includes communications from operators (forward and backward) and edges // 'communication_cost_' includes communications from operators (forward and backward) and edges
double communication_cost_; double communication_cost_;
// communication_without_parameter_ = communication_cost_ - (backward communication from operators) // communication_without_parameter_ = communication_cost_ - (backward communication from operators)

@ -35,7 +35,7 @@ namespace parallel {
// interpretation of 6 operations in costmodel.h. // interpretation of 6 operations in costmodel.h.
// Phase 2: Search the cost_list in the final graph, and determine the optimal one // Phase 2: Search the cost_list in the final graph, and determine the optimal one
// Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity
// COST_MODEL_ALPHA * memory_cost + COST_MODEL_BETA * communication_cost // COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost
// Phase 3: Recover the original CostGraph, the determine strategy for each operator // Phase 3: Recover the original CostGraph, the determine strategy for each operator
// After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying
// the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy,

@ -69,7 +69,7 @@ Status Edge::InitEdgeCost() {
MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
} }
MS_EXCEPTION_IF_NULL(cost); MS_EXCEPTION_IF_NULL(cost);
MS_LOG(DEBUG) << "The redistribution cost: memory_cost: " << cost->memory_cost_ MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
<< ", communication_cost: " << cost->communication_cost_ << ", communication_cost: " << cost->communication_cost_
<< ", communication_without_parameter_: " << cost->communication_without_parameter_ << ", communication_without_parameter_: " << cost->communication_without_parameter_
<< ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
@ -117,9 +117,9 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
double comm_cost = tensor_redistribution.comm_cost(); double comm_cost = tensor_redistribution.comm_cost();
double forward_comm_cost = tensor_redistribution.forward_comm_cost(); double forward_comm_cost = tensor_redistribution.forward_comm_cost();
double backward_comm_cost = tensor_redistribution.backward_comm_cost(); double backward_comm_cost = tensor_redistribution.backward_comm_cost();
double mem_cost = tensor_redistribution.mem_cost(); double computation_cost = tensor_redistribution.computation_cost();
*cost = std::make_shared<Cost>(type_length * mem_cost, type_length * comm_cost); *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
(*cost)->communication_without_parameter_ = type_length * comm_cost; (*cost)->communication_without_parameter_ = type_length * comm_cost;
(*cost)->communication_with_partial_para_ = (*cost)->communication_with_partial_para_ =
(*cost)->communication_without_parameter_ + (*cost)->communication_without_parameter_ +
@ -150,26 +150,26 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
CostPtrList selected_cost_list(all_cost_list.size(), nullptr); CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double)> recursive = [&](size_t k, double memory, double communication, std::function<void(size_t, double, double, double)> recursive =
double communication_without_para) { [&](size_t k, double computation, double communication, double communication_without_para) {
if (k == edges.size()) { if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(memory, communication); CostPtr new_cost = std::make_shared<Cost>(computation, communication);
MS_EXCEPTION_IF_NULL(new_cost); MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->decision_ptr_ = decision; new_cost->decision_ptr_ = decision;
result.push_back(new_cost); result.push_back(new_cost);
return; return;
} }
for (auto& c : all_cost_list[k]) { for (auto& c : all_cost_list[k]) {
MS_EXCEPTION_IF_NULL(c); MS_EXCEPTION_IF_NULL(c);
selected_cost_list[k] = c; selected_cost_list[k] = c;
recursive(k + 1, memory + c->memory_cost_, communication + c->communication_cost_, recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_); communication_without_para + c->communication_without_parameter_);
} }
}; };
recursive(0, 0, 0, 0); recursive(0, 0, 0, 0);
SimplifyForDreasingCommunicationWithPartialPara(&result); SimplifyForDreasingCommunicationWithPartialPara(&result);
return result; return result;
@ -203,7 +203,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
MS_EXCEPTION_IF_NULL(middle_cost); MS_EXCEPTION_IF_NULL(middle_cost);
for (auto& right_cost : right_cost_list) { for (auto& right_cost : right_cost_list) {
MS_EXCEPTION_IF_NULL(right_cost); MS_EXCEPTION_IF_NULL(right_cost);
double memory = left_cost->memory_cost_ + middle_cost->memory_cost_ + right_cost->memory_cost_; double computation =
left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
double communication = double communication =
left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
double communication_without_para = left_cost->communication_without_parameter_ + double communication_without_para = left_cost->communication_without_parameter_ +
@ -211,7 +212,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
right_cost->communication_without_parameter_; right_cost->communication_without_parameter_;
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
auto cost = std::make_shared<Cost>(memory, communication, decision); auto cost = std::make_shared<Cost>(computation, communication, decision);
MS_EXCEPTION_IF_NULL(cost); MS_EXCEPTION_IF_NULL(cost);
cost->communication_without_parameter_ = communication_without_para; cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ = cost->communication_with_partial_para_ =

@ -133,7 +133,7 @@ class Edge {
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } Status CalculateMemoryCost() const { return SUCCESS; }
private: private:
std::string edge_name_; std::string edge_name_;

File diff suppressed because it is too large Load Diff

@ -175,16 +175,12 @@ class CostGraph {
void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&,
const CostPtrList&, std::vector<StrategyPtr>, CostPtrList&, CostPtrList&, const CostPtrList&, std::vector<StrategyPtr>, CostPtrList&, CostPtrList&,
CostPtrList*); CostPtrList*);
// When a output of a operator is being used by multiple operators, the memory cost of this part should be calculated
// only once. This method is for correcting the 'strategy_cost_' for operators
Status CorrectOpsStrategyCostForMultiOutputUse();
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be resused. // the memory cost can be resused.
Status CorrectOpsStrategyCostForMemoryReuse(); Status CalculateOpsMemoryCost();
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be resused. // the memory cost can be resused.
Status CorrectEdgesStrategyCostForMemoryReuse(); Status CalculateEdgesMemoryCost();
Status ComputeOpsAndEdgesParameterInvolved(); Status ComputeOpsAndEdgesParameterInvolved();
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }

File diff suppressed because it is too large Load Diff

@ -592,10 +592,10 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
int32_t stage_id = strategy->GetInputStage(); int32_t stage_id = strategy->GetInputStage();
// Here, we use the origin outputs_, because we only use the slice size of the output tensor. // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not. // It does not matter whether the output tensor is transposed or not.
double memory_cost = double computation_cost =
matmulcost_ptr->GetForwardMemoryCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(memory_cost, communication_cost); std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ = result->communication_without_parameter_ =
matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ = result->communication_with_partial_para_ =
@ -604,7 +604,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
// Breaking ties for preferring data parallelization // Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result); BreakingTiesForPerferringDataParallel(strategy, result);
MS_LOG(DEBUG) << name_ << " : memory_cost: " << result->memory_cost_ MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
<< ", communication_cost: " << result->communication_cost_ << ", communication_cost: " << result->communication_cost_
<< ", communication_without_parameter_: " << result->communication_without_parameter_ << ", communication_without_parameter_: " << result->communication_without_parameter_
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_; << ", communication_with_partial_para_: " << result->communication_with_partial_para_;

@ -1034,9 +1034,10 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return FAILED; return FAILED;
} }
int32_t stage_id = strategy->GetInputStage(); int32_t stage_id = strategy->GetInputStage();
double memory_cost = GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double computation_cost =
GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(memory_cost, communication_cost); std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ = result->communication_without_parameter_ =
GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ = result->communication_with_partial_para_ =
@ -1056,22 +1057,6 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return SUCCESS; return SUCCESS;
} }
Status OperatorInfo::CorrectStrategyCostForMultiOutputUse(size_t input_index) {
for (auto& swc : strategy_cost_) {
double parameter_memory_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
static_cast<double>(GetOperatorCost()->inputs_type_lengths()[input_index]);
// remove the parameter memory cost
swc->cost_list[0]->memory_cost_ -= parameter_memory_cost;
if (swc->cost_list[0]->memory_cost_ < -1) {
MS_LOG(ERROR) << "The memory cost after correction is " << swc->cost_list[0]->memory_cost_
<< ", the parameter_memory_cost is " << parameter_memory_cost;
return FAILED;
}
}
corrected_input_indices_.push_back(input_index);
return SUCCESS;
}
int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
if (is_output_parameter_involve_ != -1) { if (is_output_parameter_involve_ != -1) {
return is_output_parameter_involve_; return is_output_parameter_involve_;
@ -1217,7 +1202,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) {
cost->memory_cost_ -= 1.0; cost->computation_cost_ -= 1.0;
cost->communication_cost_ -= 1.0; cost->communication_cost_ -= 1.0;
cost->communication_with_partial_para_ -= 1.0; cost->communication_with_partial_para_ -= 1.0;
cost->communication_without_parameter_ -= 1.0; cost->communication_without_parameter_ -= 1.0;
@ -1226,7 +1211,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
} }
double OperatorInfo::GetForwardMemoryCostFromCNode() { double OperatorInfo::GetForwardMemoryCostFromCNode() {
return GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, 0); return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
} }
} // namespace parallel } // namespace parallel

@ -87,13 +87,9 @@ class OperatorInfo {
// is checked // is checked
Status SetCostUnderStrategyBase(const StrategyPtr& strategy); Status SetCostUnderStrategyBase(const StrategyPtr& strategy);
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
// In the case of a Parameter (or a output) being used by multiple operators, the memory cost induced by
// the parameter (or a output) should be calculated only once. This method is used to
// remove this part from the 'strategy_cost_'.
Status CorrectStrategyCostForMultiOutputUse(size_t input_index);
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } Status CalculateMemoryCost() const { return SUCCESS; }
int ComputeOpAndPrevEdgeParameterInvolved(); int ComputeOpAndPrevEdgeParameterInvolved();
ForwardOp forward_op() const { return forward_op_; } ForwardOp forward_op() const { return forward_op_; }

@ -387,7 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode); operator_info->set_cnode(cnode);
// If no strategy has been configured for this operator, then candidate strategies are generated for // If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy
if (!StrategyFound(attrs) || prim->name() == CAST) { if (!StrategyFound(attrs) || prim->name() == CAST) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator // BatchParallelInfo operator
@ -600,13 +600,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
} }
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
} }
// For the case of a output being used by multiple subsequent operators, the output induced memory cost should be
// calculated only once. This method is for correct the operators' memory cost calculation.
if (entire_costgraph->CorrectOpsStrategyCostForMultiOutputUse() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting strategy_cost_ for operators failed.";
} else {
MS_LOG(INFO) << "Correcting strategy_cost_ for operators succeeded.";
}
MS_LOG(INFO) << "Constructing edges for cost graph ends."; MS_LOG(INFO) << "Constructing edges for cost graph ends.";
} }
@ -803,14 +797,6 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>( std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
// Correct the memory calculation for a parameter being used by multiple operators. The parameter is calculated
// only once
if (target_cnode->operator_info()->CorrectStrategyCostForMultiOutputUse(IntToSize(input_index - 1)) != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting strategy_cost_ failed : " << prim->name();
} else {
MS_LOG(INFO) << "Correcting strategy_cost_ succeeded. " << prim->name();
}
if (edge_ptr->InitEdgeCost() != SUCCESS) { if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed"; MS_LOG(EXCEPTION) << "Edge cost initialization failed";
} }
@ -840,7 +826,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
// operator for this Parameter, and add an edge for the use of this Parameter by each // operator for this Parameter, and add an edge for the use of this Parameter by each
// subsequent operator; // subsequent operator;
// Step 3.1: Correct the memory calculation for memory reuse // Step 3.1: Calculate memory usage
// Step 4: Run the Dynamic Programming algorithm: // Step 4: Run the Dynamic Programming algorithm:
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
@ -867,14 +853,14 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
<< " operators, and " << entire_costgraph->GetNumPairs() << " edges."; << " operators, and " << entire_costgraph->GetNumPairs() << " edges.";
// Step 3.1: Correcting calculation for memory reuse // Step 3.1: Calculate the memory usage
if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
// Correcting operators' memory usage // Calculate operators' memory usage
if (entire_costgraph->CorrectOpsStrategyCostForMemoryReuse() != SUCCESS) { if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed.";
} }
// Correcting edges' memory usage // Calculate edges' memory usage
if (entire_costgraph->CorrectEdgesStrategyCostForMemoryReuse() != SUCCESS) { if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed.";
} }
} else { } else {

@ -144,7 +144,7 @@ Status TensorRedistribution::ComputeCost() {
MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
return Status::FAILED; return Status::FAILED;
} }
// Compute redistribution communication cost and memory cost // Compute redistribution communication cost and computation cost
for (auto& op_cost : operator_list_) { for (auto& op_cost : operator_list_) {
OperatorR op = op_cost.first; OperatorR op = op_cost.first;
Shape slice_shape = op_cost.second; Shape slice_shape = op_cost.second;
@ -154,14 +154,14 @@ Status TensorRedistribution::ComputeCost() {
if (str == PERMUTE_BY_AXIS) { if (str == PERMUTE_BY_AXIS) {
// The shape does not change after PermuteByAxis operation. // The shape does not change after PermuteByAxis operation.
// communication cost = all_to_all + all_to_all = 2 * slice_shape // communication cost = all_to_all + all_to_all = 2 * slice_shape
// memory cost = slice_shape // computation cost = slice_shape
forward_comm_cost_ += prod; forward_comm_cost_ += prod;
backward_comm_cost_ += prod; backward_comm_cost_ += prod;
comm_cost_ += 2.0 * prod; comm_cost_ += 2.0 * prod;
mem_cost_ += prod; computation_cost_ += prod;
} else if (str == CONCAT_BY_AXIS) { } else if (str == CONCAT_BY_AXIS) {
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// memory cost = before_slice_shape // computation cost = before_slice_shape
if (op.second.size() < 3) { if (op.second.size() < 3) {
MS_LOG(ERROR) << "op.second size should not be less than 3!"; MS_LOG(ERROR) << "op.second size should not be less than 3!";
return Status::FAILED; return Status::FAILED;
@ -173,22 +173,22 @@ Status TensorRedistribution::ComputeCost() {
comm_cost_ += prod * (dev_num + 1.0); comm_cost_ += prod * (dev_num + 1.0);
int32_t concat_dim = op.second[0]; int32_t concat_dim = op.second[0];
if (concat_dim == 0) { if (concat_dim == 0) {
// memory cost = all_gather // computation cost = all_gather
mem_cost_ += prod; computation_cost_ += prod;
} else { } else {
// memory cost = all_gather + split + concat // computation cost = all_gather + split + concat
mem_cost_ += (prod + prod * dev_num + prod * dev_num); computation_cost_ += (prod + prod * dev_num + prod * dev_num);
} }
} else { } else {
// There is only memory cost in SplitByAxis. // There is only computation cost in SplitByAxis.
// memory cost = before_slice_shape // computation cost = before_slice_shape
mem_cost_ += prod; computation_cost_ += prod;
} }
} }
if (reshape_flag()) { if (reshape_flag()) {
Shape prev_slice_shape = from_.slice_shape().array(); Shape prev_slice_shape = from_.slice_shape().array();
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>()); double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
mem_cost_ += 2.0 * prev_prod; computation_cost_ += 2.0 * prev_prod;
} }
return Status::SUCCESS; return Status::SUCCESS;
} }

@ -41,7 +41,7 @@ class TensorRedistribution {
comm_cost_(0.0), comm_cost_(0.0),
forward_comm_cost_(0.0), forward_comm_cost_(0.0),
backward_comm_cost_(0.0), backward_comm_cost_(0.0),
mem_cost_(0.0), computation_cost_(0.0),
construct_op_flag_(construct_op_flag), construct_op_flag_(construct_op_flag),
keep_reshape_(keep_reshape) {} keep_reshape_(keep_reshape) {}
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
@ -51,7 +51,7 @@ class TensorRedistribution {
bool reshape_flag() const { return reshape_flag_; } bool reshape_flag() const { return reshape_flag_; }
Status ComputeCost(); Status ComputeCost();
double comm_cost() const { return comm_cost_; } double comm_cost() const { return comm_cost_; }
double mem_cost() const { return mem_cost_; } double computation_cost() const { return computation_cost_; }
double forward_comm_cost() const { return forward_comm_cost_; } double forward_comm_cost() const { return forward_comm_cost_; }
double backward_comm_cost() const { return backward_comm_cost_; } double backward_comm_cost() const { return backward_comm_cost_; }
@ -66,10 +66,13 @@ class TensorRedistribution {
RankList dev_list_; RankList dev_list_;
OperatorList operator_list_; OperatorList operator_list_;
bool reshape_flag_; bool reshape_flag_;
// communication cost
double comm_cost_; double comm_cost_;
// forward communication cost
double forward_comm_cost_; double forward_comm_cost_;
// backward communication cost
double backward_comm_cost_; double backward_comm_cost_;
double mem_cost_; double computation_cost_;
bool construct_op_flag_; bool construct_op_flag_;
bool keep_reshape_; bool keep_reshape_;
}; };

@ -322,8 +322,8 @@ TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) {
auto ret_list = entire_cost_graph.SelectCostListWithMinTrainingTimeMultiple(all_list, memory); auto ret_list = entire_cost_graph.SelectCostListWithMinTrainingTimeMultiple(all_list, memory);
ASSERT_EQ(ret_list.size(), 2); ASSERT_EQ(ret_list.size(), 2);
ASSERT_DOUBLE_EQ(ret_list[0]->memory_cost_, 10); ASSERT_DOUBLE_EQ(ret_list[0]->computation_cost_, 10);
ASSERT_DOUBLE_EQ(ret_list[1]->memory_cost_, 1010); ASSERT_DOUBLE_EQ(ret_list[1]->computation_cost_, 1010);
} }
TEST_F(TestCostGraph, test_CheckOpElimination) { TEST_F(TestCostGraph, test_CheckOpElimination) {

@ -76,8 +76,8 @@ TEST_F(TestMatMulCost, test_CostGeneration) {
mmcost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); mmcost_.SetInputAndOutputTypeLength(inputs_length, outputs_length);
mmcost_.GetForwardCommCost(inputs, outputs, 0); mmcost_.GetForwardCommCost(inputs, outputs, 0);
mmcost_.GetBackwardCommCost(inputs, outputs, 0); mmcost_.GetBackwardCommCost(inputs, outputs, 0);
mmcost_.GetForwardMemoryCost(inputs, outputs, 0); mmcost_.GetForwardComputationCost(inputs, outputs, 0);
mmcost_.GetBackwardMemoryCost(inputs, outputs, 0); mmcost_.GetForwardComputationCost(inputs, outputs, 0);
} }
class TestActivationCost : public UT::Common { class TestActivationCost : public UT::Common {
@ -128,8 +128,8 @@ TEST_F(TestActivationCost, test_CostGeneration) {
std::vector<size_t> inputs_length = {4, 4}; std::vector<size_t> inputs_length = {4, 4};
std::vector<size_t> outputs_length = {4}; std::vector<size_t> outputs_length = {4};
ac_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); ac_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length);
ac_cost_.GetForwardMemoryCost(inputs, outputs, 0); ac_cost_.GetForwardComputationCost(inputs, outputs, 0);
ac_cost_.GetBackwardMemoryCost(inputs, outputs, 0); ac_cost_.GetBackwardComputationCost(inputs, outputs, 0);
} }
class TestPReLUCost : public UT::Common { class TestPReLUCost : public UT::Common {
@ -184,8 +184,8 @@ TEST_F(TestPReLUCost, test_CostGeneration) {
prelu_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); prelu_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length);
double BCC, FMC, GMC; double BCC, FMC, GMC;
BCC = prelu_cost_.GetBackwardCommCost(inputs, outputs, 0); BCC = prelu_cost_.GetBackwardCommCost(inputs, outputs, 0);
FMC = prelu_cost_.GetForwardMemoryCost(inputs, outputs, 0); FMC = prelu_cost_.GetForwardComputationCost(inputs, outputs, 0);
GMC = prelu_cost_.GetBackwardMemoryCost(inputs, outputs, 0); GMC = prelu_cost_.GetBackwardComputationCost(inputs, outputs, 0);
ASSERT_EQ(BCC, 32 * 4); ASSERT_EQ(BCC, 32 * 4);
ASSERT_EQ(FMC, 8 * 32 * 8 * 8 * 4 + 32 * 4); ASSERT_EQ(FMC, 8 * 32 * 8 * 8 * 4 + 32 * 4);
ASSERT_EQ(GMC, 128); ASSERT_EQ(GMC, 128);

@ -84,8 +84,8 @@ TEST_F(TestActivation, test_activation_strategies) {
act_ptr_->InitForCostModel(sp); act_ptr_->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.memory_cost_); cost.computation_cost_);
ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_); cost.communication_cost_);
} }
@ -109,8 +109,8 @@ TEST_F(TestActivation, test_softmax_strategies) {
soft_ptr_->InitForCostModel(sp); soft_ptr_->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.memory_cost_); cost.computation_cost_);
ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_); cost.communication_cost_);
} }

@ -569,8 +569,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
matmul1->InitForCostModel(sp); matmul1->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info();
ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.memory_cost_); cost.computation_cost_);
break; break;
} }
} }
@ -599,8 +599,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape);
replica_inputs_info.push_back(replica_input1_info); replica_inputs_info.push_back(replica_input1_info);
ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetMemoryCost(replica_inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()),
cost.memory_cost_); cost.computation_cost_);
break; break;
} }
} }

@ -188,8 +188,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
tensor_add->InitForCostModel(sp); tensor_add->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info();
double memory_cost0 = tensor_add->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost1 = cost.memory_cost_; double memory_cost1 = cost.computation_cost_;
bool memory = memory_cost0 - memory_cost1 <= 1.0; bool memory = memory_cost0 - memory_cost1 <= 1.0;
double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());
@ -210,8 +210,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
tensor_add1->InitForCostModel(sp); tensor_add1->InitForCostModel(sp);
std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info();
double memory_cost0 = tensor_add1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage());
double memory_cost1 = cost.memory_cost_; double memory_cost1 = cost.computation_cost_;
bool memory = memory_cost0 - memory_cost1 <= 1.0; bool memory = memory_cost0 - memory_cost1 <= 1.0;
double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage());

@ -145,8 +145,8 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
identity_ptr->Init(sp); identity_ptr->Init(sp);
std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info();
std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info();
ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.memory_cost_); cost.computation_cost_);
ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
cost.communication_cost_); cost.communication_cost_);
} }

Loading…
Cancel
Save