|
|
|
@ -28,6 +28,8 @@
|
|
|
|
|
#include "frontend/parallel/context.h"
|
|
|
|
|
#include "frontend/parallel/step_parallel.h"
|
|
|
|
|
#include "frontend/parallel/node_check.h"
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "base/core_ops.h"
|
|
|
|
|
#include "utils/comm_manager.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
|
|
|
|
@ -136,6 +138,11 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// handle send/recv a parameter
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv.";
|
|
|
|
|
return std::make_pair(nullptr, nullptr);
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
OperatorInfoPtr op_info = nullptr;
|
|
|
|
@ -170,6 +177,23 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
|
|
|
|
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto node_users = manager_->node_users()[node];
|
|
|
|
|
for (auto &user_pair : node_users) {
|
|
|
|
|
auto user_node = user_pair.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(user_node);
|
|
|
|
|
if (!IsPipelineCareNode(user_node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto op_info = CreateOpInfo(user_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1];
|
|
|
|
|
return std::make_pair(nullptr, std::make_shared<TensorInfo>(tensor_info));
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(nullptr, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
|
|
|
|
|
auto need_coloring = true;
|
|
|
|
|
while (need_coloring) {
|
|
|
|
@ -240,6 +264,7 @@ void PipelineTransformer::HandleSharedParameter() {
|
|
|
|
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "");
|
|
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, make_tuple};
|
|
|
|
|
auto depend = graph->NewCNode(depend_input);
|
|
|
|
|
depend->set_abstract(parameter->abstract());
|
|
|
|
|
manager_->SetEdge(node, user.second, depend);
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
@ -301,7 +326,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|
|
|
|
auto send_op = CreatOpInstance(attrs, SEND, "send");
|
|
|
|
|
auto send_node = NewValueNode(send_op);
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(send_node);
|
|
|
|
|
auto op_info_pair = GetOpInfo(parameter);
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
|
|
|
|
if (parameter->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(parameter);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(parameter);
|
|
|
|
|
}
|
|
|
|
|
auto tensor_info = op_info_pair.second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
|
|
|
|
auto slice_shape = tensor_info->slice_shape();
|
|
|
|
@ -314,6 +344,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|
|
|
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend");
|
|
|
|
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
|
|
|
|
|
auto depend = graph->NewCNode(depend_input);
|
|
|
|
|
auto abstract = parameter->abstract();
|
|
|
|
|
depend->set_abstract(abstract);
|
|
|
|
|
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
|
|
|
|
|
return send_out;
|
|
|
|
|
}
|
|
|
|
@ -324,7 +356,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|
|
|
|
recv_tag += 1;
|
|
|
|
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
|
|
|
|
|
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank));
|
|
|
|
|
auto op_info_pair = GetOpInfo(node);
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
op_info_pair = GetParameterPair(node);
|
|
|
|
|
} else {
|
|
|
|
|
op_info_pair = GetOpInfo(node);
|
|
|
|
|
}
|
|
|
|
|
auto tensor_info = op_info_pair.second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
|
|
|
|
auto slice_shape = tensor_info->slice_shape();
|
|
|
|
@ -333,12 +370,19 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|
|
|
|
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second);
|
|
|
|
|
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype};
|
|
|
|
|
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv");
|
|
|
|
|
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_};
|
|
|
|
|
std::vector<AnfNodePtr> recv_input;
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
recv_input = {NewValueNode(recv_op), node};
|
|
|
|
|
} else {
|
|
|
|
|
recv_input = {NewValueNode(recv_op), virtual_param_};
|
|
|
|
|
}
|
|
|
|
|
auto recv = graph->NewCNode(recv_input);
|
|
|
|
|
auto node_abstract = node->abstract();
|
|
|
|
|
recv->set_abstract(node_abstract);
|
|
|
|
|
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
|
|
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first);
|
|
|
|
|
if (op_info_pair.first != nullptr) {
|
|
|
|
|
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
|
|
|
|
|
recv->set_user_data<OperatorInfo>(op_info_pair.first);
|
|
|
|
|
}
|
|
|
|
|
manager_->SetEdge(use_node, index, recv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -448,13 +492,6 @@ void PipelineTransformer::ElimGraphStage() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
|
|
|
|
auto anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
auto prim = anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
return (prim->name() == name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
|
|
|
|
|
std::pair<CNodePtr, FuncGraphPtr> sens_graph_pair;
|
|
|
|
|
CNodePtr sens_cnode;
|
|
|
|
@ -471,7 +508,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
|
|
|
|
|
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) {
|
|
|
|
|
if (!IsPrimitiveCNode(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto expect_anonymous = expect_tuple_getitem_cnode->input(1);
|
|
|
|
@ -484,7 +521,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto expect_j_cnode = expect_j->cast<CNodePtr>();
|
|
|
|
|
if (!IsSomePrimitive(expect_j_cnode, J)) {
|
|
|
|
|
if (!IsPrimitiveCNode(expect_j_cnode, prim::kPrimJ)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
|
|
|
|