!10691 [PipelineSplit]Opt PipelineSplit

From: @lichen666
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/10691/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6e5be437e1

@ -286,7 +286,7 @@ void PipelineTransformer::HandleSharedParameter() {
manager_->SetEdge(node, user.second, depend);
break;
} else {
InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
break;
}
}
@ -403,8 +403,9 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
return send_out;
}
void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node,
int index, int64_t user_node_stage, int64_t node_stage) {
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
int64_t node_stage) {
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
int64_t recv_tag;
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
@ -464,6 +465,28 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
recv->set_user_data<OperatorInfo>(op_info_pair.first);
}
manager_->SetEdge(use_node, index, recv);
return recv;
}
bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
const std::vector<AnfNodePtr> &out_input) {
auto node_users = manager_->node_users()[node];
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_;
for (auto &depend : out_input) {
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
continue;
}
auto cnode = depend->cast<CNodePtr>();
if (cnode->input(1) == node) {
auto send_cnode = cnode->input(2)->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0));
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank"));
if (dest_rank_send == dest_rank) {
return true;
}
}
}
return false;
}
std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
@ -496,6 +519,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
auto shared_min_tag_pair = IsSharedNode(node, node_users);
auto is_shared = shared_min_tag_pair.first;
auto min_tag = shared_min_tag_pair.second;
AnfNodePtr receive = nullptr;
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
auto node_stage = node->stage();
@ -508,18 +532,19 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
continue;
}
if (node_stage == stage_) {
if (Reuse(node, user_node_stage, node_stage, out_input)) {
continue;
}
auto send_out = InsertSend(graph, node, user_node_stage, node_stage);
out_input.insert(out_input.begin() + 1, send_out.depend);
type_ptr_ = send_out.type;
shape_ = send_out.shape;
} else {
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
}
continue;
}
if (node_stage == user_node_stage) {
if (is_shared && (min_tag != node_stage)) {
InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag);
if (!receive) {
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
} else {
manager_->SetEdge(user_node, user_pair.second, receive);
}
}
continue;
}

@ -20,6 +20,7 @@
#include <utility>
#include <string>
#include <memory>
#include <vector>
#include "ir/value.h"
#include "ir/graph_utils.h"
#include "base/base.h"
@ -62,11 +63,13 @@ class PipelineTransformer {
void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage,
int64_t node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage);
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage);
void SetNoStageNode(const FuncGraphPtr &func);
void CutBorder(const FuncGraphPtr &graph);
bool IsStageNode(const CNodePtr &node);
bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
const std::vector<AnfNodePtr> &out_input);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);

@ -2675,6 +2675,20 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
InsertNode(op, node, 2, pre_node, root, "shape");
}
static AnfNodePtr FindGrad(const CNodePtr &cnode) {
for (auto &node : cnode->inputs()) {
if (!node->isa<CNode>()) {
continue;
}
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
return FindGrad(node->cast<CNodePtr>());
} else {
return node;
}
}
return nullptr;
}
void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
// If root graph has reshape op. Find the corresponding parameter.
// Reshape's shape is the shape of the parameter.
@ -2706,12 +2720,9 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
continue;
}
auto root = node->func_graph();
auto all_dfs_nodes = DeepLinkedGraphSearch(node);
for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) {
if ((*r_iter)->isa<Parameter>()) {
InsertShapeOp(cnode, *r_iter, root);
break;
}
auto grad_node = FindGrad(cnode);
if (grad_node) {
InsertShapeOp(cnode, grad_node, root);
}
}
}
@ -3113,7 +3124,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
std::reverse(all_nodes.begin(), all_nodes.end());
if (parallel_mode != AUTO_PARALLEL) {
TOTAL_OPS = 0;
if (ParallelInit() != SUCCESS) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
MS_LOG(EXCEPTION) << "Parallel init failed";
}

Loading…
Cancel
Save