Model the memory cost in auto-parallel. It is calculated by the output of operators, plus the parameters. Additionally, modify the graph-operations in auto_parallel to include memory_cost.

pull/232/head
Xiaoda Zhang 5 years ago
parent c9fba7f091
commit 0ac50a19f5

@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision {
*/ */
struct TriangleEliminationDecision : public Decision { struct TriangleEliminationDecision : public Decision {
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) StrategyPtr left_stra, CostPtr l_node_cost)
: eliminated_op_strategy_(std::move(elimi_stra)), : eliminated_op_strategy_(std::move(elimi_stra)),
eliminated_op_cost_(std::move(elimi_op_cost)), eliminated_op_cost_(std::move(elimi_op_cost)),
left_edge_cost_(std::move(l_edge_cost)), left_edge_cost_(std::move(l_edge_cost)),
right_edge_cost_(std::move(r_edge_cost)), right_edge_cost_(std::move(r_edge_cost)),
left_node_strategy_(std::move(left_stra)), left_node_strategy_(std::move(left_stra)),
left_node_cost_(std::move(l_node_cost)), left_node_cost_(std::move(l_node_cost)) {
right_node_strategy_(std::move(right_stra)),
right_node_cost_(std::move(r_node_cost)) {
type_ = DecisionType::TRIANGLE_ELIMINATION; type_ = DecisionType::TRIANGLE_ELIMINATION;
} }
@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision {
CostPtr right_edge_cost_; CostPtr right_edge_cost_;
StrategyPtr left_node_strategy_; StrategyPtr left_node_strategy_;
CostPtr left_node_cost_; CostPtr left_node_cost_;
StrategyPtr right_node_strategy_;
CostPtr right_node_cost_;
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
}; };

@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) {
auto l_r_edge = triangle_pair.second; auto l_r_edge = triangle_pair.second;
auto left_node = l_r_edge->prev_operator(); auto left_node = l_r_edge->prev_operator();
auto right_node = l_r_edge->next_operator();
auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
MS_EXCEPTION_IF_NULL(left_edge); MS_EXCEPTION_IF_NULL(left_edge);
@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) {
right_edge = tmp; right_edge = tmp;
} }
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
auto elimi = auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
eliminations.emplace_back(std::move(elimi)); eliminations.emplace_back(std::move(elimi));
} }
auto star_center = graph->CheckStarElimination(); auto star_center = graph->CheckStarElimination();
@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto left_edge = elimination->left_edge_; auto left_edge = elimination->left_edge_;
auto eliminated_node = elimination->eliminated_node_; auto eliminated_node = elimination->eliminated_node_;
auto right_edge = elimination->right_edge_; auto right_edge = elimination->right_edge_;
auto right_node = elimination->right_node_;
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>(); auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
left_edge->set_selected_cost(decision->left_edge_cost_); left_edge->set_selected_cost(decision->left_edge_cost_);
right_edge->set_selected_cost(decision->right_edge_cost_); right_edge->set_selected_cost(decision->right_edge_cost_);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
MS_LOG(INFO) << "Recover triangleElimination succeeded."; MS_LOG(INFO) << "Recover triangleElimination succeeded.";
} else if ((*rit)->isa<StarElimination>()) { } else if ((*rit)->isa<StarElimination>()) {
auto elimination = (*rit)->cast<StarEliminationPtr>(); auto elimination = (*rit)->cast<StarEliminationPtr>();
@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
for (size_t i = 0; i < succ_edges.size(); ++i) { for (size_t i = 0; i < succ_edges.size(); ++i) {
succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]);
} }
for (size_t j = 0; j < succ_nodes.size(); ++j) { MS_EXCEPTION_IF_NULL(succ_nodes[0]);
succ_nodes[j]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[j], decision->succ_ops_cost_list_[j]); MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]);
} MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
MS_LOG(INFO) << "Recover starElimination succeeded."; MS_LOG(INFO) << "Recover starElimination succeeded.";
} else { } else {
MS_LOG(ERROR) << "Unknown Elimination type."; MS_LOG(ERROR) << "Unknown Elimination type.";

@ -102,20 +102,17 @@ struct ContractElimination : public Elimination {
// Triangle Elimination // Triangle Elimination
struct TriangleElimination : public Elimination { struct TriangleElimination : public Elimination {
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
OperatorInfoPtr r_node)
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE), : Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
eliminated_node_(std::move(elim_node)), eliminated_node_(std::move(elim_node)),
left_edge_(std::move(l_edge)), left_edge_(std::move(l_edge)),
left_node_(std::move(l_node)), left_node_(std::move(l_node)),
right_edge_(std::move(r_edge)), right_edge_(std::move(r_edge)) {}
right_node_(std::move(r_node)) {}
OperatorInfoPtr eliminated_node_; OperatorInfoPtr eliminated_node_;
EdgePtr left_edge_; EdgePtr left_edge_;
OperatorInfoPtr left_node_; OperatorInfoPtr left_node_;
EdgePtr right_edge_; EdgePtr right_edge_;
OperatorInfoPtr right_node_;
MS_DECLARE_PARENT(TriangleElimination, Elimination); MS_DECLARE_PARENT(TriangleElimination, Elimination);
}; };

