|
|
|
@ -286,7 +286,7 @@ void PipelineTransformer::HandleSharedParameter() {
|
|
|
|
|
manager_->SetEdge(node, user.second, depend);
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
|
|
|
|
|
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -403,8 +403,9 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|
|
|
|
return send_out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node,
|
|
|
|
|
int index, int64_t user_node_stage, int64_t node_stage) {
|
|
|
|
|
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
|
|
|
|
|
int64_t node_stage) {
|
|
|
|
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
|
|
|
|
|
int64_t recv_tag;
|
|
|
|
|
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
|
|
|
|
@ -464,6 +465,28 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first);
|
|
|
|
|
}
|
|
|
|
|
manager_->SetEdge(use_node, index, recv);
|
|
|
|
|
return recv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
|
|
|
|
|
const std::vector<AnfNodePtr> &out_input) {
|
|
|
|
|
auto node_users = manager_->node_users()[node];
|
|
|
|
|
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_;
|
|
|
|
|
for (auto &depend : out_input) {
|
|
|
|
|
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = depend->cast<CNodePtr>();
|
|
|
|
|
if (cnode->input(1) == node) {
|
|
|
|
|
auto send_cnode = cnode->input(2)->cast<CNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0));
|
|
|
|
|
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank"));
|
|
|
|
|
if (dest_rank_send == dest_rank) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
|
|
|
@ -496,6 +519,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
|
|
|
|
auto shared_min_tag_pair = IsSharedNode(node, node_users);
|
|
|
|
|
auto is_shared = shared_min_tag_pair.first;
|
|
|
|
|
auto min_tag = shared_min_tag_pair.second;
|
|
|
|
|
AnfNodePtr receive = nullptr;
|
|
|
|
|
for (auto &user_pair : node_users) {
|
|
|
|
|
auto user_node = user_pair.first;
|
|
|
|
|
auto node_stage = node->stage();
|
|
|
|
@ -508,18 +532,19 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (node_stage == stage_) {
|
|
|
|
|
if (Reuse(node, user_node_stage, node_stage, out_input)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto send_out = InsertSend(graph, node, user_node_stage, node_stage);
|
|
|
|
|
out_input.insert(out_input.begin() + 1, send_out.depend);
|
|
|
|
|
type_ptr_ = send_out.type;
|
|
|
|
|
shape_ = send_out.shape;
|
|
|
|
|
} else {
|
|
|
|
|
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (node_stage == user_node_stage) {
|
|
|
|
|
if (is_shared && (min_tag != node_stage)) {
|
|
|
|
|
InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag);
|
|
|
|
|
if (!receive) {
|
|
|
|
|
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
|
|
|
|
|
} else {
|
|
|
|
|
manager_->SetEdge(user_node, user_pair.second, receive);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|