!9799 change_last_nodes_strategy_in_auto_parallel_when_eval_and_predict

From: @yao_yf
Reviewed-by: @zhunaipan,@stsuteng
Signed-off-by: @stsuteng
pull/9799/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5e35bec957

@ -233,7 +233,8 @@ void InitCostGraph() {
entire_costgraph->Init();
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs();
@ -290,7 +291,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
@ -307,10 +308,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
} else {
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr;
if (load_strategy_from_ckpt) {
strategyPtr = (*stra_map)[strategy_key_name];
} else {
if (is_last_nodes) {
bool full_batch = ParallelContext::GetInstance()->full_batch();
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
if (full_batch) {
SetLastNodeStrategy(strategyPtr);
}
} else if (StrategyFound(attrs)) {
strategyPtr = parallel::ExtractStrategy(attrs);
} else {
strategyPtr = (*stra_map)[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (prim->name() == RESHAPE) {
@ -341,8 +348,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
}
// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@ -356,7 +365,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
// Step 1
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -401,7 +415,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
continue;
}
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}
@ -436,8 +452,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
}
// Using CNode's UniqueIdThroughCopys to construct nodes
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@ -451,6 +469,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -496,7 +519,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
continue;
}
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}

@ -1638,7 +1638,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
return find;
}
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) {
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) {
MS_EXCEPTION_IF_NULL(unique_ids);
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
@ -1754,10 +1754,10 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (load_strategy_from_ckpt) {
strategyPtr = stra_map[strategy_key_name];
} else {
} else if (StrategyFound(attrs)) {
strategyPtr = ExtractStrategy(attrs);
} else {
strategyPtr = stra_map[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (is_last_nodes && full_batch) {

@ -165,6 +165,10 @@ bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter);
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index);
void SetLastNodeStrategy(const StrategyPtr strategyPtr);
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
} // namespace parallel
} // namespace mindspore

@ -67,3 +67,20 @@ def test_train_and_eval():
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
context.reset_auto_parallel_context()
def test_train_and_eval_auto():
context.set_context(save_graphs=True, mode=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16)
strategy1 = ((4, 4), (4, 4))
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
eval_net = EvalNet(net, strategy2=strategy2)
net.set_auto_parallel()
net.set_train()
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
eval_net.set_train(mode=False)
eval_net.set_auto_parallel()
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
context.reset_auto_parallel_context()

Loading…
Cancel
Save