@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
double forward_comm_cost = tensor_redistribution.forward_comm_cost(); double forward_comm_cost = tensor_redistribution.forward_comm_cost();
double backward_comm_cost = tensor_redistribution.backward_comm_cost(); double backward_comm_cost = tensor_redistribution.backward_comm_cost();
double computation_cost = tensor_redistribution.computation_cost(); double computation_cost = tensor_redistribution.computation_cost();
double mem_cost = tensor_redistribution.memory_cost();
// Now AllGather, ReduceScatter, AlltoAll don't support bool type // Now AllGather, ReduceScatter, AlltoAll don't support bool type
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
(*cost)->communication_redis_forward_ = type_length * forward_comm_cost; (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
(*cost)->communication_redis_backward_ = type_length * backward_comm_cost; (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
(*cost)->memory_with_reuse_ = mem_cost;
return Status::SUCCESS; return Status::SUCCESS;
} }
@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
CostPtrList selected_cost_list(all_cost_list.size(), nullptr); CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double)> recursive = std::function<void(size_t, double, double, double, double)> recursive =
[&](size_t k, double computation, double communication, double communication_without_para) { [&](size_t k, double computation, double memory, double communication, double communication_without_para) {
if (k == edges.size()) { if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication); CostPtr new_cost = std::make_shared<Cost>(computation, communication);
@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->decision_ptr_ = decision; new_cost->decision_ptr_ = decision;
result.push_back(new_cost); result.push_back(new_cost);
return; return;
@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
for (auto& c : all_cost_list[k]) { for (auto& c : all_cost_list[k]) {
MS_EXCEPTION_IF_NULL(c); MS_EXCEPTION_IF_NULL(c);
selected_cost_list[k] = c; selected_cost_list[k] = c;
recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_, recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_); communication_without_para + c->communication_without_parameter_);
} }
}; };
recursive(0, 0, 0, 0); recursive(0, 0.0, 0.0, 0.0, 0.0);
SimplifyForDreasingCommunicationWithPartialPara(&result); SimplifyForDreasingCommunicationWithPartialPara(&result);
return result; return result;
} }
@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
double communication_without_para = left_cost->communication_without_parameter_ + double communication_without_para = left_cost->communication_without_parameter_ +
middle_cost->communication_without_parameter_ + middle_cost->communication_without_parameter_ +
right_cost->communication_without_parameter_; right_cost->communication_without_parameter_;
double memory_cost =
left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
auto cost = std::make_shared<Cost>(computation, communication, decision); auto cost = std::make_shared<Cost>(computation, communication, decision);
@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost->communication_without_parameter_ = communication_without_para; cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ = cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost;
ret_cost_list->emplace_back(std::move(cost)); ret_cost_list->emplace_back(std::move(cost));
} }
} }
@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op,
MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
} }
} }
Status Edge::CalculateMemoryCost() {
if (is_output_parameter_involve_ == -1) {
MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
return FAILED;
}
if (is_output_parameter_involve_ == 0) {
// In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
// unnecessary to keep them in memory.
for (auto& cost_kv : cost_map_) {
auto& cost_v = cost_kv.second;
if (!cost_v.empty()) {
cost_v[0]->memory_with_reuse_ = 0;
}
}
}
return SUCCESS;
}
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -133,7 +133,7 @@ class Edge {
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CalculateMemoryCost() const { return SUCCESS; } Status CalculateMemoryCost();
private: private:
std::string edge_name_; std::string edge_name_;

File diff suppressed because it is too large Load Diff

@ -187,6 +187,9 @@ class CostGraph {
size_t GetNumPairs() const { return edges_.size(); } size_t GetNumPairs() const { return edges_.size(); }
Status InitSelectedStrategy(); Status InitSelectedStrategy();
OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// once (instead of multiple times), this method is used to correct this.
Status CorrectOpsMemoryCost();
// Needed by rec_parser // Needed by rec_parser
void add_inputs_tensor_name(const std::vector<std::string>& inputs_tensor_name) { void add_inputs_tensor_name(const std::vector<std::string>& inputs_tensor_name) {
inputs_tensor_name_list_.push_back(inputs_tensor_name); inputs_tensor_name_list_.push_back(inputs_tensor_name);

@ -17,6 +17,7 @@
#include "parallel/auto_parallel/operator_costmodel.h" #include "parallel/auto_parallel/operator_costmodel.h"
#include <random> #include <random>
#include <algorithm>
#include "parallel/device_matrix.h" #include "parallel/device_matrix.h"
#include "parallel/tensor_layout/tensor_redistribution.h" #include "parallel/tensor_layout/tensor_redistribution.h"
@ -24,12 +25,44 @@ namespace mindspore {
namespace parallel { namespace parallel {
void OperatorCost::set_is_parameter(const std::vector<bool>& is_parameter) { is_parameter_ = is_parameter; } void OperatorCost::set_is_parameter(const std::vector<bool>& is_parameter) { is_parameter_ = is_parameter; }
void OperatorCost::set_is_parameter_involve(const std::vector<bool>& is_parameter_inv) {
is_parameter_involve_ = is_parameter_inv;
}
void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; }
void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths,
const std::vector<size_t>& output_lengths) { const std::vector<size_t>& output_lengths) {
inputs_type_lengths_ = input_lengths; inputs_type_lengths_ = input_lengths;
outputs_type_lengths_ = output_lengths; outputs_type_lengths_ = output_lengths;
} }
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs) const {
double result = 0.0;
if (output_parameter_involve_ == 1) {
// When this operator has multiple outputs, they all contributes to the memory.
for (size_t i = 0; i < outputs.size(); ++i) {
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
}
bool is_any_para_inv =
std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; });
if (is_any_para_inv) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_parameter_[i]) {
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
} else if (inputs_related_ && (!is_parameter_involve_[i])) {
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
// in the memory cost.
result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]);
}
}
}
}
return result;
}
// return the per device communication cost in the forward phase. // return the per device communication cost in the forward phase.
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t&) const { const int32_t&) const {
@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co
return result; return result;
} }
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t&) const { const std::vector<TensorInfo>& outputs, const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
double result = 0.0; double result = 0.0;
TensorInfo output0 = outputs[0]; TensorInfo output0 = outputs[0];
Shape input0_slice_shape = inputs[0].slice_shape(); Shape input0_slice_shape = inputs[0].slice_shape();
@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inpu
return result; return result;
} }
// Return the per device memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t& stage_id) const { const int32_t& stage_id) const {
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
double result = 0.0; double result = 0.0;
if (is_parameter_[1]) { if (is_parameter_[1]) {
TensorInfo input1 = inputs[1]; // tensor B TensorInfo input1 = inputs[1]; // tensor B
@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
return result; return result;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>&
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c
return result; return result;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
// In the forward phase, the memory cost = slice(A) // In the forward phase, the computation cost = slice(A)
TensorInfo input0 = inputs[0]; TensorInfo input0 = inputs[0];
Shape input0_slice_shape = input0.slice_shape(); Shape input0_slice_shape = input0.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,
@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle
return 0.0; return 0.0;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t&) const { const int32_t&) const {
TensorInfo input0_info = inputs[0]; return 0.0;
Shape input0_slice_shape = input0_info.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
} }
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,
@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
return 0.0; return 0.0;
} }
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&) const {
return 0.0;
}
double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t&) const { const int32_t&) const {
@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con
return result; return result;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B) // In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape(); Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& input
return result; return result;
} }
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,
const int32_t& stage_id) const { const int32_t& stage_id) const {
// In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) // In backward phase, the computation cost = (0 or 1) allreduce(slice(B))
double result = 0.0; double result = 0.0;
if (is_parameter_[1]) { if (is_parameter_[1]) {
TensorInfo input1 = inputs[1]; // tensor B TensorInfo input1 = inputs[1]; // tensor B
@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std
return 0.0; return 0.0;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
// In onehot's forward phase, the memory cost = slice(A) // In onehot's forward phase, the computation cost = slice(A)
Shape input0_slice_shape = inputs[0].slice_shape(); Shape input0_slice_shape = inputs[0].slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
} }
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<
return 0.0; return 0.0;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const { const int32_t&) const {
// In forward phase, the memory cost = slice(A) + slice(B) // In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape(); Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v
return result; return result;
} }
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&,
const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st
return 0.0; return 0.0;
} }
// Return the per memory cost in the forward phase. The cost is calculated according to the bytes // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const { const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const {
@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost());
} }
// Return the per memory cost in the backward phase. The cost is calculated according to the bytes // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes
// this operator uses // this operator uses
double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&,
const std::vector<mindspore::parallel::TensorInfo>&, const std::vector<mindspore::parallel::TensorInfo>&,

