implementing-searching-strategy-for-inference

pull/859/head
Xiaoda Zhang 5 years ago
parent 5a03bd8077
commit def8573275

@ -23,8 +23,17 @@
namespace mindspore {
namespace parallel {
void Simplify(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_cost.
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
} else {
// inference phase
SimplifyForDecreasingCommunicationForward(clist_ptrs);
}
}
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_forward.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if (!COST_MODEL_SIMPLIFY_CALCULATION) {
return;
@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) {
});
CostPtrList ret;
for (size_t i = 0; i < clist_ptrs->size(); ++i) {
if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) {
if ((ret.size() == size_t(0)) ||
(clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
ret.emplace_back(std::move(clist_ptrs->at(id[i])));
}
}
*clist_ptrs = std::move(ret);
}
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if (!COST_MODEL_SIMPLIFY_CALCULATION) {

@ -51,18 +51,22 @@ struct Cost {
communication_with_partial_para_ = 0.0;
communication_redis_forward_ = 0.0;
communication_redis_backward_ = 0.0;
communication_forward_ = 0.0;
}
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
double memory_with_reuse_;
// 'computation_cost_' models the training time of an iteration in a training phase
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
double computation_cost_;
// 'communication_cost_' includes communications from operators (forward and backward) and edges
// 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution)
double communication_cost_;
// communication_without_parameter_ = communication_cost_ - (backward communication from operators)
double communication_without_parameter_;
// communication_with_partial_para_ =
// communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
double communication_with_partial_para_;
// communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
double communication_forward_;
double communication_redis_forward_;
double communication_redis_backward_;
std::shared_ptr<Decision> decision_ptr_;
@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>;
void Simplify(CostPtrList *clist);
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist);
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist);
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist);
void RefineForPracticalCost(const CostPtr &, bool is_redistribution);
} // namespace parallel
} // namespace mindspore

@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() {
<< ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
// refine communication cost calculation for practice
RefineForPracticalCost(cost, true);
cost->communication_forward_ = cost->communication_redis_forward_;
CostPtrKey ck = {target_output_str, target_input_str};
CostPtrList cl;
cl.push_back(cost);
@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
std::function<void(size_t, double, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
double communication_forward) {
if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication);
@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
new_cost->decision_ptr_ = decision;
result.push_back(new_cost);
return;
@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
selected_cost_list[k] = c;
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_);
communication_without_para + c->communication_without_parameter_,
communication_forward + c->communication_forward_);
}
};
recursive(0, 0.0, 0.0, 0.0, 0.0);
SimplifyForDreasingCommunicationWithPartialPara(&result);
recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
Simplify(&result);
return result;
}
@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
double communication =
left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
double communication_forward =
left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
double communication_without_para = left_cost->communication_without_parameter_ +
middle_cost->communication_without_parameter_ +
right_cost->communication_without_parameter_;
@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost;
cost->communication_forward_ = communication_forward;
ret_cost_list->emplace_back(std::move(cost));
}
}
@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP
CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
}
SimplifyForDreasingCommunicationWithPartialPara(&result);
Simplify(&result);
return result;
}

File diff suppressed because it is too large Load Diff

@ -45,6 +45,9 @@ namespace parallel {
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
class CostGraph;
using CostGraphPtr = std::shared_ptr<CostGraph>;
@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern int32_t RUN_PHASE;
class CostGraph {
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
@ -98,7 +103,7 @@ class CostGraph {
CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);

@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
run_phase_ = DEFAULT_RUN_PHASE;
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
elementwise_stra_follow_ = elementwise_follow;
}
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
} // namespace parallel
} // namespace mindspore

@ -113,6 +113,9 @@ class CostModelContext {
void set_elementwise_stra_follow(bool);
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
void set_run_phase(int32_t);
int32_t run_phase() const { return run_phase_; }
private:
CostModelContext();
static std::shared_ptr<CostModelContext> cm_context_inst_;
@ -141,8 +144,11 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS
double costmodel_communi_bias_;
// MULTI_SUBGRAPHS
bool is_multi_subgraphs_;
int32_t run_phase_; // 0: 'training', 1: 'inference'
int32_t costmodel_allreduce_fusion_algorithm_;
int32_t costmodel_allreduce_fusion_times_;

@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_;
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);

@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
BreakingTiesForPerferringDataParallel(strategy, result);
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);

@ -69,16 +69,16 @@ class TensorRedistribution {
RankList dev_list_;
OperatorList operator_list_;
bool reshape_flag_;
// communication cost
// communication cost, which is the sum of forward communication cost and backward communication cost
double comm_cost_;
// forward communication cost
double forward_comm_cost_;
// backward communication cost
double backward_comm_cost_;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
// inputs. This is calculated ONLY for forward phase.
double computation_cost_;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is
// calculated by the outputs.
double memory_cost_;
bool construct_op_flag_;

@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get the parameter cost_model_communi_bias of the DP algorithm.")
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
.def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
.def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
"Set the parameter gradient AllReduce fusion algorithm.")
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,

@ -239,6 +239,33 @@ class _CostModelContext:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_multi_subgraphs()
def set_run_phase(self, phase):
"""
Set the flag of running phase: training (0) or inference (1)
Args:
phase (int): A parameter indicating which phase is running.
Raises:
ValueError: If context handle is none, or phase is not in {0, 1}.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
if phase not in (0, 1):
raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
self._context_handle.set_run_phase(phase)
def get_run_phase(self):
"""
Get the flag of running phase.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_run_phase()
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
"""
Set costmodel allreduce fusion algorithm.
@ -453,6 +480,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
"run_phase": cost_model_context().set_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
@ -473,7 +501,8 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
"run_phase": cost_model_context().get_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
@ -488,7 +517,7 @@ get_cost_model_context_func_map = {
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
multi_subgraphs=bool,
multi_subgraphs=bool, run_phase=int,
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
costmodel_allreduce_fusion_allreduce_inherent_time=float,
@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;

@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory());
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
}
TEST_F(TestCostGraph, test_EliminationOp) {

@ -14,15 +14,21 @@
import mindspore.context as context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._cost_model_context import reset_cost_model_context
from mindspore.parallel.algo_parameter_config import reset_algo_parameters
from mindspore.parallel._utils import _reset_op_id
def setup_module(module):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id()
def teardown_module():
context.reset_auto_parallel_context()
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id()

@ -0,0 +1,36 @@
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
from mindspore.parallel._cost_model_context import set_cost_model_context
class Net(nn.Cell):
def __init__(self, input_ch, out_ch):
super(Net, self).__init__()
self.dense = nn.Dense(input_ch, out_ch)
self.relu = P.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
def test_inference_phase():
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_cost_model_context(run_phase=1)
net = Net(512, 128)
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
label = Tensor(np.ones([64, 128]).astype(np.float32))
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
output = train_network(predict, label)
Loading…
Cancel
Save