@ -43,10 +43,20 @@ double ListProduct(std::vector<T> vec) {
// entries timing the length of each entry's data type // entries timing the length of each entry's data type
class OperatorCost { class OperatorCost {
public: public:
OperatorCost() { explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false); is_parameter_.push_back(false);
is_parameter_involve_.push_back(false);
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
}
}
OperatorCost() : inputs_related_(false) {
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) {
is_parameter_.push_back(false);
is_parameter_involve_.push_back(false);
inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH);
} }
@ -54,6 +64,8 @@ class OperatorCost {
virtual ~OperatorCost() = default; virtual ~OperatorCost() = default;
void set_is_parameter(const std::vector<bool>& is_parameter); void set_is_parameter(const std::vector<bool>& is_parameter);
void set_is_parameter_involve(const std::vector<bool>&);
void set_output_parameter_involve(int);
void SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, const std::vector<size_t>& output_lengths); void SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, const std::vector<size_t>& output_lengths);
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; } std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; } std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
@ -72,8 +84,19 @@ class OperatorCost {
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs,
const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0;
// per device PEAK memory cost in a training iteration
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// plus necessary inputs.
virtual double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const;
protected: protected:
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// pre-operator that has parameters as input.
std::vector<bool> is_parameter_involve_;
int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
// Mul's two inputs are dependent (related).
bool inputs_related_;
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std::vector<bool> is_parameter_; std::vector<bool> is_parameter_;
// for each input and output, the followings record the number of bytes of each element // for each input and output, the followings record the number of bytes of each element
@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr<OperatorCost>;
class MatMulCost : public OperatorCost { class MatMulCost : public OperatorCost {
public: public:
MatMulCost() = default; explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
MatMulCost() : OperatorCost(true) {}
~MatMulCost() override = default; ~MatMulCost() override = default;
// per device communication cost // per device communication cost
@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override; const int32_t& stage_id) const override;
}; };
using MatMulCostPtr = std::shared_ptr<MatMulCost>; using MatMulCostPtr = std::shared_ptr<MatMulCost>;
class ActivationCost : public OperatorCost { class ActivationCost : public OperatorCost {
public: public:
ActivationCost() = default; explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ActivationCost() : OperatorCost(false) {}
~ActivationCost() override = default; ~ActivationCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override; const int32_t& stage_id) const override;
}; };
using ActivationCostPtr = std::shared_ptr<ActivationCost>; using ActivationCostPtr = std::shared_ptr<ActivationCost>;
using TransposeCost = ActivationCost; using TransposeCost = ActivationCost;
using TransposeCostPtr = std::shared_ptr<TransposeCost>; using TransposeCostPtr = std::shared_ptr<TransposeCost>;
class SoftmaxCost : public OperatorCost { class SoftmaxCost : public OperatorCost {
public: public:
SoftmaxCost() = default; explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCost() : OperatorCost(false) {}
~SoftmaxCost() override = default; ~SoftmaxCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost {
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t&) const override; const int32_t&) const override;
}; };
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
class TmpIdentityCost : public OperatorCost { class TmpIdentityCost : public OperatorCost {
public: public:
TmpIdentityCost() = default; explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
TmpIdentityCost() : OperatorCost(false) {}
~TmpIdentityCost() override = default; ~TmpIdentityCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost {
const int32_t& stage_id) const override; const int32_t& stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override; const int32_t& stage_id) const override;
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override;
}; };
using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>;
class BatchParallelCost : public OperatorCost { class BatchParallelCost : public OperatorCost {
public: public:
BatchParallelCost() = default; explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
BatchParallelCost() : OperatorCost(false) {}
~BatchParallelCost() override = default; ~BatchParallelCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
class VirtualDatasetCost : public OperatorCost { class VirtualDatasetCost : public OperatorCost {
public: public:
VirtualDatasetCost() = default; explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
VirtualDatasetCost() : OperatorCost(false) {}
~VirtualDatasetCost() override = default; ~VirtualDatasetCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost {
const int32_t&) const override { const int32_t&) const override {
return 0.0; return 0.0;
} }
// per device PEAK memory cost in a training iteration
double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override {
return 0.0;
}
}; };
using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>; using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>;
class GeneratorBaseCost : public OperatorCost { class GeneratorBaseCost : public OperatorCost {
public: public:
GeneratorBaseCost() = default; explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GeneratorBaseCost() : OperatorCost(false) {}
~GeneratorBaseCost() override = default; ~GeneratorBaseCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
class PReLUCost : public OperatorCost { class PReLUCost : public OperatorCost {
public: public:
PReLUCost() = default; explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
PReLUCost() : OperatorCost(true) {}
~PReLUCost() override = default; ~PReLUCost() override = default;
// per device communication cost // per device communication cost
@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr<PReLUCost>;
class OneHotCost : public OperatorCost { class OneHotCost : public OperatorCost {
public: public:
OneHotCost() = default; explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
OneHotCost() : OperatorCost(true) {}
~OneHotCost() override = default; ~OneHotCost() override = default;
// per device communication cost // per device communication cost
@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr<OneHotCost>;
class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost {
public: public:
SoftmaxCrossEntropyWithLogitsCost() = default; explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {}
~SoftmaxCrossEntropyWithLogitsCost() override = default; ~SoftmaxCrossEntropyWithLogitsCost() override = default;
// per device communication cost // per device communication cost
@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropy
class ReshapeCost : public OperatorCost { class ReshapeCost : public OperatorCost {
public: public:
ReshapeCost() = default; explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReshapeCost() : OperatorCost(true) {}
~ReshapeCost() override = default; ~ReshapeCost() override = default;
@ -396,7 +433,8 @@ using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
class ArithmeticCost : public OperatorCost { class ArithmeticCost : public OperatorCost {
public: public:
ArithmeticCost() = default; explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ArithmeticCost() : OperatorCost(false) {}
~ArithmeticCost() override = default; ~ArithmeticCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
class ReduceMethodCost : public OperatorCost { class ReduceMethodCost : public OperatorCost {
public: public:
ReduceMethodCost() = default; explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
ReduceMethodCost() : OperatorCost(true) {}
~ReduceMethodCost() override = default; ~ReduceMethodCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
class ReduceMeanCost : public ReduceMethodCost { class ReduceMeanCost : public ReduceMethodCost {
public: public:
ReduceMeanCost() = default; explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {}
ReduceMeanCost() : ReduceMethodCost(true) {}
~ReduceMeanCost() override = default; ~ReduceMeanCost() override = default;
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
class GetNextCost : public OperatorCost { class GetNextCost : public OperatorCost {
public: public:
GetNextCost() = default; explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GetNextCost() : OperatorCost(false) {}
~GetNextCost() override = default; ~GetNextCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr<GetNextCost>;
class DropOutCost : public OperatorCost { class DropOutCost : public OperatorCost {
public: public:
DropOutCost() = default; explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
DropOutCost() : OperatorCost(true) {}
~DropOutCost() override = default; ~DropOutCost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class GatherV2Cost : public OperatorCost { class GatherV2Cost : public OperatorCost {
public: public:
GatherV2Cost() = default; explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GatherV2Cost() : OperatorCost(true) {}
~GatherV2Cost() override = default; ~GatherV2Cost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,

@ -51,7 +51,7 @@ class Activation : public ActivationBase {
public: public:
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {} : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
~Activation() override = default; ~Activation() override = default;
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr& strategy) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
public: public:
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
~Softmax() override = default; ~Softmax() override = default;
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr& strategy) override; Status SetCostUnderStrategy(const StrategyPtr& strategy) override;

@ -32,8 +32,8 @@ namespace parallel {
class ArithmeticBase : public OperatorInfo { class ArithmeticBase : public OperatorInfo {
public: public:
ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs, OperatorCostPtr cost)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {} : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
~ArithmeticBase() override = default; ~ArithmeticBase() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override;
@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
class SubInfo : public ArithmeticBase { class SubInfo : public ArithmeticBase {
public: public:
SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~SubInfo() override = default; ~SubInfo() override = default;
}; };
@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase {
public: public:
TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~TensorAddInfo() override = default; ~TensorAddInfo() override = default;
}; };
class MulInfo : public ArithmeticBase { class MulInfo : public ArithmeticBase {
public: public:
MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MulInfo() override = default; ~MulInfo() override = default;
}; };
class DivInfo : public ArithmeticBase { class DivInfo : public ArithmeticBase {
public: public:
DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~DivInfo() override = default; ~DivInfo() override = default;
}; };
@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase {
public: public:
RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~RealDivInfo() override = default; ~RealDivInfo() override = default;
}; };
@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase {
public: public:
FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~FloorDivInfo() override = default; ~FloorDivInfo() override = default;
}; };
class PowInfo : public ArithmeticBase { class PowInfo : public ArithmeticBase {
public: public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~PowInfo() override = default; ~PowInfo() override = default;
}; };
@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase {
public: public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default; ~GreaterInfo() override = default;
}; };
@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase {
public: public:
AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default; ~AssignSubInfo() override = default;
}; };
} // namespace parallel } // namespace parallel

@ -29,9 +29,13 @@ namespace mindspore {
namespace parallel { namespace parallel {
class BatchParallelInfo : public OperatorInfo { class BatchParallelInfo : public OperatorInfo {
public: public:
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)),
dev_num_(1) {}
~BatchParallelInfo() override = default; ~BatchParallelInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;
@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
public: public:
SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape,
const Shapes& outputs_shape, const PrimitiveAttrs& attrs) const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {} : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {}
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
void ReComputeBatchSplitFlagList() override; void ReComputeBatchSplitFlagList() override;
}; };

@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
public: public:
BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {}
~BiasAddInfo() override = default; ~BiasAddInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
#include <string> #include <string>
#include <memory>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "ir/value.h" #include "ir/value.h"
@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
public: public:
EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~EqualInfo() override = default; ~EqualInfo() override = default;
}; };
@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase {
public: public:
NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~NotEqualInfo() override = default; ~NotEqualInfo() override = default;
}; };
@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase {
public: public:
MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MaximumInfo() override = default; ~MaximumInfo() override = default;
}; };
@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase {
public: public:
MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MinimumInfo() override = default; ~MinimumInfo() override = default;
}; };
} // namespace parallel } // namespace parallel

@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
public: public:
DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {}
~DropoutDoMaskInfo() override = default; ~DropoutDoMaskInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;

@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
public: public:
GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {}
~GetNextInfo() override = default; ~GetNextInfo() override = default;
Status Init(const StrategyPtr &strategy) override; Status Init(const StrategyPtr &strategy) override;

@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
public: public:
SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs,
std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {}
~SoftmaxCrossEntropyWithLogitsInfo() override = default; ~SoftmaxCrossEntropyWithLogitsInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override;

@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
// Here, we use the origin outputs_, because we only use the slice size of the output tensor. // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not. // It does not matter whether the output tensor is transposed or not.
double computation_cost = double computation_cost =
cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ = result->communication_without_parameter_ =
cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ = result->communication_with_partial_para_ =
result->communication_without_parameter_ + result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);

@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
public: public:
MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
~MatMulBase() override = default; ~MatMulBase() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;

@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
public: public:
OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {}
~OneHotInfo() override = default; ~OneHotInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override;

@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return FAILED; return FAILED;
} }
int32_t stage_id = strategy->GetInputStage(); int32_t stage_id = strategy->GetInputStage();
double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double computation_cost =
double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ = result->communication_without_parameter_ =
cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ = result->communication_with_partial_para_ =
result->communication_without_parameter_ + result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
return FAILED; return FAILED;
} }
is_parameter_ = is_parameter; is_parameter_ = is_parameter;
cost()->set_is_parameter(is_parameter); operator_cost()->set_is_parameter(is_parameter);
return SUCCESS;
}
Status OperatorInfo::CalculateMemoryCost() {
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
// calculate memory cost.
if (is_parameter_involve_.size() != is_parameter_.size()) {
MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
return FAILED;
}
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
// Set the memory cost in the 'strategy_cost_'
for (auto& swc : strategy_cost_) {
auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
swc->cost_list[0]->memory_with_reuse_ = mem_cost;
}
return SUCCESS;
}
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
for (auto& swc : strategy_cost_) {
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
if (swc->cost_list[0]->memory_with_reuse_ < 0) {
MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
<< ", the parameter memory cost is: " << parameter_mem_cost;
return FAILED;
}
}
return SUCCESS; return SUCCESS;
} }
@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
} }
inputs_type_lengths_ = input_lengths; inputs_type_lengths_ = input_lengths;
outputs_type_lengths_ = output_lengths; outputs_type_lengths_ = output_lengths;
cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
return SUCCESS; return SUCCESS;
} }
@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
} }
double OperatorInfo::GetForwardMemoryCostFromCNode() { double OperatorInfo::GetForwardMemoryCostFromCNode() {
return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
} }
} // namespace parallel } // namespace parallel

@ -60,7 +60,7 @@ class OperatorInfo {
outputs_shape_(std::move(outputs_shape)), outputs_shape_(std::move(outputs_shape)),
attrs_(std::move(attrs)), attrs_(std::move(attrs)),
is_alive_(true), is_alive_(true),
cost_(cost), operator_cost_(cost),
outputs_type_() { outputs_type_() {
std::vector<bool> not_parameteter(inputs_shape_.size(), false); std::vector<bool> not_parameteter(inputs_shape_.size(), false);
is_parameter_ = not_parameteter; is_parameter_ = not_parameteter;
@ -83,8 +83,8 @@ class OperatorInfo {
// Given the stage_id (which indicates the number of devices), // Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator // generate all strategies for this operator
virtual Status GenerateStrategies(int32_t stage_id) = 0; virtual Status GenerateStrategies(int32_t stage_id) = 0;
const OperatorCostPtr& cost() const { return cost_; } const OperatorCostPtr& operator_cost() const { return operator_cost_; }
void set_cost(const OperatorCostPtr& cost) { cost_ = cost; } void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; }
virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0;
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies(); virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies();
@ -98,7 +98,7 @@ class OperatorInfo {
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status CalculateMemoryCost() const { return SUCCESS; } Status CalculateMemoryCost();
int ComputeOpAndPrevEdgeParameterInvolved(); int ComputeOpAndPrevEdgeParameterInvolved();
ForwardOp forward_op() const { return forward_op_; } ForwardOp forward_op() const { return forward_op_; }
@ -125,7 +125,7 @@ class OperatorInfo {
void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge);
std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); } std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) {
selected_strategy_ = s_strategy; selected_strategy_ = s_strategy;
selected_cost_ = cost; selected_cost_ = cost;
@ -142,6 +142,10 @@ class OperatorInfo {
void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; }
void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } const std::string& refkey_parameter_name() const { return refkey_parameter_name_; }
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
// multiple times. This method is to correct this, and makes the cost is calulated only once.
Status CorrectMemoryCost(size_t input_index);
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
int used_devices() const { return used_devices_; } int used_devices() const { return used_devices_; }
// needed by rec_parser // needed by rec_parser
void set_type(const std::string& type) { type_ = type; } void set_type(const std::string& type) { type_ = type; }
@ -234,7 +238,7 @@ class OperatorInfo {
int32_t used_devices_ = -1; int32_t used_devices_ = -1;
private: private:
OperatorCostPtr cost_; OperatorCostPtr operator_cost_;
std::vector<TypePtr> outputs_type_; std::vector<TypePtr> outputs_type_;
}; };

@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
public: public:
PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs) const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {}
~PReLUInfo() override = default; ~PReLUInfo() override = default;
Status Init(const StrategyPtr& strategy) override; Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override; Status InitForCostModel(const StrategyPtr& strategy) override;

@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() {
} }
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
} }
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost()); auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
if (reducemethodcost == nullptr) { if (reducemethodcost == nullptr) {
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
return FAILED; return FAILED;

@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo {
public: public:
ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {}
~ReduceMethod() override = default; ~ReduceMethod() override = default;
Status Init(const StrategyPtr &strategy) override; Status Init(const StrategyPtr &strategy) override;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save